From 5291eafa412117b80ebbf645fc51559dd0b2caaf Mon Sep 17 00:00:00 2001 From: Mike Frysinger Date: Wed, 5 May 2021 15:53:03 -0400 Subject: [PATCH] ssh: move all ssh logic to a common place We had ssh logic sprinkled between two git modules, and neither was quite the right home for it. This largely moves the logic as-is to its new home. We'll leave major refactoring to followup commits. Bug: https://crbug.com/gerrit/12389 Change-Id: I300a8f7dba74f2bd132232a5eb1e856a8490e0e9 Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/305483 Reviewed-by: Chris Mcdonald Tested-by: Mike Frysinger --- git_command.py | 91 +------------- git_config.py | 156 +---------------------- main.py | 7 +- ssh.py | 257 ++++++++++++++++++++++++++++++++++++++ tests/test_git_command.py | 32 ----- tests/test_ssh.py | 52 ++++++++ 6 files changed, 320 insertions(+), 275 deletions(-) create mode 100644 ssh.py create mode 100644 tests/test_ssh.py diff --git a/git_command.py b/git_command.py index f8cb280..fabad0e 100644 --- a/git_command.py +++ b/git_command.py @@ -14,16 +14,14 @@ import functools import os -import re import sys import subprocess -import tempfile -from signal import SIGTERM from error import GitError from git_refs import HEAD import platform_utils from repo_trace import REPO_TRACE, IsTrace, Trace +import ssh from wrapper import Wrapper GIT = 'git' @@ -43,85 +41,6 @@ GIT_DIR = 'GIT_DIR' LAST_GITDIR = None LAST_CWD = None -_ssh_proxy_path = None -_ssh_sock_path = None -_ssh_clients = [] - - -def _run_ssh_version(): - """run ssh -V to display the version number""" - return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode() - - -def _parse_ssh_version(ver_str=None): - """parse a ssh version string into a tuple""" - if ver_str is None: - ver_str = _run_ssh_version() - m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str) - if m: - return tuple(int(x) for x in m.group(1).split('.')) - else: - return () - - -@functools.lru_cache(maxsize=None) -def ssh_version(): - """return ssh version as a tuple""" - try: - return _parse_ssh_version() - except subprocess.CalledProcessError: - print('fatal: unable to detect ssh version', file=sys.stderr) - sys.exit(1) - - -def ssh_sock(create=True): - global _ssh_sock_path - if _ssh_sock_path is None: - if not create: - return None - tmp_dir = '/tmp' - if not os.path.exists(tmp_dir): - tmp_dir = tempfile.gettempdir() - if ssh_version() < (6, 7): - tokens = '%r@%h:%p' - else: - tokens = '%C' # hash of %l%h%p%r - _ssh_sock_path = os.path.join( - tempfile.mkdtemp('', 'ssh-', tmp_dir), - 'master-' + tokens) - return _ssh_sock_path - - -def _ssh_proxy(): - global _ssh_proxy_path - if _ssh_proxy_path is None: - _ssh_proxy_path = os.path.join( - os.path.dirname(__file__), - 'git_ssh') - return _ssh_proxy_path - - -def _add_ssh_client(p): - _ssh_clients.append(p) - - -def _remove_ssh_client(p): - try: - _ssh_clients.remove(p) - except ValueError: - pass - - -def terminate_ssh_clients(): - global _ssh_clients - for p in _ssh_clients: - try: - os.kill(p.pid, SIGTERM) - p.wait() - except OSError: - pass - _ssh_clients = [] - class _GitCall(object): @functools.lru_cache(maxsize=None) @@ -256,8 +175,8 @@ class GitCommand(object): if disable_editor: env['GIT_EDITOR'] = ':' if ssh_proxy: - env['REPO_SSH_SOCK'] = ssh_sock() - env['GIT_SSH'] = _ssh_proxy() + env['REPO_SSH_SOCK'] = ssh.sock() + env['GIT_SSH'] = ssh.proxy() env['GIT_SSH_VARIANT'] = 'ssh' if 'http_proxy' in env and 'darwin' == sys.platform: s = "'http.proxy=%s'" % (env['http_proxy'],) @@ -340,7 +259,7 @@ class GitCommand(object): raise GitError('%s: %s' % (command[1], e)) if ssh_proxy: - _add_ssh_client(p) + ssh.add_client(p) self.process = p if input: @@ -352,7 +271,7 @@ class GitCommand(object): try: self.stdout, self.stderr = p.communicate() finally: - _remove_ssh_client(p) + ssh.remove_client(p) self.rc = p.wait() @staticmethod diff --git a/git_config.py b/git_config.py index fcd0446..1d8d136 100644 --- a/git_config.py +++ b/git_config.py @@ -18,25 +18,17 @@ from http.client import HTTPException import json import os import re -import signal import ssl import subprocess import sys -try: - import threading as _threading -except ImportError: - import dummy_threading as _threading -import time import urllib.error import urllib.request from error import GitError, UploadError import platform_utils from repo_trace import Trace - +import ssh from git_command import GitCommand -from git_command import ssh_sock -from git_command import terminate_ssh_clients from git_refs import R_CHANGES, R_HEADS, R_TAGS ID_RE = re.compile(r'^[0-9a-f]{40}$') @@ -440,129 +432,6 @@ class RefSpec(object): return s -_master_processes = [] -_master_keys = set() -_ssh_master = True -_master_keys_lock = None - - -def init_ssh(): - """Should be called once at the start of repo to init ssh master handling. - - At the moment, all we do is to create our lock. - """ - global _master_keys_lock - assert _master_keys_lock is None, "Should only call init_ssh once" - _master_keys_lock = _threading.Lock() - - -def _open_ssh(host, port=None): - global _ssh_master - - # Bail before grabbing the lock if we already know that we aren't going to - # try creating new masters below. - if sys.platform in ('win32', 'cygwin'): - return False - - # Acquire the lock. This is needed to prevent opening multiple masters for - # the same host when we're running "repo sync -jN" (for N > 1) _and_ the - # manifest specifies a different host from the - # one that was passed to repo init. - _master_keys_lock.acquire() - try: - - # Check to see whether we already think that the master is running; if we - # think it's already running, return right away. - if port is not None: - key = '%s:%s' % (host, port) - else: - key = host - - if key in _master_keys: - return True - - if not _ssh_master or 'GIT_SSH' in os.environ: - # Failed earlier, so don't retry. - return False - - # We will make two calls to ssh; this is the common part of both calls. - command_base = ['ssh', - '-o', 'ControlPath %s' % ssh_sock(), - host] - if port is not None: - command_base[1:1] = ['-p', str(port)] - - # Since the key wasn't in _master_keys, we think that master isn't running. - # ...but before actually starting a master, we'll double-check. This can - # be important because we can't tell that that 'git@myhost.com' is the same - # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file. - check_command = command_base + ['-O', 'check'] - try: - Trace(': %s', ' '.join(check_command)) - check_process = subprocess.Popen(check_command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - check_process.communicate() # read output, but ignore it... - isnt_running = check_process.wait() - - if not isnt_running: - # Our double-check found that the master _was_ infact running. Add to - # the list of keys. - _master_keys.add(key) - return True - except Exception: - # Ignore excpetions. We we will fall back to the normal command and print - # to the log there. - pass - - command = command_base[:1] + ['-M', '-N'] + command_base[1:] - try: - Trace(': %s', ' '.join(command)) - p = subprocess.Popen(command) - except Exception as e: - _ssh_master = False - print('\nwarn: cannot enable ssh control master for %s:%s\n%s' - % (host, port, str(e)), file=sys.stderr) - return False - - time.sleep(1) - ssh_died = (p.poll() is not None) - if ssh_died: - return False - - _master_processes.append(p) - _master_keys.add(key) - return True - finally: - _master_keys_lock.release() - - -def close_ssh(): - global _master_keys_lock - - terminate_ssh_clients() - - for p in _master_processes: - try: - os.kill(p.pid, signal.SIGTERM) - p.wait() - except OSError: - pass - del _master_processes[:] - _master_keys.clear() - - d = ssh_sock(create=False) - if d: - try: - platform_utils.rmdir(os.path.dirname(d)) - except OSError: - pass - - # We're done with the lock, so we can delete it. - _master_keys_lock = None - - -URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') @@ -614,27 +483,6 @@ def GetUrlCookieFile(url, quiet): yield cookiefile, None -def _preconnect(url): - m = URI_ALL.match(url) - if m: - scheme = m.group(1) - host = m.group(2) - if ':' in host: - host, port = host.split(':') - else: - port = None - if scheme in ('ssh', 'git+ssh', 'ssh+git'): - return _open_ssh(host, port) - return False - - m = URI_SCP.match(url) - if m: - host = m.group(1) - return _open_ssh(host) - - return False - - class Remote(object): """Configuration options related to a remote. """ @@ -673,7 +521,7 @@ class Remote(object): def PreConnectFetch(self): connectionUrl = self._InsteadOf() - return _preconnect(connectionUrl) + return ssh.preconnect(connectionUrl) def ReviewUrl(self, userEmail, validate_certs): if self._review_url is None: diff --git a/main.py b/main.py index 8aba2ec..9674433 100755 --- a/main.py +++ b/main.py @@ -39,7 +39,7 @@ from color import SetDefaultColoring import event_log from repo_trace import SetTrace from git_command import user_agent -from git_config import init_ssh, close_ssh, RepoConfig +from git_config import RepoConfig from git_trace2_event_log import EventLog from command import InteractiveCommand from command import MirrorSafeCommand @@ -56,6 +56,7 @@ from error import RepoChangedException import gitc_utils from manifest_xml import GitcClient, RepoClient from pager import RunPager, TerminatePager +import ssh from wrapper import WrapperPath, Wrapper from subcmds import all_commands @@ -592,7 +593,7 @@ def _Main(argv): repo = _Repo(opt.repodir) try: try: - init_ssh() + ssh.init() init_http() name, gopts, argv = repo._ParseArgs(argv) run = lambda: repo._Run(name, gopts, argv) or 0 @@ -604,7 +605,7 @@ def _Main(argv): else: result = run() finally: - close_ssh() + ssh.close() except KeyboardInterrupt: print('aborted by user', file=sys.stderr) result = 1 diff --git a/ssh.py b/ssh.py new file mode 100644 index 0000000..d06c4eb --- /dev/null +++ b/ssh.py @@ -0,0 +1,257 @@ +# Copyright (C) 2008 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common SSH management logic.""" + +import functools +import os +import re +import signal +import subprocess +import sys +import tempfile +try: + import threading as _threading +except ImportError: + import dummy_threading as _threading +import time + +import platform_utils +from repo_trace import Trace + + +_ssh_proxy_path = None +_ssh_sock_path = None +_ssh_clients = [] + + +def _run_ssh_version(): + """run ssh -V to display the version number""" + return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode() + + +def _parse_ssh_version(ver_str=None): + """parse a ssh version string into a tuple""" + if ver_str is None: + ver_str = _run_ssh_version() + m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str) + if m: + return tuple(int(x) for x in m.group(1).split('.')) + else: + return () + + +@functools.lru_cache(maxsize=None) +def version(): + """return ssh version as a tuple""" + try: + return _parse_ssh_version() + except subprocess.CalledProcessError: + print('fatal: unable to detect ssh version', file=sys.stderr) + sys.exit(1) + + +def proxy(): + global _ssh_proxy_path + if _ssh_proxy_path is None: + _ssh_proxy_path = os.path.join( + os.path.dirname(__file__), + 'git_ssh') + return _ssh_proxy_path + + +def add_client(p): + _ssh_clients.append(p) + + +def remove_client(p): + try: + _ssh_clients.remove(p) + except ValueError: + pass + + +def _terminate_clients(): + global _ssh_clients + for p in _ssh_clients: + try: + os.kill(p.pid, signal.SIGTERM) + p.wait() + except OSError: + pass + _ssh_clients = [] + + +_master_processes = [] +_master_keys = set() +_ssh_master = True +_master_keys_lock = None + + +def init(): + """Should be called once at the start of repo to init ssh master handling. + + At the moment, all we do is to create our lock. + """ + global _master_keys_lock + assert _master_keys_lock is None, "Should only call init once" + _master_keys_lock = _threading.Lock() + + +def _open_ssh(host, port=None): + global _ssh_master + + # Bail before grabbing the lock if we already know that we aren't going to + # try creating new masters below. + if sys.platform in ('win32', 'cygwin'): + return False + + # Acquire the lock. This is needed to prevent opening multiple masters for + # the same host when we're running "repo sync -jN" (for N > 1) _and_ the + # manifest specifies a different host from the + # one that was passed to repo init. + _master_keys_lock.acquire() + try: + + # Check to see whether we already think that the master is running; if we + # think it's already running, return right away. + if port is not None: + key = '%s:%s' % (host, port) + else: + key = host + + if key in _master_keys: + return True + + if not _ssh_master or 'GIT_SSH' in os.environ: + # Failed earlier, so don't retry. + return False + + # We will make two calls to ssh; this is the common part of both calls. + command_base = ['ssh', + '-o', 'ControlPath %s' % sock(), + host] + if port is not None: + command_base[1:1] = ['-p', str(port)] + + # Since the key wasn't in _master_keys, we think that master isn't running. + # ...but before actually starting a master, we'll double-check. This can + # be important because we can't tell that that 'git@myhost.com' is the same + # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file. + check_command = command_base + ['-O', 'check'] + try: + Trace(': %s', ' '.join(check_command)) + check_process = subprocess.Popen(check_command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + check_process.communicate() # read output, but ignore it... + isnt_running = check_process.wait() + + if not isnt_running: + # Our double-check found that the master _was_ infact running. Add to + # the list of keys. + _master_keys.add(key) + return True + except Exception: + # Ignore excpetions. We we will fall back to the normal command and print + # to the log there. + pass + + command = command_base[:1] + ['-M', '-N'] + command_base[1:] + try: + Trace(': %s', ' '.join(command)) + p = subprocess.Popen(command) + except Exception as e: + _ssh_master = False + print('\nwarn: cannot enable ssh control master for %s:%s\n%s' + % (host, port, str(e)), file=sys.stderr) + return False + + time.sleep(1) + ssh_died = (p.poll() is not None) + if ssh_died: + return False + + _master_processes.append(p) + _master_keys.add(key) + return True + finally: + _master_keys_lock.release() + + +def close(): + global _master_keys_lock + + _terminate_clients() + + for p in _master_processes: + try: + os.kill(p.pid, signal.SIGTERM) + p.wait() + except OSError: + pass + del _master_processes[:] + _master_keys.clear() + + d = sock(create=False) + if d: + try: + platform_utils.rmdir(os.path.dirname(d)) + except OSError: + pass + + # We're done with the lock, so we can delete it. + _master_keys_lock = None + + +URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') +URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') + + +def preconnect(url): + m = URI_ALL.match(url) + if m: + scheme = m.group(1) + host = m.group(2) + if ':' in host: + host, port = host.split(':') + else: + port = None + if scheme in ('ssh', 'git+ssh', 'ssh+git'): + return _open_ssh(host, port) + return False + + m = URI_SCP.match(url) + if m: + host = m.group(1) + return _open_ssh(host) + + return False + +def sock(create=True): + global _ssh_sock_path + if _ssh_sock_path is None: + if not create: + return None + tmp_dir = '/tmp' + if not os.path.exists(tmp_dir): + tmp_dir = tempfile.gettempdir() + if version() < (6, 7): + tokens = '%r@%h:%p' + else: + tokens = '%C' # hash of %l%h%p%r + _ssh_sock_path = os.path.join( + tempfile.mkdtemp('', 'ssh-', tmp_dir), + 'master-' + tokens) + return _ssh_sock_path diff --git a/tests/test_git_command.py b/tests/test_git_command.py index 76c092f..93300a6 100644 --- a/tests/test_git_command.py +++ b/tests/test_git_command.py @@ -26,38 +26,6 @@ import git_command import wrapper -class SSHUnitTest(unittest.TestCase): - """Tests the ssh functions.""" - - def test_parse_ssh_version(self): - """Check parse_ssh_version() handling.""" - ver = git_command._parse_ssh_version('Unknown\n') - self.assertEqual(ver, ()) - ver = git_command._parse_ssh_version('OpenSSH_1.0\n') - self.assertEqual(ver, (1, 0)) - ver = git_command._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n') - self.assertEqual(ver, (6, 6, 1)) - ver = git_command._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n') - self.assertEqual(ver, (7, 6)) - - def test_ssh_version(self): - """Check ssh_version() handling.""" - with mock.patch('git_command._run_ssh_version', return_value='OpenSSH_1.2\n'): - self.assertEqual(git_command.ssh_version(), (1, 2)) - - def test_ssh_sock(self): - """Check ssh_sock() function.""" - with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): - # old ssh version uses port - with mock.patch('git_command.ssh_version', return_value=(6, 6)): - self.assertTrue(git_command.ssh_sock().endswith('%p')) - git_command._ssh_sock_path = None - # new ssh version uses hash - with mock.patch('git_command.ssh_version', return_value=(6, 7)): - self.assertTrue(git_command.ssh_sock().endswith('%C')) - git_command._ssh_sock_path = None - - class GitCallUnitTest(unittest.TestCase): """Tests the _GitCall class (via git_command.git).""" diff --git a/tests/test_ssh.py b/tests/test_ssh.py new file mode 100644 index 0000000..5a4f27e --- /dev/null +++ b/tests/test_ssh.py @@ -0,0 +1,52 @@ +# Copyright 2019 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unittests for the ssh.py module.""" + +import unittest +from unittest import mock + +import ssh + + +class SshTests(unittest.TestCase): + """Tests the ssh functions.""" + + def test_parse_ssh_version(self): + """Check _parse_ssh_version() handling.""" + ver = ssh._parse_ssh_version('Unknown\n') + self.assertEqual(ver, ()) + ver = ssh._parse_ssh_version('OpenSSH_1.0\n') + self.assertEqual(ver, (1, 0)) + ver = ssh._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n') + self.assertEqual(ver, (6, 6, 1)) + ver = ssh._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n') + self.assertEqual(ver, (7, 6)) + + def test_version(self): + """Check version() handling.""" + with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'): + self.assertEqual(ssh.version(), (1, 2)) + + def test_ssh_sock(self): + """Check sock() function.""" + with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): + # old ssh version uses port + with mock.patch('ssh.version', return_value=(6, 6)): + self.assertTrue(ssh.sock().endswith('%p')) + ssh._ssh_sock_path = None + # new ssh version uses hash + with mock.patch('ssh.version', return_value=(6, 7)): + self.assertTrue(ssh.sock().endswith('%C')) + ssh._ssh_sock_path = None