diff options
Diffstat (limited to 'git_command.py')
-rw-r--r-- | git_command.py | 116 |
1 files changed, 15 insertions, 101 deletions
diff --git a/git_command.py b/git_command.py index d06fc77c..95db91f2 100644 --- a/git_command.py +++ b/git_command.py | |||
@@ -12,12 +12,10 @@ | |||
12 | # See the License for the specific language governing permissions and | 12 | # See the License for the specific language governing permissions and |
13 | # limitations under the License. | 13 | # limitations under the License. |
14 | 14 | ||
15 | import functools | ||
15 | import os | 16 | import os |
16 | import re | ||
17 | import sys | 17 | import sys |
18 | import subprocess | 18 | import subprocess |
19 | import tempfile | ||
20 | from signal import SIGTERM | ||
21 | 19 | ||
22 | from error import GitError | 20 | from error import GitError |
23 | from git_refs import HEAD | 21 | from git_refs import HEAD |
@@ -42,101 +40,15 @@ GIT_DIR = 'GIT_DIR' | |||
42 | LAST_GITDIR = None | 40 | LAST_GITDIR = None |
43 | LAST_CWD = None | 41 | LAST_CWD = None |
44 | 42 | ||
45 | _ssh_proxy_path = None | ||
46 | _ssh_sock_path = None | ||
47 | _ssh_clients = [] | ||
48 | _ssh_version = None | ||
49 | |||
50 | |||
51 | def _run_ssh_version(): | ||
52 | """run ssh -V to display the version number""" | ||
53 | return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode() | ||
54 | |||
55 | |||
56 | def _parse_ssh_version(ver_str=None): | ||
57 | """parse a ssh version string into a tuple""" | ||
58 | if ver_str is None: | ||
59 | ver_str = _run_ssh_version() | ||
60 | m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str) | ||
61 | if m: | ||
62 | return tuple(int(x) for x in m.group(1).split('.')) | ||
63 | else: | ||
64 | return () | ||
65 | |||
66 | |||
67 | def ssh_version(): | ||
68 | """return ssh version as a tuple""" | ||
69 | global _ssh_version | ||
70 | if _ssh_version is None: | ||
71 | try: | ||
72 | _ssh_version = _parse_ssh_version() | ||
73 | except subprocess.CalledProcessError: | ||
74 | print('fatal: unable to detect ssh version', file=sys.stderr) | ||
75 | sys.exit(1) | ||
76 | return _ssh_version | ||
77 | |||
78 | |||
79 | def ssh_sock(create=True): | ||
80 | global _ssh_sock_path | ||
81 | if _ssh_sock_path is None: | ||
82 | if not create: | ||
83 | return None | ||
84 | tmp_dir = '/tmp' | ||
85 | if not os.path.exists(tmp_dir): | ||
86 | tmp_dir = tempfile.gettempdir() | ||
87 | if ssh_version() < (6, 7): | ||
88 | tokens = '%r@%h:%p' | ||
89 | else: | ||
90 | tokens = '%C' # hash of %l%h%p%r | ||
91 | _ssh_sock_path = os.path.join( | ||
92 | tempfile.mkdtemp('', 'ssh-', tmp_dir), | ||
93 | 'master-' + tokens) | ||
94 | return _ssh_sock_path | ||
95 | |||
96 | |||
97 | def _ssh_proxy(): | ||
98 | global _ssh_proxy_path | ||
99 | if _ssh_proxy_path is None: | ||
100 | _ssh_proxy_path = os.path.join( | ||
101 | os.path.dirname(__file__), | ||
102 | 'git_ssh') | ||
103 | return _ssh_proxy_path | ||
104 | |||
105 | |||
106 | def _add_ssh_client(p): | ||
107 | _ssh_clients.append(p) | ||
108 | |||
109 | |||
110 | def _remove_ssh_client(p): | ||
111 | try: | ||
112 | _ssh_clients.remove(p) | ||
113 | except ValueError: | ||
114 | pass | ||
115 | |||
116 | |||
117 | def terminate_ssh_clients(): | ||
118 | global _ssh_clients | ||
119 | for p in _ssh_clients: | ||
120 | try: | ||
121 | os.kill(p.pid, SIGTERM) | ||
122 | p.wait() | ||
123 | except OSError: | ||
124 | pass | ||
125 | _ssh_clients = [] | ||
126 | |||
127 | |||
128 | _git_version = None | ||
129 | |||
130 | 43 | ||
131 | class _GitCall(object): | 44 | class _GitCall(object): |
45 | @functools.lru_cache(maxsize=None) | ||
132 | def version_tuple(self): | 46 | def version_tuple(self): |
133 | global _git_version | 47 | ret = Wrapper().ParseGitVersion() |
134 | if _git_version is None: | 48 | if ret is None: |
135 | _git_version = Wrapper().ParseGitVersion() | 49 | print('fatal: unable to detect git version', file=sys.stderr) |
136 | if _git_version is None: | 50 | sys.exit(1) |
137 | print('fatal: unable to detect git version', file=sys.stderr) | 51 | return ret |
138 | sys.exit(1) | ||
139 | return _git_version | ||
140 | 52 | ||
141 | def __getattr__(self, name): | 53 | def __getattr__(self, name): |
142 | name = name.replace('_', '-') | 54 | name = name.replace('_', '-') |
@@ -163,7 +75,8 @@ def RepoSourceVersion(): | |||
163 | proj = os.path.dirname(os.path.abspath(__file__)) | 75 | proj = os.path.dirname(os.path.abspath(__file__)) |
164 | env[GIT_DIR] = os.path.join(proj, '.git') | 76 | env[GIT_DIR] = os.path.join(proj, '.git') |
165 | result = subprocess.run([GIT, 'describe', HEAD], stdout=subprocess.PIPE, | 77 | result = subprocess.run([GIT, 'describe', HEAD], stdout=subprocess.PIPE, |
166 | encoding='utf-8', env=env, check=False) | 78 | stderr=subprocess.DEVNULL, encoding='utf-8', |
79 | env=env, check=False) | ||
167 | if result.returncode == 0: | 80 | if result.returncode == 0: |
168 | ver = result.stdout.strip() | 81 | ver = result.stdout.strip() |
169 | if ver.startswith('v'): | 82 | if ver.startswith('v'): |
@@ -254,7 +167,7 @@ class GitCommand(object): | |||
254 | capture_stderr=False, | 167 | capture_stderr=False, |
255 | merge_output=False, | 168 | merge_output=False, |
256 | disable_editor=False, | 169 | disable_editor=False, |
257 | ssh_proxy=False, | 170 | ssh_proxy=None, |
258 | cwd=None, | 171 | cwd=None, |
259 | gitdir=None): | 172 | gitdir=None): |
260 | env = self._GetBasicEnv() | 173 | env = self._GetBasicEnv() |
@@ -262,8 +175,8 @@ class GitCommand(object): | |||
262 | if disable_editor: | 175 | if disable_editor: |
263 | env['GIT_EDITOR'] = ':' | 176 | env['GIT_EDITOR'] = ':' |
264 | if ssh_proxy: | 177 | if ssh_proxy: |
265 | env['REPO_SSH_SOCK'] = ssh_sock() | 178 | env['REPO_SSH_SOCK'] = ssh_proxy.sock() |
266 | env['GIT_SSH'] = _ssh_proxy() | 179 | env['GIT_SSH'] = ssh_proxy.proxy |
267 | env['GIT_SSH_VARIANT'] = 'ssh' | 180 | env['GIT_SSH_VARIANT'] = 'ssh' |
268 | if 'http_proxy' in env and 'darwin' == sys.platform: | 181 | if 'http_proxy' in env and 'darwin' == sys.platform: |
269 | s = "'http.proxy=%s'" % (env['http_proxy'],) | 182 | s = "'http.proxy=%s'" % (env['http_proxy'],) |
@@ -346,7 +259,7 @@ class GitCommand(object): | |||
346 | raise GitError('%s: %s' % (command[1], e)) | 259 | raise GitError('%s: %s' % (command[1], e)) |
347 | 260 | ||
348 | if ssh_proxy: | 261 | if ssh_proxy: |
349 | _add_ssh_client(p) | 262 | ssh_proxy.add_client(p) |
350 | 263 | ||
351 | self.process = p | 264 | self.process = p |
352 | if input: | 265 | if input: |
@@ -358,7 +271,8 @@ class GitCommand(object): | |||
358 | try: | 271 | try: |
359 | self.stdout, self.stderr = p.communicate() | 272 | self.stdout, self.stderr = p.communicate() |
360 | finally: | 273 | finally: |
361 | _remove_ssh_client(p) | 274 | if ssh_proxy: |
275 | ssh_proxy.remove_client(p) | ||
362 | self.rc = p.wait() | 276 | self.rc = p.wait() |
363 | 277 | ||
364 | @staticmethod | 278 | @staticmethod |