diff options
Diffstat (limited to 'ssh.py')
-rw-r--r-- | ssh.py | 277 |
1 files changed, 277 insertions, 0 deletions
@@ -0,0 +1,277 @@ | |||
1 | # Copyright (C) 2008 The Android Open Source Project | ||
2 | # | ||
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | # you may not use this file except in compliance with the License. | ||
5 | # You may obtain a copy of the License at | ||
6 | # | ||
7 | # http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | # | ||
9 | # Unless required by applicable law or agreed to in writing, software | ||
10 | # distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | # See the License for the specific language governing permissions and | ||
13 | # limitations under the License. | ||
14 | |||
15 | """Common SSH management logic.""" | ||
16 | |||
17 | import functools | ||
18 | import multiprocessing | ||
19 | import os | ||
20 | import re | ||
21 | import signal | ||
22 | import subprocess | ||
23 | import sys | ||
24 | import tempfile | ||
25 | import time | ||
26 | |||
27 | import platform_utils | ||
28 | from repo_trace import Trace | ||
29 | |||
30 | |||
31 | PROXY_PATH = os.path.join(os.path.dirname(__file__), 'git_ssh') | ||
32 | |||
33 | |||
34 | def _run_ssh_version(): | ||
35 | """run ssh -V to display the version number""" | ||
36 | return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode() | ||
37 | |||
38 | |||
39 | def _parse_ssh_version(ver_str=None): | ||
40 | """parse a ssh version string into a tuple""" | ||
41 | if ver_str is None: | ||
42 | ver_str = _run_ssh_version() | ||
43 | m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str) | ||
44 | if m: | ||
45 | return tuple(int(x) for x in m.group(1).split('.')) | ||
46 | else: | ||
47 | return () | ||
48 | |||
49 | |||
50 | @functools.lru_cache(maxsize=None) | ||
51 | def version(): | ||
52 | """return ssh version as a tuple""" | ||
53 | try: | ||
54 | return _parse_ssh_version() | ||
55 | except subprocess.CalledProcessError: | ||
56 | print('fatal: unable to detect ssh version', file=sys.stderr) | ||
57 | sys.exit(1) | ||
58 | |||
59 | |||
60 | URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') | ||
61 | URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') | ||
62 | |||
63 | |||
64 | class ProxyManager: | ||
65 | """Manage various ssh clients & masters that we spawn. | ||
66 | |||
67 | This will take care of sharing state between multiprocessing children, and | ||
68 | make sure that if we crash, we don't leak any of the ssh sessions. | ||
69 | |||
70 | The code should work with a single-process scenario too, and not add too much | ||
71 | overhead due to the manager. | ||
72 | """ | ||
73 | |||
74 | # Path to the ssh program to run which will pass our master settings along. | ||
75 | # Set here more as a convenience API. | ||
76 | proxy = PROXY_PATH | ||
77 | |||
78 | def __init__(self, manager): | ||
79 | # Protect access to the list of active masters. | ||
80 | self._lock = multiprocessing.Lock() | ||
81 | # List of active masters (pid). These will be spawned on demand, and we are | ||
82 | # responsible for shutting them all down at the end. | ||
83 | self._masters = manager.list() | ||
84 | # Set of active masters indexed by "host:port" information. | ||
85 | # The value isn't used, but multiprocessing doesn't provide a set class. | ||
86 | self._master_keys = manager.dict() | ||
87 | # Whether ssh masters are known to be broken, so we give up entirely. | ||
88 | self._master_broken = manager.Value('b', False) | ||
89 | # List of active ssh sesssions. Clients will be added & removed as | ||
90 | # connections finish, so this list is just for safety & cleanup if we crash. | ||
91 | self._clients = manager.list() | ||
92 | # Path to directory for holding master sockets. | ||
93 | self._sock_path = None | ||
94 | |||
95 | def __enter__(self): | ||
96 | """Enter a new context.""" | ||
97 | return self | ||
98 | |||
99 | def __exit__(self, exc_type, exc_value, traceback): | ||
100 | """Exit a context & clean up all resources.""" | ||
101 | self.close() | ||
102 | |||
103 | def add_client(self, proc): | ||
104 | """Track a new ssh session.""" | ||
105 | self._clients.append(proc.pid) | ||
106 | |||
107 | def remove_client(self, proc): | ||
108 | """Remove a completed ssh session.""" | ||
109 | try: | ||
110 | self._clients.remove(proc.pid) | ||
111 | except ValueError: | ||
112 | pass | ||
113 | |||
114 | def add_master(self, proc): | ||
115 | """Track a new master connection.""" | ||
116 | self._masters.append(proc.pid) | ||
117 | |||
118 | def _terminate(self, procs): | ||
119 | """Kill all |procs|.""" | ||
120 | for pid in procs: | ||
121 | try: | ||
122 | os.kill(pid, signal.SIGTERM) | ||
123 | os.waitpid(pid, 0) | ||
124 | except OSError: | ||
125 | pass | ||
126 | |||
127 | # The multiprocessing.list() API doesn't provide many standard list() | ||
128 | # methods, so we have to manually clear the list. | ||
129 | while True: | ||
130 | try: | ||
131 | procs.pop(0) | ||
132 | except: | ||
133 | break | ||
134 | |||
135 | def close(self): | ||
136 | """Close this active ssh session. | ||
137 | |||
138 | Kill all ssh clients & masters we created, and nuke the socket dir. | ||
139 | """ | ||
140 | self._terminate(self._clients) | ||
141 | self._terminate(self._masters) | ||
142 | |||
143 | d = self.sock(create=False) | ||
144 | if d: | ||
145 | try: | ||
146 | platform_utils.rmdir(os.path.dirname(d)) | ||
147 | except OSError: | ||
148 | pass | ||
149 | |||
150 | def _open_unlocked(self, host, port=None): | ||
151 | """Make sure a ssh master session exists for |host| & |port|. | ||
152 | |||
153 | If one doesn't exist already, we'll create it. | ||
154 | |||
155 | We won't grab any locks, so the caller has to do that. This helps keep the | ||
156 | business logic of actually creating the master separate from grabbing locks. | ||
157 | """ | ||
158 | # Check to see whether we already think that the master is running; if we | ||
159 | # think it's already running, return right away. | ||
160 | if port is not None: | ||
161 | key = '%s:%s' % (host, port) | ||
162 | else: | ||
163 | key = host | ||
164 | |||
165 | if key in self._master_keys: | ||
166 | return True | ||
167 | |||
168 | if self._master_broken.value or 'GIT_SSH' in os.environ: | ||
169 | # Failed earlier, so don't retry. | ||
170 | return False | ||
171 | |||
172 | # We will make two calls to ssh; this is the common part of both calls. | ||
173 | command_base = ['ssh', '-o', 'ControlPath %s' % self.sock(), host] | ||
174 | if port is not None: | ||
175 | command_base[1:1] = ['-p', str(port)] | ||
176 | |||
177 | # Since the key wasn't in _master_keys, we think that master isn't running. | ||
178 | # ...but before actually starting a master, we'll double-check. This can | ||
179 | # be important because we can't tell that that 'git@myhost.com' is the same | ||
180 | # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file. | ||
181 | check_command = command_base + ['-O', 'check'] | ||
182 | try: | ||
183 | Trace(': %s', ' '.join(check_command)) | ||
184 | check_process = subprocess.Popen(check_command, | ||
185 | stdout=subprocess.PIPE, | ||
186 | stderr=subprocess.PIPE) | ||
187 | check_process.communicate() # read output, but ignore it... | ||
188 | isnt_running = check_process.wait() | ||
189 | |||
190 | if not isnt_running: | ||
191 | # Our double-check found that the master _was_ infact running. Add to | ||
192 | # the list of keys. | ||
193 | self._master_keys[key] = True | ||
194 | return True | ||
195 | except Exception: | ||
196 | # Ignore excpetions. We we will fall back to the normal command and print | ||
197 | # to the log there. | ||
198 | pass | ||
199 | |||
200 | command = command_base[:1] + ['-M', '-N'] + command_base[1:] | ||
201 | try: | ||
202 | Trace(': %s', ' '.join(command)) | ||
203 | p = subprocess.Popen(command) | ||
204 | except Exception as e: | ||
205 | self._master_broken.value = True | ||
206 | print('\nwarn: cannot enable ssh control master for %s:%s\n%s' | ||
207 | % (host, port, str(e)), file=sys.stderr) | ||
208 | return False | ||
209 | |||
210 | time.sleep(1) | ||
211 | ssh_died = (p.poll() is not None) | ||
212 | if ssh_died: | ||
213 | return False | ||
214 | |||
215 | self.add_master(p) | ||
216 | self._master_keys[key] = True | ||
217 | return True | ||
218 | |||
219 | def _open(self, host, port=None): | ||
220 | """Make sure a ssh master session exists for |host| & |port|. | ||
221 | |||
222 | If one doesn't exist already, we'll create it. | ||
223 | |||
224 | This will obtain any necessary locks to avoid inter-process races. | ||
225 | """ | ||
226 | # Bail before grabbing the lock if we already know that we aren't going to | ||
227 | # try creating new masters below. | ||
228 | if sys.platform in ('win32', 'cygwin'): | ||
229 | return False | ||
230 | |||
231 | # Acquire the lock. This is needed to prevent opening multiple masters for | ||
232 | # the same host when we're running "repo sync -jN" (for N > 1) _and_ the | ||
233 | # manifest <remote fetch="ssh://xyz"> specifies a different host from the | ||
234 | # one that was passed to repo init. | ||
235 | with self._lock: | ||
236 | return self._open_unlocked(host, port) | ||
237 | |||
238 | def preconnect(self, url): | ||
239 | """If |uri| will create a ssh connection, setup the ssh master for it.""" | ||
240 | m = URI_ALL.match(url) | ||
241 | if m: | ||
242 | scheme = m.group(1) | ||
243 | host = m.group(2) | ||
244 | if ':' in host: | ||
245 | host, port = host.split(':') | ||
246 | else: | ||
247 | port = None | ||
248 | if scheme in ('ssh', 'git+ssh', 'ssh+git'): | ||
249 | return self._open(host, port) | ||
250 | return False | ||
251 | |||
252 | m = URI_SCP.match(url) | ||
253 | if m: | ||
254 | host = m.group(1) | ||
255 | return self._open(host) | ||
256 | |||
257 | return False | ||
258 | |||
259 | def sock(self, create=True): | ||
260 | """Return the path to the ssh socket dir. | ||
261 | |||
262 | This has all the master sockets so clients can talk to them. | ||
263 | """ | ||
264 | if self._sock_path is None: | ||
265 | if not create: | ||
266 | return None | ||
267 | tmp_dir = '/tmp' | ||
268 | if not os.path.exists(tmp_dir): | ||
269 | tmp_dir = tempfile.gettempdir() | ||
270 | if version() < (6, 7): | ||
271 | tokens = '%r@%h:%p' | ||
272 | else: | ||
273 | tokens = '%C' # hash of %l%h%p%r | ||
274 | self._sock_path = os.path.join( | ||
275 | tempfile.mkdtemp('', 'ssh-', tmp_dir), | ||
276 | 'master-' + tokens) | ||
277 | return self._sock_path | ||