summaryrefslogtreecommitdiffstats
path: root/ssh.py
diff options
context:
space:
mode:
Diffstat (limited to 'ssh.py')
-rw-r--r--ssh.py277
1 files changed, 277 insertions, 0 deletions
diff --git a/ssh.py b/ssh.py
new file mode 100644
index 00000000..0ae8d120
--- /dev/null
+++ b/ssh.py
@@ -0,0 +1,277 @@
1# Copyright (C) 2008 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Common SSH management logic."""
16
17import functools
18import multiprocessing
19import os
20import re
21import signal
22import subprocess
23import sys
24import tempfile
25import time
26
27import platform_utils
28from repo_trace import Trace
29
30
31PROXY_PATH = os.path.join(os.path.dirname(__file__), 'git_ssh')
32
33
34def _run_ssh_version():
35 """run ssh -V to display the version number"""
36 return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
37
38
39def _parse_ssh_version(ver_str=None):
40 """parse a ssh version string into a tuple"""
41 if ver_str is None:
42 ver_str = _run_ssh_version()
43 m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
44 if m:
45 return tuple(int(x) for x in m.group(1).split('.'))
46 else:
47 return ()
48
49
50@functools.lru_cache(maxsize=None)
51def version():
52 """return ssh version as a tuple"""
53 try:
54 return _parse_ssh_version()
55 except subprocess.CalledProcessError:
56 print('fatal: unable to detect ssh version', file=sys.stderr)
57 sys.exit(1)
58
59
60URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
61URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
62
63
64class ProxyManager:
65 """Manage various ssh clients & masters that we spawn.
66
67 This will take care of sharing state between multiprocessing children, and
68 make sure that if we crash, we don't leak any of the ssh sessions.
69
70 The code should work with a single-process scenario too, and not add too much
71 overhead due to the manager.
72 """
73
74 # Path to the ssh program to run which will pass our master settings along.
75 # Set here more as a convenience API.
76 proxy = PROXY_PATH
77
78 def __init__(self, manager):
79 # Protect access to the list of active masters.
80 self._lock = multiprocessing.Lock()
81 # List of active masters (pid). These will be spawned on demand, and we are
82 # responsible for shutting them all down at the end.
83 self._masters = manager.list()
84 # Set of active masters indexed by "host:port" information.
85 # The value isn't used, but multiprocessing doesn't provide a set class.
86 self._master_keys = manager.dict()
87 # Whether ssh masters are known to be broken, so we give up entirely.
88 self._master_broken = manager.Value('b', False)
89 # List of active ssh sesssions. Clients will be added & removed as
90 # connections finish, so this list is just for safety & cleanup if we crash.
91 self._clients = manager.list()
92 # Path to directory for holding master sockets.
93 self._sock_path = None
94
95 def __enter__(self):
96 """Enter a new context."""
97 return self
98
99 def __exit__(self, exc_type, exc_value, traceback):
100 """Exit a context & clean up all resources."""
101 self.close()
102
103 def add_client(self, proc):
104 """Track a new ssh session."""
105 self._clients.append(proc.pid)
106
107 def remove_client(self, proc):
108 """Remove a completed ssh session."""
109 try:
110 self._clients.remove(proc.pid)
111 except ValueError:
112 pass
113
114 def add_master(self, proc):
115 """Track a new master connection."""
116 self._masters.append(proc.pid)
117
118 def _terminate(self, procs):
119 """Kill all |procs|."""
120 for pid in procs:
121 try:
122 os.kill(pid, signal.SIGTERM)
123 os.waitpid(pid, 0)
124 except OSError:
125 pass
126
127 # The multiprocessing.list() API doesn't provide many standard list()
128 # methods, so we have to manually clear the list.
129 while True:
130 try:
131 procs.pop(0)
132 except:
133 break
134
135 def close(self):
136 """Close this active ssh session.
137
138 Kill all ssh clients & masters we created, and nuke the socket dir.
139 """
140 self._terminate(self._clients)
141 self._terminate(self._masters)
142
143 d = self.sock(create=False)
144 if d:
145 try:
146 platform_utils.rmdir(os.path.dirname(d))
147 except OSError:
148 pass
149
150 def _open_unlocked(self, host, port=None):
151 """Make sure a ssh master session exists for |host| & |port|.
152
153 If one doesn't exist already, we'll create it.
154
155 We won't grab any locks, so the caller has to do that. This helps keep the
156 business logic of actually creating the master separate from grabbing locks.
157 """
158 # Check to see whether we already think that the master is running; if we
159 # think it's already running, return right away.
160 if port is not None:
161 key = '%s:%s' % (host, port)
162 else:
163 key = host
164
165 if key in self._master_keys:
166 return True
167
168 if self._master_broken.value or 'GIT_SSH' in os.environ:
169 # Failed earlier, so don't retry.
170 return False
171
172 # We will make two calls to ssh; this is the common part of both calls.
173 command_base = ['ssh', '-o', 'ControlPath %s' % self.sock(), host]
174 if port is not None:
175 command_base[1:1] = ['-p', str(port)]
176
177 # Since the key wasn't in _master_keys, we think that master isn't running.
178 # ...but before actually starting a master, we'll double-check. This can
179 # be important because we can't tell that that 'git@myhost.com' is the same
180 # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file.
181 check_command = command_base + ['-O', 'check']
182 try:
183 Trace(': %s', ' '.join(check_command))
184 check_process = subprocess.Popen(check_command,
185 stdout=subprocess.PIPE,
186 stderr=subprocess.PIPE)
187 check_process.communicate() # read output, but ignore it...
188 isnt_running = check_process.wait()
189
190 if not isnt_running:
191 # Our double-check found that the master _was_ infact running. Add to
192 # the list of keys.
193 self._master_keys[key] = True
194 return True
195 except Exception:
196 # Ignore excpetions. We we will fall back to the normal command and print
197 # to the log there.
198 pass
199
200 command = command_base[:1] + ['-M', '-N'] + command_base[1:]
201 try:
202 Trace(': %s', ' '.join(command))
203 p = subprocess.Popen(command)
204 except Exception as e:
205 self._master_broken.value = True
206 print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
207 % (host, port, str(e)), file=sys.stderr)
208 return False
209
210 time.sleep(1)
211 ssh_died = (p.poll() is not None)
212 if ssh_died:
213 return False
214
215 self.add_master(p)
216 self._master_keys[key] = True
217 return True
218
219 def _open(self, host, port=None):
220 """Make sure a ssh master session exists for |host| & |port|.
221
222 If one doesn't exist already, we'll create it.
223
224 This will obtain any necessary locks to avoid inter-process races.
225 """
226 # Bail before grabbing the lock if we already know that we aren't going to
227 # try creating new masters below.
228 if sys.platform in ('win32', 'cygwin'):
229 return False
230
231 # Acquire the lock. This is needed to prevent opening multiple masters for
232 # the same host when we're running "repo sync -jN" (for N > 1) _and_ the
233 # manifest <remote fetch="ssh://xyz"> specifies a different host from the
234 # one that was passed to repo init.
235 with self._lock:
236 return self._open_unlocked(host, port)
237
238 def preconnect(self, url):
239 """If |uri| will create a ssh connection, setup the ssh master for it."""
240 m = URI_ALL.match(url)
241 if m:
242 scheme = m.group(1)
243 host = m.group(2)
244 if ':' in host:
245 host, port = host.split(':')
246 else:
247 port = None
248 if scheme in ('ssh', 'git+ssh', 'ssh+git'):
249 return self._open(host, port)
250 return False
251
252 m = URI_SCP.match(url)
253 if m:
254 host = m.group(1)
255 return self._open(host)
256
257 return False
258
259 def sock(self, create=True):
260 """Return the path to the ssh socket dir.
261
262 This has all the master sockets so clients can talk to them.
263 """
264 if self._sock_path is None:
265 if not create:
266 return None
267 tmp_dir = '/tmp'
268 if not os.path.exists(tmp_dir):
269 tmp_dir = tempfile.gettempdir()
270 if version() < (6, 7):
271 tokens = '%r@%h:%p'
272 else:
273 tokens = '%C' # hash of %l%h%p%r
274 self._sock_path = os.path.join(
275 tempfile.mkdtemp('', 'ssh-', tmp_dir),
276 'master-' + tokens)
277 return self._sock_path