summaryrefslogtreecommitdiffstats
path: root/git_command.py
diff options
context:
space:
mode:
Diffstat (limited to 'git_command.py')
-rw-r--r--git_command.py116
1 files changed, 15 insertions, 101 deletions
diff --git a/git_command.py b/git_command.py
index d06fc77c..95db91f2 100644
--- a/git_command.py
+++ b/git_command.py
@@ -12,12 +12,10 @@
12# See the License for the specific language governing permissions and 12# See the License for the specific language governing permissions and
13# limitations under the License. 13# limitations under the License.
14 14
15import functools
15import os 16import os
16import re
17import sys 17import sys
18import subprocess 18import subprocess
19import tempfile
20from signal import SIGTERM
21 19
22from error import GitError 20from error import GitError
23from git_refs import HEAD 21from git_refs import HEAD
@@ -42,101 +40,15 @@ GIT_DIR = 'GIT_DIR'
42LAST_GITDIR = None 40LAST_GITDIR = None
43LAST_CWD = None 41LAST_CWD = None
44 42
45_ssh_proxy_path = None
46_ssh_sock_path = None
47_ssh_clients = []
48_ssh_version = None
49
50
51def _run_ssh_version():
52 """run ssh -V to display the version number"""
53 return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
54
55
56def _parse_ssh_version(ver_str=None):
57 """parse a ssh version string into a tuple"""
58 if ver_str is None:
59 ver_str = _run_ssh_version()
60 m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
61 if m:
62 return tuple(int(x) for x in m.group(1).split('.'))
63 else:
64 return ()
65
66
67def ssh_version():
68 """return ssh version as a tuple"""
69 global _ssh_version
70 if _ssh_version is None:
71 try:
72 _ssh_version = _parse_ssh_version()
73 except subprocess.CalledProcessError:
74 print('fatal: unable to detect ssh version', file=sys.stderr)
75 sys.exit(1)
76 return _ssh_version
77
78
79def ssh_sock(create=True):
80 global _ssh_sock_path
81 if _ssh_sock_path is None:
82 if not create:
83 return None
84 tmp_dir = '/tmp'
85 if not os.path.exists(tmp_dir):
86 tmp_dir = tempfile.gettempdir()
87 if ssh_version() < (6, 7):
88 tokens = '%r@%h:%p'
89 else:
90 tokens = '%C' # hash of %l%h%p%r
91 _ssh_sock_path = os.path.join(
92 tempfile.mkdtemp('', 'ssh-', tmp_dir),
93 'master-' + tokens)
94 return _ssh_sock_path
95
96
97def _ssh_proxy():
98 global _ssh_proxy_path
99 if _ssh_proxy_path is None:
100 _ssh_proxy_path = os.path.join(
101 os.path.dirname(__file__),
102 'git_ssh')
103 return _ssh_proxy_path
104
105
106def _add_ssh_client(p):
107 _ssh_clients.append(p)
108
109
110def _remove_ssh_client(p):
111 try:
112 _ssh_clients.remove(p)
113 except ValueError:
114 pass
115
116
117def terminate_ssh_clients():
118 global _ssh_clients
119 for p in _ssh_clients:
120 try:
121 os.kill(p.pid, SIGTERM)
122 p.wait()
123 except OSError:
124 pass
125 _ssh_clients = []
126
127
128_git_version = None
129
130 43
131class _GitCall(object): 44class _GitCall(object):
45 @functools.lru_cache(maxsize=None)
132 def version_tuple(self): 46 def version_tuple(self):
133 global _git_version 47 ret = Wrapper().ParseGitVersion()
134 if _git_version is None: 48 if ret is None:
135 _git_version = Wrapper().ParseGitVersion() 49 print('fatal: unable to detect git version', file=sys.stderr)
136 if _git_version is None: 50 sys.exit(1)
137 print('fatal: unable to detect git version', file=sys.stderr) 51 return ret
138 sys.exit(1)
139 return _git_version
140 52
141 def __getattr__(self, name): 53 def __getattr__(self, name):
142 name = name.replace('_', '-') 54 name = name.replace('_', '-')
@@ -163,7 +75,8 @@ def RepoSourceVersion():
163 proj = os.path.dirname(os.path.abspath(__file__)) 75 proj = os.path.dirname(os.path.abspath(__file__))
164 env[GIT_DIR] = os.path.join(proj, '.git') 76 env[GIT_DIR] = os.path.join(proj, '.git')
165 result = subprocess.run([GIT, 'describe', HEAD], stdout=subprocess.PIPE, 77 result = subprocess.run([GIT, 'describe', HEAD], stdout=subprocess.PIPE,
166 encoding='utf-8', env=env, check=False) 78 stderr=subprocess.DEVNULL, encoding='utf-8',
79 env=env, check=False)
167 if result.returncode == 0: 80 if result.returncode == 0:
168 ver = result.stdout.strip() 81 ver = result.stdout.strip()
169 if ver.startswith('v'): 82 if ver.startswith('v'):
@@ -254,7 +167,7 @@ class GitCommand(object):
254 capture_stderr=False, 167 capture_stderr=False,
255 merge_output=False, 168 merge_output=False,
256 disable_editor=False, 169 disable_editor=False,
257 ssh_proxy=False, 170 ssh_proxy=None,
258 cwd=None, 171 cwd=None,
259 gitdir=None): 172 gitdir=None):
260 env = self._GetBasicEnv() 173 env = self._GetBasicEnv()
@@ -262,8 +175,8 @@ class GitCommand(object):
262 if disable_editor: 175 if disable_editor:
263 env['GIT_EDITOR'] = ':' 176 env['GIT_EDITOR'] = ':'
264 if ssh_proxy: 177 if ssh_proxy:
265 env['REPO_SSH_SOCK'] = ssh_sock() 178 env['REPO_SSH_SOCK'] = ssh_proxy.sock()
266 env['GIT_SSH'] = _ssh_proxy() 179 env['GIT_SSH'] = ssh_proxy.proxy
267 env['GIT_SSH_VARIANT'] = 'ssh' 180 env['GIT_SSH_VARIANT'] = 'ssh'
268 if 'http_proxy' in env and 'darwin' == sys.platform: 181 if 'http_proxy' in env and 'darwin' == sys.platform:
269 s = "'http.proxy=%s'" % (env['http_proxy'],) 182 s = "'http.proxy=%s'" % (env['http_proxy'],)
@@ -346,7 +259,7 @@ class GitCommand(object):
346 raise GitError('%s: %s' % (command[1], e)) 259 raise GitError('%s: %s' % (command[1], e))
347 260
348 if ssh_proxy: 261 if ssh_proxy:
349 _add_ssh_client(p) 262 ssh_proxy.add_client(p)
350 263
351 self.process = p 264 self.process = p
352 if input: 265 if input:
@@ -358,7 +271,8 @@ class GitCommand(object):
358 try: 271 try:
359 self.stdout, self.stderr = p.communicate() 272 self.stdout, self.stderr = p.communicate()
360 finally: 273 finally:
361 _remove_ssh_client(p) 274 if ssh_proxy:
275 ssh_proxy.remove_client(p)
362 self.rc = p.wait() 276 self.rc = p.wait()
363 277
364 @staticmethod 278 @staticmethod