summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/fixtures/test.gitconfig20
-rw-r--r--tests/test_editor.py4
-rw-r--r--tests/test_error.py53
-rw-r--r--tests/test_git_command.py54
-rw-r--r--tests/test_git_config.py134
-rw-r--r--tests/test_git_superproject.py376
-rw-r--r--tests/test_git_trace2_event_log.py329
-rw-r--r--tests/test_hooks.py55
-rw-r--r--tests/test_manifest_xml.py845
-rw-r--r--tests/test_platform_utils.py50
-rw-r--r--tests/test_project.py297
-rw-r--r--tests/test_ssh.py74
-rw-r--r--tests/test_subcmds.py73
-rw-r--r--tests/test_subcmds_init.py49
-rw-r--r--tests/test_wrapper.py511
15 files changed, 2850 insertions, 74 deletions
diff --git a/tests/fixtures/test.gitconfig b/tests/fixtures/test.gitconfig
index 3c573c9e..b178cf60 100644
--- a/tests/fixtures/test.gitconfig
+++ b/tests/fixtures/test.gitconfig
@@ -1,3 +1,23 @@
1[section] 1[section]
2 empty 2 empty
3 nonempty = true 3 nonempty = true
4 boolinvalid = oops
5 booltrue = true
6 boolfalse = false
7 intinvalid = oops
8 inthex = 0x10
9 inthexk = 0x10k
10 int = 10
11 intk = 10k
12 intm = 10m
13 intg = 10g
14[repo "syncstate.main"]
15 synctime = 2021-09-14T17:23:43.537338Z
16 version = 1
17[repo "syncstate.sys"]
18 argv = ['/usr/bin/pytest-3']
19[repo "syncstate.superproject"]
20 test = false
21[repo "syncstate.options"]
22 verbose = true
23 mpupdate = false
diff --git a/tests/test_editor.py b/tests/test_editor.py
index fbcfcdbd..cfd4f5ed 100644
--- a/tests/test_editor.py
+++ b/tests/test_editor.py
@@ -1,5 +1,3 @@
1# -*- coding:utf-8 -*-
2#
3# Copyright (C) 2019 The Android Open Source Project 1# Copyright (C) 2019 The Android Open Source Project
4# 2#
5# Licensed under the Apache License, Version 2.0 (the "License"); 3# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,8 +14,6 @@
16 14
17"""Unittests for the editor.py module.""" 15"""Unittests for the editor.py module."""
18 16
19from __future__ import print_function
20
21import unittest 17import unittest
22 18
23from editor import Editor 19from editor import Editor
diff --git a/tests/test_error.py b/tests/test_error.py
new file mode 100644
index 00000000..82b00c24
--- /dev/null
+++ b/tests/test_error.py
@@ -0,0 +1,53 @@
1# Copyright 2021 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Unittests for the error.py module."""
16
17import inspect
18import pickle
19import unittest
20
21import error
22
23
24class PickleTests(unittest.TestCase):
25 """Make sure all our custom exceptions can be pickled."""
26
27 def getExceptions(self):
28 """Return all our custom exceptions."""
29 for name in dir(error):
30 cls = getattr(error, name)
31 if isinstance(cls, type) and issubclass(cls, Exception):
32 yield cls
33
34 def testExceptionLookup(self):
35 """Make sure our introspection logic works."""
36 classes = list(self.getExceptions())
37 self.assertIn(error.HookError, classes)
38 # Don't assert the exact number to avoid being a change-detector test.
39 self.assertGreater(len(classes), 10)
40
41 def testPickle(self):
42 """Try to pickle all the exceptions."""
43 for cls in self.getExceptions():
44 args = inspect.getfullargspec(cls.__init__).args[1:]
45 obj = cls(*args)
46 p = pickle.dumps(obj)
47 try:
48 newobj = pickle.loads(p)
49 except Exception as e: # pylint: disable=broad-except
50 self.fail('Class %s is unable to be pickled: %s\n'
51 'Incomplete super().__init__(...) call?' % (cls, e))
52 self.assertIsInstance(newobj, cls)
53 self.assertEqual(str(obj), str(newobj))
diff --git a/tests/test_git_command.py b/tests/test_git_command.py
index 51171a32..93300a6f 100644
--- a/tests/test_git_command.py
+++ b/tests/test_git_command.py
@@ -1,5 +1,3 @@
1# -*- coding:utf-8 -*-
2#
3# Copyright 2019 The Android Open Source Project 1# Copyright 2019 The Android Open Source Project
4# 2#
5# Licensed under the Apache License, Version 2.0 (the "License"); 3# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,12 +14,16 @@
16 14
17"""Unittests for the git_command.py module.""" 15"""Unittests for the git_command.py module."""
18 16
19from __future__ import print_function
20
21import re 17import re
22import unittest 18import unittest
23 19
20try:
21 from unittest import mock
22except ImportError:
23 import mock
24
24import git_command 25import git_command
26import wrapper
25 27
26 28
27class GitCallUnitTest(unittest.TestCase): 29class GitCallUnitTest(unittest.TestCase):
@@ -35,7 +37,7 @@ class GitCallUnitTest(unittest.TestCase):
35 # We don't dive too deep into the values here to avoid having to update 37 # We don't dive too deep into the values here to avoid having to update
36 # whenever git versions change. We do check relative to this min version 38 # whenever git versions change. We do check relative to this min version
37 # as this is what `repo` itself requires via MIN_GIT_VERSION. 39 # as this is what `repo` itself requires via MIN_GIT_VERSION.
38 MIN_GIT_VERSION = (1, 7, 2) 40 MIN_GIT_VERSION = (2, 10, 2)
39 self.assertTrue(isinstance(ver.major, int)) 41 self.assertTrue(isinstance(ver.major, int))
40 self.assertTrue(isinstance(ver.minor, int)) 42 self.assertTrue(isinstance(ver.minor, int))
41 self.assertTrue(isinstance(ver.micro, int)) 43 self.assertTrue(isinstance(ver.micro, int))
@@ -76,3 +78,45 @@ class UserAgentUnitTest(unittest.TestCase):
76 # the general form. 78 # the general form.
77 m = re.match(r'^git/[^ ]+ ([^ ]+) git-repo/[^ ]+', ua) 79 m = re.match(r'^git/[^ ]+ ([^ ]+) git-repo/[^ ]+', ua)
78 self.assertIsNotNone(m) 80 self.assertIsNotNone(m)
81
82
83class GitRequireTests(unittest.TestCase):
84 """Test the git_require helper."""
85
86 def setUp(self):
87 ver = wrapper.GitVersion(1, 2, 3, 4)
88 mock.patch.object(git_command.git, 'version_tuple', return_value=ver).start()
89
90 def tearDown(self):
91 mock.patch.stopall()
92
93 def test_older_nonfatal(self):
94 """Test non-fatal require calls with old versions."""
95 self.assertFalse(git_command.git_require((2,)))
96 self.assertFalse(git_command.git_require((1, 3)))
97 self.assertFalse(git_command.git_require((1, 2, 4)))
98 self.assertFalse(git_command.git_require((1, 2, 3, 5)))
99
100 def test_newer_nonfatal(self):
101 """Test non-fatal require calls with newer versions."""
102 self.assertTrue(git_command.git_require((0,)))
103 self.assertTrue(git_command.git_require((1, 0)))
104 self.assertTrue(git_command.git_require((1, 2, 0)))
105 self.assertTrue(git_command.git_require((1, 2, 3, 0)))
106
107 def test_equal_nonfatal(self):
108 """Test require calls with equal values."""
109 self.assertTrue(git_command.git_require((1, 2, 3, 4), fail=False))
110 self.assertTrue(git_command.git_require((1, 2, 3, 4), fail=True))
111
112 def test_older_fatal(self):
113 """Test fatal require calls with old versions."""
114 with self.assertRaises(SystemExit) as e:
115 git_command.git_require((2,), fail=True)
116 self.assertNotEqual(0, e.code)
117
118 def test_older_fatal_msg(self):
119 """Test fatal require calls with old versions and message."""
120 with self.assertRaises(SystemExit) as e:
121 git_command.git_require((2,), fail=True, msg='so sad')
122 self.assertNotEqual(0, e.code)
diff --git a/tests/test_git_config.py b/tests/test_git_config.py
index b735f27f..faf12a2e 100644
--- a/tests/test_git_config.py
+++ b/tests/test_git_config.py
@@ -1,5 +1,3 @@
1# -*- coding:utf-8 -*-
2#
3# Copyright (C) 2009 The Android Open Source Project 1# Copyright (C) 2009 The Android Open Source Project
4# 2#
5# Licensed under the Apache License, Version 2.0 (the "License"); 3# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,21 +14,22 @@
16 14
17"""Unittests for the git_config.py module.""" 15"""Unittests for the git_config.py module."""
18 16
19from __future__ import print_function
20
21import os 17import os
18import tempfile
22import unittest 19import unittest
23 20
24import git_config 21import git_config
25 22
23
26def fixture(*paths): 24def fixture(*paths):
27 """Return a path relative to test/fixtures. 25 """Return a path relative to test/fixtures.
28 """ 26 """
29 return os.path.join(os.path.dirname(__file__), 'fixtures', *paths) 27 return os.path.join(os.path.dirname(__file__), 'fixtures', *paths)
30 28
31class GitConfigUnitTest(unittest.TestCase): 29
32 """Tests the GitConfig class. 30class GitConfigReadOnlyTests(unittest.TestCase):
33 """ 31 """Read-only tests of the GitConfig class."""
32
34 def setUp(self): 33 def setUp(self):
35 """Create a GitConfig object using the test.gitconfig fixture. 34 """Create a GitConfig object using the test.gitconfig fixture.
36 """ 35 """
@@ -68,5 +67,126 @@ class GitConfigUnitTest(unittest.TestCase):
68 val = config.GetString('empty') 67 val = config.GetString('empty')
69 self.assertEqual(val, None) 68 self.assertEqual(val, None)
70 69
70 def test_GetBoolean_undefined(self):
71 """Test GetBoolean on key that doesn't exist."""
72 self.assertIsNone(self.config.GetBoolean('section.missing'))
73
74 def test_GetBoolean_invalid(self):
75 """Test GetBoolean on invalid boolean value."""
76 self.assertIsNone(self.config.GetBoolean('section.boolinvalid'))
77
78 def test_GetBoolean_true(self):
79 """Test GetBoolean on valid true boolean."""
80 self.assertTrue(self.config.GetBoolean('section.booltrue'))
81
82 def test_GetBoolean_false(self):
83 """Test GetBoolean on valid false boolean."""
84 self.assertFalse(self.config.GetBoolean('section.boolfalse'))
85
86 def test_GetInt_undefined(self):
87 """Test GetInt on key that doesn't exist."""
88 self.assertIsNone(self.config.GetInt('section.missing'))
89
90 def test_GetInt_invalid(self):
91 """Test GetInt on invalid integer value."""
92 self.assertIsNone(self.config.GetBoolean('section.intinvalid'))
93
94 def test_GetInt_valid(self):
95 """Test GetInt on valid integers."""
96 TESTS = (
97 ('inthex', 16),
98 ('inthexk', 16384),
99 ('int', 10),
100 ('intk', 10240),
101 ('intm', 10485760),
102 ('intg', 10737418240),
103 )
104 for key, value in TESTS:
105 self.assertEqual(value, self.config.GetInt('section.%s' % (key,)))
106
107 def test_GetSyncAnalysisStateData(self):
108 """Test config entries with a sync state analysis data."""
109 superproject_logging_data = {}
110 superproject_logging_data['test'] = False
111 options = type('options', (object,), {})()
112 options.verbose = 'true'
113 options.mp_update = 'false'
114 TESTS = (
115 ('superproject.test', 'false'),
116 ('options.verbose', 'true'),
117 ('options.mpupdate', 'false'),
118 ('main.version', '1'),
119 )
120 self.config.UpdateSyncAnalysisState(options, superproject_logging_data)
121 sync_data = self.config.GetSyncAnalysisStateData()
122 for key, value in TESTS:
123 self.assertEqual(sync_data[f'{git_config.SYNC_STATE_PREFIX}{key}'], value)
124 self.assertTrue(sync_data[f'{git_config.SYNC_STATE_PREFIX}main.synctime'])
125
126
127class GitConfigReadWriteTests(unittest.TestCase):
128 """Read/write tests of the GitConfig class."""
129
130 def setUp(self):
131 self.tmpfile = tempfile.NamedTemporaryFile()
132 self.config = self.get_config()
133
134 def get_config(self):
135 """Get a new GitConfig instance."""
136 return git_config.GitConfig(self.tmpfile.name)
137
138 def test_SetString(self):
139 """Test SetString behavior."""
140 # Set a value.
141 self.assertIsNone(self.config.GetString('foo.bar'))
142 self.config.SetString('foo.bar', 'val')
143 self.assertEqual('val', self.config.GetString('foo.bar'))
144
145 # Make sure the value was actually written out.
146 config = self.get_config()
147 self.assertEqual('val', config.GetString('foo.bar'))
148
149 # Update the value.
150 self.config.SetString('foo.bar', 'valll')
151 self.assertEqual('valll', self.config.GetString('foo.bar'))
152 config = self.get_config()
153 self.assertEqual('valll', config.GetString('foo.bar'))
154
155 # Delete the value.
156 self.config.SetString('foo.bar', None)
157 self.assertIsNone(self.config.GetString('foo.bar'))
158 config = self.get_config()
159 self.assertIsNone(config.GetString('foo.bar'))
160
161 def test_SetBoolean(self):
162 """Test SetBoolean behavior."""
163 # Set a true value.
164 self.assertIsNone(self.config.GetBoolean('foo.bar'))
165 for val in (True, 1):
166 self.config.SetBoolean('foo.bar', val)
167 self.assertTrue(self.config.GetBoolean('foo.bar'))
168
169 # Make sure the value was actually written out.
170 config = self.get_config()
171 self.assertTrue(config.GetBoolean('foo.bar'))
172 self.assertEqual('true', config.GetString('foo.bar'))
173
174 # Set a false value.
175 for val in (False, 0):
176 self.config.SetBoolean('foo.bar', val)
177 self.assertFalse(self.config.GetBoolean('foo.bar'))
178
179 # Make sure the value was actually written out.
180 config = self.get_config()
181 self.assertFalse(config.GetBoolean('foo.bar'))
182 self.assertEqual('false', config.GetString('foo.bar'))
183
184 # Delete the value.
185 self.config.SetBoolean('foo.bar', None)
186 self.assertIsNone(self.config.GetBoolean('foo.bar'))
187 config = self.get_config()
188 self.assertIsNone(config.GetBoolean('foo.bar'))
189
190
71if __name__ == '__main__': 191if __name__ == '__main__':
72 unittest.main() 192 unittest.main()
diff --git a/tests/test_git_superproject.py b/tests/test_git_superproject.py
new file mode 100644
index 00000000..a24fc7f0
--- /dev/null
+++ b/tests/test_git_superproject.py
@@ -0,0 +1,376 @@
1# Copyright (C) 2021 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Unittests for the git_superproject.py module."""
16
17import json
18import os
19import platform
20import tempfile
21import unittest
22from unittest import mock
23
24import git_superproject
25import git_trace2_event_log
26import manifest_xml
27import platform_utils
28from test_manifest_xml import sort_attributes
29
30
31class SuperprojectTestCase(unittest.TestCase):
32 """TestCase for the Superproject module."""
33
34 PARENT_SID_KEY = 'GIT_TRACE2_PARENT_SID'
35 PARENT_SID_VALUE = 'parent_sid'
36 SELF_SID_REGEX = r'repo-\d+T\d+Z-.*'
37 FULL_SID_REGEX = r'^%s/%s' % (PARENT_SID_VALUE, SELF_SID_REGEX)
38
39 def setUp(self):
40 """Set up superproject every time."""
41 self.tempdir = tempfile.mkdtemp(prefix='repo_tests')
42 self.repodir = os.path.join(self.tempdir, '.repo')
43 self.manifest_file = os.path.join(
44 self.repodir, manifest_xml.MANIFEST_FILE_NAME)
45 os.mkdir(self.repodir)
46 self.platform = platform.system().lower()
47
48 # By default we initialize with the expected case where
49 # repo launches us (so GIT_TRACE2_PARENT_SID is set).
50 env = {
51 self.PARENT_SID_KEY: self.PARENT_SID_VALUE,
52 }
53 self.git_event_log = git_trace2_event_log.EventLog(env=env)
54
55 # The manifest parsing really wants a git repo currently.
56 gitdir = os.path.join(self.repodir, 'manifests.git')
57 os.mkdir(gitdir)
58 with open(os.path.join(gitdir, 'config'), 'w') as fp:
59 fp.write("""[remote "origin"]
60 url = https://localhost:0/manifest
61""")
62
63 manifest = self.getXmlManifest("""
64<manifest>
65 <remote name="default-remote" fetch="http://localhost" />
66 <default remote="default-remote" revision="refs/heads/main" />
67 <superproject name="superproject"/>
68 <project path="art" name="platform/art" groups="notdefault,platform-""" + self.platform + """
69 " /></manifest>
70""")
71 self._superproject = git_superproject.Superproject(manifest, self.repodir,
72 self.git_event_log)
73
74 def tearDown(self):
75 """Tear down superproject every time."""
76 platform_utils.rmtree(self.tempdir)
77
78 def getXmlManifest(self, data):
79 """Helper to initialize a manifest for testing."""
80 with open(self.manifest_file, 'w') as fp:
81 fp.write(data)
82 return manifest_xml.XmlManifest(self.repodir, self.manifest_file)
83
84 def verifyCommonKeys(self, log_entry, expected_event_name, full_sid=True):
85 """Helper function to verify common event log keys."""
86 self.assertIn('event', log_entry)
87 self.assertIn('sid', log_entry)
88 self.assertIn('thread', log_entry)
89 self.assertIn('time', log_entry)
90
91 # Do basic data format validation.
92 self.assertEqual(expected_event_name, log_entry['event'])
93 if full_sid:
94 self.assertRegex(log_entry['sid'], self.FULL_SID_REGEX)
95 else:
96 self.assertRegex(log_entry['sid'], self.SELF_SID_REGEX)
97 self.assertRegex(log_entry['time'], r'^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$')
98
99 def readLog(self, log_path):
100 """Helper function to read log data into a list."""
101 log_data = []
102 with open(log_path, mode='rb') as f:
103 for line in f:
104 log_data.append(json.loads(line))
105 return log_data
106
107 def verifyErrorEvent(self):
108 """Helper to verify that error event is written."""
109
110 with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir:
111 log_path = self.git_event_log.Write(path=tempdir)
112 self.log_data = self.readLog(log_path)
113
114 self.assertEqual(len(self.log_data), 2)
115 error_event = self.log_data[1]
116 self.verifyCommonKeys(self.log_data[0], expected_event_name='version')
117 self.verifyCommonKeys(error_event, expected_event_name='error')
118 # Check for 'error' event specific fields.
119 self.assertIn('msg', error_event)
120 self.assertIn('fmt', error_event)
121
122 def test_superproject_get_superproject_no_superproject(self):
123 """Test with no url."""
124 manifest = self.getXmlManifest("""
125<manifest>
126</manifest>
127""")
128 superproject = git_superproject.Superproject(manifest, self.repodir, self.git_event_log)
129 # Test that exit condition is false when there is no superproject tag.
130 sync_result = superproject.Sync()
131 self.assertFalse(sync_result.success)
132 self.assertFalse(sync_result.fatal)
133 self.verifyErrorEvent()
134
135 def test_superproject_get_superproject_invalid_url(self):
136 """Test with an invalid url."""
137 manifest = self.getXmlManifest("""
138<manifest>
139 <remote name="test-remote" fetch="localhost" />
140 <default remote="test-remote" revision="refs/heads/main" />
141 <superproject name="superproject"/>
142</manifest>
143""")
144 superproject = git_superproject.Superproject(manifest, self.repodir, self.git_event_log)
145 sync_result = superproject.Sync()
146 self.assertFalse(sync_result.success)
147 self.assertTrue(sync_result.fatal)
148
149 def test_superproject_get_superproject_invalid_branch(self):
150 """Test with an invalid branch."""
151 manifest = self.getXmlManifest("""
152<manifest>
153 <remote name="test-remote" fetch="localhost" />
154 <default remote="test-remote" revision="refs/heads/main" />
155 <superproject name="superproject"/>
156</manifest>
157""")
158 self._superproject = git_superproject.Superproject(manifest, self.repodir,
159 self.git_event_log)
160 with mock.patch.object(self._superproject, '_branch', 'junk'):
161 sync_result = self._superproject.Sync()
162 self.assertFalse(sync_result.success)
163 self.assertTrue(sync_result.fatal)
164
165 def test_superproject_get_superproject_mock_init(self):
166 """Test with _Init failing."""
167 with mock.patch.object(self._superproject, '_Init', return_value=False):
168 sync_result = self._superproject.Sync()
169 self.assertFalse(sync_result.success)
170 self.assertTrue(sync_result.fatal)
171
172 def test_superproject_get_superproject_mock_fetch(self):
173 """Test with _Fetch failing."""
174 with mock.patch.object(self._superproject, '_Init', return_value=True):
175 os.mkdir(self._superproject._superproject_path)
176 with mock.patch.object(self._superproject, '_Fetch', return_value=False):
177 sync_result = self._superproject.Sync()
178 self.assertFalse(sync_result.success)
179 self.assertTrue(sync_result.fatal)
180
181 def test_superproject_get_all_project_commit_ids_mock_ls_tree(self):
182 """Test with LsTree being a mock."""
183 data = ('120000 blob 158258bdf146f159218e2b90f8b699c4d85b5804\tAndroid.bp\x00'
184 '160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00'
185 '160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00'
186 '120000 blob acc2cbdf438f9d2141f0ae424cec1d8fc4b5d97f\tbootstrap.bash\x00'
187 '160000 commit ade9b7a0d874e25fff4bf2552488825c6f111928\tbuild/bazel\x00')
188 with mock.patch.object(self._superproject, '_Init', return_value=True):
189 with mock.patch.object(self._superproject, '_Fetch', return_value=True):
190 with mock.patch.object(self._superproject, '_LsTree', return_value=data):
191 commit_ids_result = self._superproject._GetAllProjectsCommitIds()
192 self.assertEqual(commit_ids_result.commit_ids, {
193 'art': '2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea',
194 'bootable/recovery': 'e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06',
195 'build/bazel': 'ade9b7a0d874e25fff4bf2552488825c6f111928'
196 })
197 self.assertFalse(commit_ids_result.fatal)
198
199 def test_superproject_write_manifest_file(self):
200 """Test with writing manifest to a file after setting revisionId."""
201 self.assertEqual(len(self._superproject._manifest.projects), 1)
202 project = self._superproject._manifest.projects[0]
203 project.SetRevisionId('ABCDEF')
204 # Create temporary directory so that it can write the file.
205 os.mkdir(self._superproject._superproject_path)
206 manifest_path = self._superproject._WriteManifestFile()
207 self.assertIsNotNone(manifest_path)
208 with open(manifest_path, 'r') as fp:
209 manifest_xml_data = fp.read()
210 self.assertEqual(
211 sort_attributes(manifest_xml_data),
212 '<?xml version="1.0" ?><manifest>'
213 '<remote fetch="http://localhost" name="default-remote"/>'
214 '<default remote="default-remote" revision="refs/heads/main"/>'
215 '<project groups="notdefault,platform-' + self.platform + '" '
216 'name="platform/art" path="art" revision="ABCDEF" upstream="refs/heads/main"/>'
217 '<superproject name="superproject"/>'
218 '</manifest>')
219
220 def test_superproject_update_project_revision_id(self):
221 """Test with LsTree being a mock."""
222 self.assertEqual(len(self._superproject._manifest.projects), 1)
223 projects = self._superproject._manifest.projects
224 data = ('160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00'
225 '160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00')
226 with mock.patch.object(self._superproject, '_Init', return_value=True):
227 with mock.patch.object(self._superproject, '_Fetch', return_value=True):
228 with mock.patch.object(self._superproject,
229 '_LsTree',
230 return_value=data):
231 # Create temporary directory so that it can write the file.
232 os.mkdir(self._superproject._superproject_path)
233 update_result = self._superproject.UpdateProjectsRevisionId(projects)
234 self.assertIsNotNone(update_result.manifest_path)
235 self.assertFalse(update_result.fatal)
236 with open(update_result.manifest_path, 'r') as fp:
237 manifest_xml_data = fp.read()
238 self.assertEqual(
239 sort_attributes(manifest_xml_data),
240 '<?xml version="1.0" ?><manifest>'
241 '<remote fetch="http://localhost" name="default-remote"/>'
242 '<default remote="default-remote" revision="refs/heads/main"/>'
243 '<project groups="notdefault,platform-' + self.platform + '" '
244 'name="platform/art" path="art" '
245 'revision="2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea" upstream="refs/heads/main"/>'
246 '<superproject name="superproject"/>'
247 '</manifest>')
248
249 def test_superproject_update_project_revision_id_no_superproject_tag(self):
250 """Test update of commit ids of a manifest without superproject tag."""
251 manifest = self.getXmlManifest("""
252<manifest>
253 <remote name="default-remote" fetch="http://localhost" />
254 <default remote="default-remote" revision="refs/heads/main" />
255 <project name="test-name"/>
256</manifest>
257""")
258 self.maxDiff = None
259 self._superproject = git_superproject.Superproject(manifest, self.repodir,
260 self.git_event_log)
261 self.assertEqual(len(self._superproject._manifest.projects), 1)
262 projects = self._superproject._manifest.projects
263 project = projects[0]
264 project.SetRevisionId('ABCDEF')
265 update_result = self._superproject.UpdateProjectsRevisionId(projects)
266 self.assertIsNone(update_result.manifest_path)
267 self.assertFalse(update_result.fatal)
268 self.verifyErrorEvent()
269 self.assertEqual(
270 sort_attributes(manifest.ToXml().toxml()),
271 '<?xml version="1.0" ?><manifest>'
272 '<remote fetch="http://localhost" name="default-remote"/>'
273 '<default remote="default-remote" revision="refs/heads/main"/>'
274 '<project name="test-name" revision="ABCDEF" upstream="refs/heads/main"/>'
275 '</manifest>')
276
277 def test_superproject_update_project_revision_id_from_local_manifest_group(self):
278 """Test update of commit ids of a manifest that have local manifest no superproject group."""
279 local_group = manifest_xml.LOCAL_MANIFEST_GROUP_PREFIX + ':local'
280 manifest = self.getXmlManifest("""
281<manifest>
282 <remote name="default-remote" fetch="http://localhost" />
283 <remote name="goog" fetch="http://localhost2" />
284 <default remote="default-remote" revision="refs/heads/main" />
285 <superproject name="superproject"/>
286 <project path="vendor/x" name="platform/vendor/x" remote="goog"
287 groups=\"""" + local_group + """
288 " revision="master-with-vendor" clone-depth="1" />
289 <project path="art" name="platform/art" groups="notdefault,platform-""" + self.platform + """
290 " /></manifest>
291""")
292 self.maxDiff = None
293 self._superproject = git_superproject.Superproject(manifest, self.repodir,
294 self.git_event_log)
295 self.assertEqual(len(self._superproject._manifest.projects), 2)
296 projects = self._superproject._manifest.projects
297 data = ('160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00')
298 with mock.patch.object(self._superproject, '_Init', return_value=True):
299 with mock.patch.object(self._superproject, '_Fetch', return_value=True):
300 with mock.patch.object(self._superproject,
301 '_LsTree',
302 return_value=data):
303 # Create temporary directory so that it can write the file.
304 os.mkdir(self._superproject._superproject_path)
305 update_result = self._superproject.UpdateProjectsRevisionId(projects)
306 self.assertIsNotNone(update_result.manifest_path)
307 self.assertFalse(update_result.fatal)
308 with open(update_result.manifest_path, 'r') as fp:
309 manifest_xml_data = fp.read()
310 # Verify platform/vendor/x's project revision hasn't changed.
311 self.assertEqual(
312 sort_attributes(manifest_xml_data),
313 '<?xml version="1.0" ?><manifest>'
314 '<remote fetch="http://localhost" name="default-remote"/>'
315 '<remote fetch="http://localhost2" name="goog"/>'
316 '<default remote="default-remote" revision="refs/heads/main"/>'
317 '<project groups="notdefault,platform-' + self.platform + '" '
318 'name="platform/art" path="art" '
319 'revision="2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea" upstream="refs/heads/main"/>'
320 '<project clone-depth="1" groups="' + local_group + '" '
321 'name="platform/vendor/x" path="vendor/x" remote="goog" '
322 'revision="master-with-vendor"/>'
323 '<superproject name="superproject"/>'
324 '</manifest>')
325
326 def test_superproject_update_project_revision_id_with_pinned_manifest(self):
327 """Test update of commit ids of a pinned manifest."""
328 manifest = self.getXmlManifest("""
329<manifest>
330 <remote name="default-remote" fetch="http://localhost" />
331 <default remote="default-remote" revision="refs/heads/main" />
332 <superproject name="superproject"/>
333 <project path="vendor/x" name="platform/vendor/x" revision="" />
334 <project path="vendor/y" name="platform/vendor/y"
335 revision="52d3c9f7c107839ece2319d077de0cd922aa9d8f" />
336 <project path="art" name="platform/art" groups="notdefault,platform-""" + self.platform + """
337 " /></manifest>
338""")
339 self.maxDiff = None
340 self._superproject = git_superproject.Superproject(manifest, self.repodir,
341 self.git_event_log)
342 self.assertEqual(len(self._superproject._manifest.projects), 3)
343 projects = self._superproject._manifest.projects
344 data = ('160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00'
345 '160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tvendor/x\x00')
346 with mock.patch.object(self._superproject, '_Init', return_value=True):
347 with mock.patch.object(self._superproject, '_Fetch', return_value=True):
348 with mock.patch.object(self._superproject,
349 '_LsTree',
350 return_value=data):
351 # Create temporary directory so that it can write the file.
352 os.mkdir(self._superproject._superproject_path)
353 update_result = self._superproject.UpdateProjectsRevisionId(projects)
354 self.assertIsNotNone(update_result.manifest_path)
355 self.assertFalse(update_result.fatal)
356 with open(update_result.manifest_path, 'r') as fp:
357 manifest_xml_data = fp.read()
358 # Verify platform/vendor/x's project revision hasn't changed.
359 self.assertEqual(
360 sort_attributes(manifest_xml_data),
361 '<?xml version="1.0" ?><manifest>'
362 '<remote fetch="http://localhost" name="default-remote"/>'
363 '<default remote="default-remote" revision="refs/heads/main"/>'
364 '<project groups="notdefault,platform-' + self.platform + '" '
365 'name="platform/art" path="art" '
366 'revision="2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea" upstream="refs/heads/main"/>'
367 '<project name="platform/vendor/x" path="vendor/x" '
368 'revision="e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06" upstream="refs/heads/main"/>'
369 '<project name="platform/vendor/y" path="vendor/y" '
370 'revision="52d3c9f7c107839ece2319d077de0cd922aa9d8f"/>'
371 '<superproject name="superproject"/>'
372 '</manifest>')
373
374
375if __name__ == '__main__':
376 unittest.main()
diff --git a/tests/test_git_trace2_event_log.py b/tests/test_git_trace2_event_log.py
new file mode 100644
index 00000000..89dcfb92
--- /dev/null
+++ b/tests/test_git_trace2_event_log.py
@@ -0,0 +1,329 @@
1# Copyright (C) 2020 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Unittests for the git_trace2_event_log.py module."""
16
17import json
18import os
19import tempfile
20import unittest
21from unittest import mock
22
23import git_trace2_event_log
24
25
26class EventLogTestCase(unittest.TestCase):
27 """TestCase for the EventLog module."""
28
29 PARENT_SID_KEY = 'GIT_TRACE2_PARENT_SID'
30 PARENT_SID_VALUE = 'parent_sid'
31 SELF_SID_REGEX = r'repo-\d+T\d+Z-.*'
32 FULL_SID_REGEX = r'^%s/%s' % (PARENT_SID_VALUE, SELF_SID_REGEX)
33
34 def setUp(self):
35 """Load the event_log module every time."""
36 self._event_log_module = None
37 # By default we initialize with the expected case where
38 # repo launches us (so GIT_TRACE2_PARENT_SID is set).
39 env = {
40 self.PARENT_SID_KEY: self.PARENT_SID_VALUE,
41 }
42 self._event_log_module = git_trace2_event_log.EventLog(env=env)
43 self._log_data = None
44
45 def verifyCommonKeys(self, log_entry, expected_event_name=None, full_sid=True):
46 """Helper function to verify common event log keys."""
47 self.assertIn('event', log_entry)
48 self.assertIn('sid', log_entry)
49 self.assertIn('thread', log_entry)
50 self.assertIn('time', log_entry)
51
52 # Do basic data format validation.
53 if expected_event_name:
54 self.assertEqual(expected_event_name, log_entry['event'])
55 if full_sid:
56 self.assertRegex(log_entry['sid'], self.FULL_SID_REGEX)
57 else:
58 self.assertRegex(log_entry['sid'], self.SELF_SID_REGEX)
59 self.assertRegex(log_entry['time'], r'^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$')
60
61 def readLog(self, log_path):
62 """Helper function to read log data into a list."""
63 log_data = []
64 with open(log_path, mode='rb') as f:
65 for line in f:
66 log_data.append(json.loads(line))
67 return log_data
68
69 def remove_prefix(self, s, prefix):
70 """Return a copy string after removing |prefix| from |s|, if present or the original string."""
71 if s.startswith(prefix):
72 return s[len(prefix):]
73 else:
74 return s
75
76 def test_initial_state_with_parent_sid(self):
77 """Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent."""
78 self.assertRegex(self._event_log_module.full_sid, self.FULL_SID_REGEX)
79
80 def test_initial_state_no_parent_sid(self):
81 """Test initial state when 'GIT_TRACE2_PARENT_SID' is not set."""
82 # Setup an empty environment dict (no parent sid).
83 self._event_log_module = git_trace2_event_log.EventLog(env={})
84 self.assertRegex(self._event_log_module.full_sid, self.SELF_SID_REGEX)
85
86 def test_version_event(self):
87 """Test 'version' event data is valid.
88
89 Verify that the 'version' event is written even when no other
90 events are addded.
91
92 Expected event log:
93 <version event>
94 """
95 with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir:
96 log_path = self._event_log_module.Write(path=tempdir)
97 self._log_data = self.readLog(log_path)
98
99 # A log with no added events should only have the version entry.
100 self.assertEqual(len(self._log_data), 1)
101 version_event = self._log_data[0]
102 self.verifyCommonKeys(version_event, expected_event_name='version')
103 # Check for 'version' event specific fields.
104 self.assertIn('evt', version_event)
105 self.assertIn('exe', version_event)
106 # Verify "evt" version field is a string.
107 self.assertIsInstance(version_event['evt'], str)
108
109 def test_start_event(self):
110 """Test and validate 'start' event data is valid.
111
112 Expected event log:
113 <version event>
114 <start event>
115 """
116 self._event_log_module.StartEvent()
117 with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir:
118 log_path = self._event_log_module.Write(path=tempdir)
119 self._log_data = self.readLog(log_path)
120
121 self.assertEqual(len(self._log_data), 2)
122 start_event = self._log_data[1]
123 self.verifyCommonKeys(self._log_data[0], expected_event_name='version')
124 self.verifyCommonKeys(start_event, expected_event_name='start')
125 # Check for 'start' event specific fields.
126 self.assertIn('argv', start_event)
127 self.assertTrue(isinstance(start_event['argv'], list))
128
129 def test_exit_event_result_none(self):
130 """Test 'exit' event data is valid when result is None.
131
132 We expect None result to be converted to 0 in the exit event data.
133
134 Expected event log:
135 <version event>
136 <exit event>
137 """
138 self._event_log_module.ExitEvent(None)
139 with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir:
140 log_path = self._event_log_module.Write(path=tempdir)
141 self._log_data = self.readLog(log_path)
142
143 self.assertEqual(len(self._log_data), 2)
144 exit_event = self._log_data[1]
145 self.verifyCommonKeys(self._log_data[0], expected_event_name='version')
146 self.verifyCommonKeys(exit_event, expected_event_name='exit')
147 # Check for 'exit' event specific fields.
148 self.assertIn('code', exit_event)
149 # 'None' result should convert to 0 (successful) return code.
150 self.assertEqual(exit_event['code'], 0)
151
152 def test_exit_event_result_integer(self):
153 """Test 'exit' event data is valid when result is an integer.
154
155 Expected event log:
156 <version event>
157 <exit event>
158 """
159 self._event_log_module.ExitEvent(2)
160 with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir:
161 log_path = self._event_log_module.Write(path=tempdir)
162 self._log_data = self.readLog(log_path)
163
164 self.assertEqual(len(self._log_data), 2)
165 exit_event = self._log_data[1]
166 self.verifyCommonKeys(self._log_data[0], expected_event_name='version')
167 self.verifyCommonKeys(exit_event, expected_event_name='exit')
168 # Check for 'exit' event specific fields.
169 self.assertIn('code', exit_event)
170 self.assertEqual(exit_event['code'], 2)
171
172 def test_command_event(self):
173 """Test and validate 'command' event data is valid.
174
175 Expected event log:
176 <version event>
177 <command event>
178 """
179 name = 'repo'
180 subcommands = ['init' 'this']
181 self._event_log_module.CommandEvent(name='repo', subcommands=subcommands)
182 with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir:
183 log_path = self._event_log_module.Write(path=tempdir)
184 self._log_data = self.readLog(log_path)
185
186 self.assertEqual(len(self._log_data), 2)
187 command_event = self._log_data[1]
188 self.verifyCommonKeys(self._log_data[0], expected_event_name='version')
189 self.verifyCommonKeys(command_event, expected_event_name='command')
190 # Check for 'command' event specific fields.
191 self.assertIn('name', command_event)
192 self.assertIn('subcommands', command_event)
193 self.assertEqual(command_event['name'], name)
194 self.assertEqual(command_event['subcommands'], subcommands)
195
196 def test_def_params_event_repo_config(self):
197 """Test 'def_params' event data outputs only repo config keys.
198
199 Expected event log:
200 <version event>
201 <def_param event>
202 <def_param event>
203 """
204 config = {
205 'git.foo': 'bar',
206 'repo.partialclone': 'true',
207 'repo.partialclonefilter': 'blob:none',
208 }
209 self._event_log_module.DefParamRepoEvents(config)
210
211 with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir:
212 log_path = self._event_log_module.Write(path=tempdir)
213 self._log_data = self.readLog(log_path)
214
215 self.assertEqual(len(self._log_data), 3)
216 def_param_events = self._log_data[1:]
217 self.verifyCommonKeys(self._log_data[0], expected_event_name='version')
218
219 for event in def_param_events:
220 self.verifyCommonKeys(event, expected_event_name='def_param')
221 # Check for 'def_param' event specific fields.
222 self.assertIn('param', event)
223 self.assertIn('value', event)
224 self.assertTrue(event['param'].startswith('repo.'))
225
226 def test_def_params_event_no_repo_config(self):
227 """Test 'def_params' event data won't output non-repo config keys.
228
229 Expected event log:
230 <version event>
231 """
232 config = {
233 'git.foo': 'bar',
234 'git.core.foo2': 'baz',
235 }
236 self._event_log_module.DefParamRepoEvents(config)
237
238 with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir:
239 log_path = self._event_log_module.Write(path=tempdir)
240 self._log_data = self.readLog(log_path)
241
242 self.assertEqual(len(self._log_data), 1)
243 self.verifyCommonKeys(self._log_data[0], expected_event_name='version')
244
245 def test_data_event_config(self):
246 """Test 'data' event data outputs all config keys.
247
248 Expected event log:
249 <version event>
250 <data event>
251 <data event>
252 """
253 config = {
254 'git.foo': 'bar',
255 'repo.partialclone': 'false',
256 'repo.syncstate.superproject.hassuperprojecttag': 'true',
257 'repo.syncstate.superproject.sys.argv': ['--', 'sync', 'protobuf'],
258 }
259 prefix_value = 'prefix'
260 self._event_log_module.LogDataConfigEvents(config, prefix_value)
261
262 with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir:
263 log_path = self._event_log_module.Write(path=tempdir)
264 self._log_data = self.readLog(log_path)
265
266 self.assertEqual(len(self._log_data), 5)
267 data_events = self._log_data[1:]
268 self.verifyCommonKeys(self._log_data[0], expected_event_name='version')
269
270 for event in data_events:
271 self.verifyCommonKeys(event)
272 # Check for 'data' event specific fields.
273 self.assertIn('key', event)
274 self.assertIn('value', event)
275 key = event['key']
276 key = self.remove_prefix(key, f'{prefix_value}/')
277 value = event['value']
278 self.assertEqual(self._event_log_module.GetDataEventName(value), event['event'])
279 self.assertTrue(key in config and value == config[key])
280
281 def test_error_event(self):
282 """Test and validate 'error' event data is valid.
283
284 Expected event log:
285 <version event>
286 <error event>
287 """
288 msg = 'invalid option: --cahced'
289 fmt = 'invalid option: %s'
290 self._event_log_module.ErrorEvent(msg, fmt)
291 with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir:
292 log_path = self._event_log_module.Write(path=tempdir)
293 self._log_data = self.readLog(log_path)
294
295 self.assertEqual(len(self._log_data), 2)
296 error_event = self._log_data[1]
297 self.verifyCommonKeys(self._log_data[0], expected_event_name='version')
298 self.verifyCommonKeys(error_event, expected_event_name='error')
299 # Check for 'error' event specific fields.
300 self.assertIn('msg', error_event)
301 self.assertIn('fmt', error_event)
302 self.assertEqual(error_event['msg'], msg)
303 self.assertEqual(error_event['fmt'], fmt)
304
305 def test_write_with_filename(self):
306 """Test Write() with a path to a file exits with None."""
307 self.assertIsNone(self._event_log_module.Write(path='path/to/file'))
308
309 def test_write_with_git_config(self):
310 """Test Write() uses the git config path when 'git config' call succeeds."""
311 with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir:
312 with mock.patch.object(self._event_log_module,
313 '_GetEventTargetPath', return_value=tempdir):
314 self.assertEqual(os.path.dirname(self._event_log_module.Write()), tempdir)
315
316 def test_write_no_git_config(self):
317 """Test Write() with no git config variable present exits with None."""
318 with mock.patch.object(self._event_log_module,
319 '_GetEventTargetPath', return_value=None):
320 self.assertIsNone(self._event_log_module.Write())
321
322 def test_write_non_string(self):
323 """Test Write() with non-string type for |path| throws TypeError."""
324 with self.assertRaises(TypeError):
325 self._event_log_module.Write(path=1234)
326
327
328if __name__ == '__main__':
329 unittest.main()
diff --git a/tests/test_hooks.py b/tests/test_hooks.py
new file mode 100644
index 00000000..6632b3e5
--- /dev/null
+++ b/tests/test_hooks.py
@@ -0,0 +1,55 @@
1# Copyright (C) 2019 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Unittests for the hooks.py module."""
16
17import hooks
18import unittest
19
20class RepoHookShebang(unittest.TestCase):
21 """Check shebang parsing in RepoHook."""
22
23 def test_no_shebang(self):
24 """Lines w/out shebangs should be rejected."""
25 DATA = (
26 '',
27 '#\n# foo\n',
28 '# Bad shebang in script\n#!/foo\n'
29 )
30 for data in DATA:
31 self.assertIsNone(hooks.RepoHook._ExtractInterpFromShebang(data))
32
33 def test_direct_interp(self):
34 """Lines whose shebang points directly to the interpreter."""
35 DATA = (
36 ('#!/foo', '/foo'),
37 ('#! /foo', '/foo'),
38 ('#!/bin/foo ', '/bin/foo'),
39 ('#! /usr/foo ', '/usr/foo'),
40 ('#! /usr/foo -args', '/usr/foo'),
41 )
42 for shebang, interp in DATA:
43 self.assertEqual(hooks.RepoHook._ExtractInterpFromShebang(shebang),
44 interp)
45
46 def test_env_interp(self):
47 """Lines whose shebang launches through `env`."""
48 DATA = (
49 ('#!/usr/bin/env foo', 'foo'),
50 ('#!/bin/env foo', 'foo'),
51 ('#! /bin/env /bin/foo ', '/bin/foo'),
52 )
53 for shebang, interp in DATA:
54 self.assertEqual(hooks.RepoHook._ExtractInterpFromShebang(shebang),
55 interp)
diff --git a/tests/test_manifest_xml.py b/tests/test_manifest_xml.py
new file mode 100644
index 00000000..cb3eb855
--- /dev/null
+++ b/tests/test_manifest_xml.py
@@ -0,0 +1,845 @@
1# Copyright (C) 2019 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Unittests for the manifest_xml.py module."""
16
17import os
18import platform
19import re
20import shutil
21import tempfile
22import unittest
23import xml.dom.minidom
24
25import error
26import manifest_xml
27
28
29# Invalid paths that we don't want in the filesystem.
30INVALID_FS_PATHS = (
31 '',
32 '.',
33 '..',
34 '../',
35 './',
36 './/',
37 'foo/',
38 './foo',
39 '../foo',
40 'foo/./bar',
41 'foo/../../bar',
42 '/foo',
43 './../foo',
44 '.git/foo',
45 # Check case folding.
46 '.GIT/foo',
47 'blah/.git/foo',
48 '.repo/foo',
49 '.repoconfig',
50 # Block ~ due to 8.3 filenames on Windows filesystems.
51 '~',
52 'foo~',
53 'blah/foo~',
54 # Block Unicode characters that get normalized out by filesystems.
55 u'foo\u200Cbar',
56 # Block newlines.
57 'f\n/bar',
58 'f\r/bar',
59)
60
61# Make sure platforms that use path separators (e.g. Windows) are also
62# rejected properly.
63if os.path.sep != '/':
64 INVALID_FS_PATHS += tuple(x.replace('/', os.path.sep) for x in INVALID_FS_PATHS)
65
66
67def sort_attributes(manifest):
68 """Sort the attributes of all elements alphabetically.
69
70 This is needed because different versions of the toxml() function from
71 xml.dom.minidom outputs the attributes of elements in different orders.
72 Before Python 3.8 they were output alphabetically, later versions preserve
73 the order specified by the user.
74
75 Args:
76 manifest: String containing an XML manifest.
77
78 Returns:
79 The XML manifest with the attributes of all elements sorted alphabetically.
80 """
81 new_manifest = ''
82 # This will find every element in the XML manifest, whether they have
83 # attributes or not. This simplifies recreating the manifest below.
84 matches = re.findall(r'(<[/?]?[a-z-]+\s*)((?:\S+?="[^"]+"\s*?)*)(\s*[/?]?>)', manifest)
85 for head, attrs, tail in matches:
86 m = re.findall(r'\S+?="[^"]+"', attrs)
87 new_manifest += head + ' '.join(sorted(m)) + tail
88 return new_manifest
89
90
91class ManifestParseTestCase(unittest.TestCase):
92 """TestCase for parsing manifests."""
93
94 def setUp(self):
95 self.tempdir = tempfile.mkdtemp(prefix='repo_tests')
96 self.repodir = os.path.join(self.tempdir, '.repo')
97 self.manifest_dir = os.path.join(self.repodir, 'manifests')
98 self.manifest_file = os.path.join(
99 self.repodir, manifest_xml.MANIFEST_FILE_NAME)
100 self.local_manifest_dir = os.path.join(
101 self.repodir, manifest_xml.LOCAL_MANIFESTS_DIR_NAME)
102 os.mkdir(self.repodir)
103 os.mkdir(self.manifest_dir)
104
105 # The manifest parsing really wants a git repo currently.
106 gitdir = os.path.join(self.repodir, 'manifests.git')
107 os.mkdir(gitdir)
108 with open(os.path.join(gitdir, 'config'), 'w') as fp:
109 fp.write("""[remote "origin"]
110 url = https://localhost:0/manifest
111""")
112
113 def tearDown(self):
114 shutil.rmtree(self.tempdir, ignore_errors=True)
115
116 def getXmlManifest(self, data):
117 """Helper to initialize a manifest for testing."""
118 with open(self.manifest_file, 'w') as fp:
119 fp.write(data)
120 return manifest_xml.XmlManifest(self.repodir, self.manifest_file)
121
122 @staticmethod
123 def encodeXmlAttr(attr):
124 """Encode |attr| using XML escape rules."""
125 return attr.replace('\r', '&#x000d;').replace('\n', '&#x000a;')
126
127
128class ManifestValidateFilePaths(unittest.TestCase):
129 """Check _ValidateFilePaths helper.
130
131 This doesn't access a real filesystem.
132 """
133
134 def check_both(self, *args):
135 manifest_xml.XmlManifest._ValidateFilePaths('copyfile', *args)
136 manifest_xml.XmlManifest._ValidateFilePaths('linkfile', *args)
137
138 def test_normal_path(self):
139 """Make sure good paths are accepted."""
140 self.check_both('foo', 'bar')
141 self.check_both('foo/bar', 'bar')
142 self.check_both('foo', 'bar/bar')
143 self.check_both('foo/bar', 'bar/bar')
144
145 def test_symlink_targets(self):
146 """Some extra checks for symlinks."""
147 def check(*args):
148 manifest_xml.XmlManifest._ValidateFilePaths('linkfile', *args)
149
150 # We allow symlinks to end in a slash since we allow them to point to dirs
151 # in general. Technically the slash isn't necessary.
152 check('foo/', 'bar')
153 # We allow a single '.' to get a reference to the project itself.
154 check('.', 'bar')
155
156 def test_bad_paths(self):
157 """Make sure bad paths (src & dest) are rejected."""
158 for path in INVALID_FS_PATHS:
159 self.assertRaises(
160 error.ManifestInvalidPathError, self.check_both, path, 'a')
161 self.assertRaises(
162 error.ManifestInvalidPathError, self.check_both, 'a', path)
163
164
165class ValueTests(unittest.TestCase):
166 """Check utility parsing code."""
167
168 def _get_node(self, text):
169 return xml.dom.minidom.parseString(text).firstChild
170
171 def test_bool_default(self):
172 """Check XmlBool default handling."""
173 node = self._get_node('<node/>')
174 self.assertIsNone(manifest_xml.XmlBool(node, 'a'))
175 self.assertIsNone(manifest_xml.XmlBool(node, 'a', None))
176 self.assertEqual(123, manifest_xml.XmlBool(node, 'a', 123))
177
178 node = self._get_node('<node a=""/>')
179 self.assertIsNone(manifest_xml.XmlBool(node, 'a'))
180
181 def test_bool_invalid(self):
182 """Check XmlBool invalid handling."""
183 node = self._get_node('<node a="moo"/>')
184 self.assertEqual(123, manifest_xml.XmlBool(node, 'a', 123))
185
186 def test_bool_true(self):
187 """Check XmlBool true values."""
188 for value in ('yes', 'true', '1'):
189 node = self._get_node('<node a="%s"/>' % (value,))
190 self.assertTrue(manifest_xml.XmlBool(node, 'a'))
191
192 def test_bool_false(self):
193 """Check XmlBool false values."""
194 for value in ('no', 'false', '0'):
195 node = self._get_node('<node a="%s"/>' % (value,))
196 self.assertFalse(manifest_xml.XmlBool(node, 'a'))
197
198 def test_int_default(self):
199 """Check XmlInt default handling."""
200 node = self._get_node('<node/>')
201 self.assertIsNone(manifest_xml.XmlInt(node, 'a'))
202 self.assertIsNone(manifest_xml.XmlInt(node, 'a', None))
203 self.assertEqual(123, manifest_xml.XmlInt(node, 'a', 123))
204
205 node = self._get_node('<node a=""/>')
206 self.assertIsNone(manifest_xml.XmlInt(node, 'a'))
207
208 def test_int_good(self):
209 """Check XmlInt numeric handling."""
210 for value in (-1, 0, 1, 50000):
211 node = self._get_node('<node a="%s"/>' % (value,))
212 self.assertEqual(value, manifest_xml.XmlInt(node, 'a'))
213
214 def test_int_invalid(self):
215 """Check XmlInt invalid handling."""
216 with self.assertRaises(error.ManifestParseError):
217 node = self._get_node('<node a="xx"/>')
218 manifest_xml.XmlInt(node, 'a')
219
220
221class XmlManifestTests(ManifestParseTestCase):
222 """Check manifest processing."""
223
224 def test_empty(self):
225 """Parse an 'empty' manifest file."""
226 manifest = self.getXmlManifest(
227 '<?xml version="1.0" encoding="UTF-8"?>'
228 '<manifest></manifest>')
229 self.assertEqual(manifest.remotes, {})
230 self.assertEqual(manifest.projects, [])
231
232 def test_link(self):
233 """Verify Link handling with new names."""
234 manifest = manifest_xml.XmlManifest(self.repodir, self.manifest_file)
235 with open(os.path.join(self.manifest_dir, 'foo.xml'), 'w') as fp:
236 fp.write('<manifest></manifest>')
237 manifest.Link('foo.xml')
238 with open(self.manifest_file) as fp:
239 self.assertIn('<include name="foo.xml" />', fp.read())
240
241 def test_toxml_empty(self):
242 """Verify the ToXml() helper."""
243 manifest = self.getXmlManifest(
244 '<?xml version="1.0" encoding="UTF-8"?>'
245 '<manifest></manifest>')
246 self.assertEqual(manifest.ToXml().toxml(), '<?xml version="1.0" ?><manifest/>')
247
248 def test_todict_empty(self):
249 """Verify the ToDict() helper."""
250 manifest = self.getXmlManifest(
251 '<?xml version="1.0" encoding="UTF-8"?>'
252 '<manifest></manifest>')
253 self.assertEqual(manifest.ToDict(), {})
254
255 def test_repo_hooks(self):
256 """Check repo-hooks settings."""
257 manifest = self.getXmlManifest("""
258<manifest>
259 <remote name="test-remote" fetch="http://localhost" />
260 <default remote="test-remote" revision="refs/heads/main" />
261 <project name="repohooks" path="src/repohooks"/>
262 <repo-hooks in-project="repohooks" enabled-list="a, b"/>
263</manifest>
264""")
265 self.assertEqual(manifest.repo_hooks_project.name, 'repohooks')
266 self.assertEqual(manifest.repo_hooks_project.enabled_repo_hooks, ['a', 'b'])
267
268 def test_repo_hooks_unordered(self):
269 """Check repo-hooks settings work even if the project def comes second."""
270 manifest = self.getXmlManifest("""
271<manifest>
272 <remote name="test-remote" fetch="http://localhost" />
273 <default remote="test-remote" revision="refs/heads/main" />
274 <repo-hooks in-project="repohooks" enabled-list="a, b"/>
275 <project name="repohooks" path="src/repohooks"/>
276</manifest>
277""")
278 self.assertEqual(manifest.repo_hooks_project.name, 'repohooks')
279 self.assertEqual(manifest.repo_hooks_project.enabled_repo_hooks, ['a', 'b'])
280
281 def test_unknown_tags(self):
282 """Check superproject settings."""
283 manifest = self.getXmlManifest("""
284<manifest>
285 <remote name="test-remote" fetch="http://localhost" />
286 <default remote="test-remote" revision="refs/heads/main" />
287 <superproject name="superproject"/>
288 <iankaz value="unknown (possible) future tags are ignored"/>
289 <x-custom-tag>X tags are always ignored</x-custom-tag>
290</manifest>
291""")
292 self.assertEqual(manifest.superproject['name'], 'superproject')
293 self.assertEqual(manifest.superproject['remote'].name, 'test-remote')
294 self.assertEqual(
295 sort_attributes(manifest.ToXml().toxml()),
296 '<?xml version="1.0" ?><manifest>'
297 '<remote fetch="http://localhost" name="test-remote"/>'
298 '<default remote="test-remote" revision="refs/heads/main"/>'
299 '<superproject name="superproject"/>'
300 '</manifest>')
301
302 def test_remote_annotations(self):
303 """Check remote settings."""
304 manifest = self.getXmlManifest("""
305<manifest>
306 <remote name="test-remote" fetch="http://localhost">
307 <annotation name="foo" value="bar"/>
308 </remote>
309</manifest>
310""")
311 self.assertEqual(manifest.remotes['test-remote'].annotations[0].name, 'foo')
312 self.assertEqual(manifest.remotes['test-remote'].annotations[0].value, 'bar')
313 self.assertEqual(
314 sort_attributes(manifest.ToXml().toxml()),
315 '<?xml version="1.0" ?><manifest>'
316 '<remote fetch="http://localhost" name="test-remote">'
317 '<annotation name="foo" value="bar"/>'
318 '</remote>'
319 '</manifest>')
320
321
322class IncludeElementTests(ManifestParseTestCase):
323 """Tests for <include>."""
324
325 def test_group_levels(self):
326 root_m = os.path.join(self.manifest_dir, 'root.xml')
327 with open(root_m, 'w') as fp:
328 fp.write("""
329<manifest>
330 <remote name="test-remote" fetch="http://localhost" />
331 <default remote="test-remote" revision="refs/heads/main" />
332 <include name="level1.xml" groups="level1-group" />
333 <project name="root-name1" path="root-path1" />
334 <project name="root-name2" path="root-path2" groups="r2g1,r2g2" />
335</manifest>
336""")
337 with open(os.path.join(self.manifest_dir, 'level1.xml'), 'w') as fp:
338 fp.write("""
339<manifest>
340 <include name="level2.xml" groups="level2-group" />
341 <project name="level1-name1" path="level1-path1" />
342</manifest>
343""")
344 with open(os.path.join(self.manifest_dir, 'level2.xml'), 'w') as fp:
345 fp.write("""
346<manifest>
347 <project name="level2-name1" path="level2-path1" groups="l2g1,l2g2" />
348</manifest>
349""")
350 include_m = manifest_xml.XmlManifest(self.repodir, root_m)
351 for proj in include_m.projects:
352 if proj.name == 'root-name1':
353 # Check include group not set on root level proj.
354 self.assertNotIn('level1-group', proj.groups)
355 if proj.name == 'root-name2':
356 # Check root proj group not removed.
357 self.assertIn('r2g1', proj.groups)
358 if proj.name == 'level1-name1':
359 # Check level1 proj has inherited group level 1.
360 self.assertIn('level1-group', proj.groups)
361 if proj.name == 'level2-name1':
362 # Check level2 proj has inherited group levels 1 and 2.
363 self.assertIn('level1-group', proj.groups)
364 self.assertIn('level2-group', proj.groups)
365 # Check level2 proj group not removed.
366 self.assertIn('l2g1', proj.groups)
367
368 def test_allow_bad_name_from_user(self):
369 """Check handling of bad name attribute from the user's input."""
370 def parse(name):
371 name = self.encodeXmlAttr(name)
372 manifest = self.getXmlManifest(f"""
373<manifest>
374 <remote name="default-remote" fetch="http://localhost" />
375 <default remote="default-remote" revision="refs/heads/main" />
376 <include name="{name}" />
377</manifest>
378""")
379 # Force the manifest to be parsed.
380 manifest.ToXml()
381
382 # Setup target of the include.
383 target = os.path.join(self.tempdir, 'target.xml')
384 with open(target, 'w') as fp:
385 fp.write('<manifest></manifest>')
386
387 # Include with absolute path.
388 parse(os.path.abspath(target))
389
390 # Include with relative path.
391 parse(os.path.relpath(target, self.manifest_dir))
392
393 def test_bad_name_checks(self):
394 """Check handling of bad name attribute."""
395 def parse(name):
396 name = self.encodeXmlAttr(name)
397 # Setup target of the include.
398 with open(os.path.join(self.manifest_dir, 'target.xml'), 'w') as fp:
399 fp.write(f'<manifest><include name="{name}"/></manifest>')
400
401 manifest = self.getXmlManifest("""
402<manifest>
403 <remote name="default-remote" fetch="http://localhost" />
404 <default remote="default-remote" revision="refs/heads/main" />
405 <include name="target.xml" />
406</manifest>
407""")
408 # Force the manifest to be parsed.
409 manifest.ToXml()
410
411 # Handle empty name explicitly because a different codepath rejects it.
412 with self.assertRaises(error.ManifestParseError):
413 parse('')
414
415 for path in INVALID_FS_PATHS:
416 if not path:
417 continue
418
419 with self.assertRaises(error.ManifestInvalidPathError):
420 parse(path)
421
422
423class ProjectElementTests(ManifestParseTestCase):
424 """Tests for <project>."""
425
426 def test_group(self):
427 """Check project group settings."""
428 manifest = self.getXmlManifest("""
429<manifest>
430 <remote name="test-remote" fetch="http://localhost" />
431 <default remote="test-remote" revision="refs/heads/main" />
432 <project name="test-name" path="test-path"/>
433 <project name="extras" path="path" groups="g1,g2,g1"/>
434</manifest>
435""")
436 self.assertEqual(len(manifest.projects), 2)
437 # Ordering isn't guaranteed.
438 result = {
439 manifest.projects[0].name: manifest.projects[0].groups,
440 manifest.projects[1].name: manifest.projects[1].groups,
441 }
442 project = manifest.projects[0]
443 self.assertCountEqual(
444 result['test-name'],
445 ['name:test-name', 'all', 'path:test-path'])
446 self.assertCountEqual(
447 result['extras'],
448 ['g1', 'g2', 'g1', 'name:extras', 'all', 'path:path'])
449 groupstr = 'default,platform-' + platform.system().lower()
450 self.assertEqual(groupstr, manifest.GetGroupsStr())
451 groupstr = 'g1,g2,g1'
452 manifest.manifestProject.config.SetString('manifest.groups', groupstr)
453 self.assertEqual(groupstr, manifest.GetGroupsStr())
454
455 def test_set_revision_id(self):
456 """Check setting of project's revisionId."""
457 manifest = self.getXmlManifest("""
458<manifest>
459 <remote name="default-remote" fetch="http://localhost" />
460 <default remote="default-remote" revision="refs/heads/main" />
461 <project name="test-name"/>
462</manifest>
463""")
464 self.assertEqual(len(manifest.projects), 1)
465 project = manifest.projects[0]
466 project.SetRevisionId('ABCDEF')
467 self.assertEqual(
468 sort_attributes(manifest.ToXml().toxml()),
469 '<?xml version="1.0" ?><manifest>'
470 '<remote fetch="http://localhost" name="default-remote"/>'
471 '<default remote="default-remote" revision="refs/heads/main"/>'
472 '<project name="test-name" revision="ABCDEF" upstream="refs/heads/main"/>'
473 '</manifest>')
474
475 def test_trailing_slash(self):
476 """Check handling of trailing slashes in attributes."""
477 def parse(name, path):
478 name = self.encodeXmlAttr(name)
479 path = self.encodeXmlAttr(path)
480 return self.getXmlManifest(f"""
481<manifest>
482 <remote name="default-remote" fetch="http://localhost" />
483 <default remote="default-remote" revision="refs/heads/main" />
484 <project name="{name}" path="{path}" />
485</manifest>
486""")
487
488 manifest = parse('a/path/', 'foo')
489 self.assertEqual(manifest.projects[0].gitdir,
490 os.path.join(self.tempdir, '.repo/projects/foo.git'))
491 self.assertEqual(manifest.projects[0].objdir,
492 os.path.join(self.tempdir, '.repo/project-objects/a/path.git'))
493
494 manifest = parse('a/path', 'foo/')
495 self.assertEqual(manifest.projects[0].gitdir,
496 os.path.join(self.tempdir, '.repo/projects/foo.git'))
497 self.assertEqual(manifest.projects[0].objdir,
498 os.path.join(self.tempdir, '.repo/project-objects/a/path.git'))
499
500 manifest = parse('a/path', 'foo//////')
501 self.assertEqual(manifest.projects[0].gitdir,
502 os.path.join(self.tempdir, '.repo/projects/foo.git'))
503 self.assertEqual(manifest.projects[0].objdir,
504 os.path.join(self.tempdir, '.repo/project-objects/a/path.git'))
505
506 def test_toplevel_path(self):
507 """Check handling of path=. specially."""
508 def parse(name, path):
509 name = self.encodeXmlAttr(name)
510 path = self.encodeXmlAttr(path)
511 return self.getXmlManifest(f"""
512<manifest>
513 <remote name="default-remote" fetch="http://localhost" />
514 <default remote="default-remote" revision="refs/heads/main" />
515 <project name="{name}" path="{path}" />
516</manifest>
517""")
518
519 for path in ('.', './', './/', './//'):
520 manifest = parse('server/path', path)
521 self.assertEqual(manifest.projects[0].gitdir,
522 os.path.join(self.tempdir, '.repo/projects/..git'))
523
524 def test_bad_path_name_checks(self):
525 """Check handling of bad path & name attributes."""
526 def parse(name, path):
527 name = self.encodeXmlAttr(name)
528 path = self.encodeXmlAttr(path)
529 manifest = self.getXmlManifest(f"""
530<manifest>
531 <remote name="default-remote" fetch="http://localhost" />
532 <default remote="default-remote" revision="refs/heads/main" />
533 <project name="{name}" path="{path}" />
534</manifest>
535""")
536 # Force the manifest to be parsed.
537 manifest.ToXml()
538
539 # Verify the parser is valid by default to avoid buggy tests below.
540 parse('ok', 'ok')
541
542 # Handle empty name explicitly because a different codepath rejects it.
543 # Empty path is OK because it defaults to the name field.
544 with self.assertRaises(error.ManifestParseError):
545 parse('', 'ok')
546
547 for path in INVALID_FS_PATHS:
548 if not path or path.endswith('/'):
549 continue
550
551 with self.assertRaises(error.ManifestInvalidPathError):
552 parse(path, 'ok')
553
554 # We have a dedicated test for path=".".
555 if path not in {'.'}:
556 with self.assertRaises(error.ManifestInvalidPathError):
557 parse('ok', path)
558
559
560class SuperProjectElementTests(ManifestParseTestCase):
561 """Tests for <superproject>."""
562
563 def test_superproject(self):
564 """Check superproject settings."""
565 manifest = self.getXmlManifest("""
566<manifest>
567 <remote name="test-remote" fetch="http://localhost" />
568 <default remote="test-remote" revision="refs/heads/main" />
569 <superproject name="superproject"/>
570</manifest>
571""")
572 self.assertEqual(manifest.superproject['name'], 'superproject')
573 self.assertEqual(manifest.superproject['remote'].name, 'test-remote')
574 self.assertEqual(manifest.superproject['remote'].url, 'http://localhost/superproject')
575 self.assertEqual(manifest.superproject['revision'], 'refs/heads/main')
576 self.assertEqual(
577 sort_attributes(manifest.ToXml().toxml()),
578 '<?xml version="1.0" ?><manifest>'
579 '<remote fetch="http://localhost" name="test-remote"/>'
580 '<default remote="test-remote" revision="refs/heads/main"/>'
581 '<superproject name="superproject"/>'
582 '</manifest>')
583
584 def test_superproject_revision(self):
585 """Check superproject settings with a different revision attribute"""
586 self.maxDiff = None
587 manifest = self.getXmlManifest("""
588<manifest>
589 <remote name="test-remote" fetch="http://localhost" />
590 <default remote="test-remote" revision="refs/heads/main" />
591 <superproject name="superproject" revision="refs/heads/stable" />
592</manifest>
593""")
594 self.assertEqual(manifest.superproject['name'], 'superproject')
595 self.assertEqual(manifest.superproject['remote'].name, 'test-remote')
596 self.assertEqual(manifest.superproject['remote'].url, 'http://localhost/superproject')
597 self.assertEqual(manifest.superproject['revision'], 'refs/heads/stable')
598 self.assertEqual(
599 sort_attributes(manifest.ToXml().toxml()),
600 '<?xml version="1.0" ?><manifest>'
601 '<remote fetch="http://localhost" name="test-remote"/>'
602 '<default remote="test-remote" revision="refs/heads/main"/>'
603 '<superproject name="superproject" revision="refs/heads/stable"/>'
604 '</manifest>')
605
606 def test_superproject_revision_default_negative(self):
607 """Check superproject settings with a same revision attribute"""
608 self.maxDiff = None
609 manifest = self.getXmlManifest("""
610<manifest>
611 <remote name="test-remote" fetch="http://localhost" />
612 <default remote="test-remote" revision="refs/heads/stable" />
613 <superproject name="superproject" revision="refs/heads/stable" />
614</manifest>
615""")
616 self.assertEqual(manifest.superproject['name'], 'superproject')
617 self.assertEqual(manifest.superproject['remote'].name, 'test-remote')
618 self.assertEqual(manifest.superproject['remote'].url, 'http://localhost/superproject')
619 self.assertEqual(manifest.superproject['revision'], 'refs/heads/stable')
620 self.assertEqual(
621 sort_attributes(manifest.ToXml().toxml()),
622 '<?xml version="1.0" ?><manifest>'
623 '<remote fetch="http://localhost" name="test-remote"/>'
624 '<default remote="test-remote" revision="refs/heads/stable"/>'
625 '<superproject name="superproject"/>'
626 '</manifest>')
627
628 def test_superproject_revision_remote(self):
629 """Check superproject settings with a same revision attribute"""
630 self.maxDiff = None
631 manifest = self.getXmlManifest("""
632<manifest>
633 <remote name="test-remote" fetch="http://localhost" revision="refs/heads/main" />
634 <default remote="test-remote" />
635 <superproject name="superproject" revision="refs/heads/stable" />
636</manifest>
637""")
638 self.assertEqual(manifest.superproject['name'], 'superproject')
639 self.assertEqual(manifest.superproject['remote'].name, 'test-remote')
640 self.assertEqual(manifest.superproject['remote'].url, 'http://localhost/superproject')
641 self.assertEqual(manifest.superproject['revision'], 'refs/heads/stable')
642 self.assertEqual(
643 sort_attributes(manifest.ToXml().toxml()),
644 '<?xml version="1.0" ?><manifest>'
645 '<remote fetch="http://localhost" name="test-remote" revision="refs/heads/main"/>'
646 '<default remote="test-remote"/>'
647 '<superproject name="superproject" revision="refs/heads/stable"/>'
648 '</manifest>')
649
650 def test_remote(self):
651 """Check superproject settings with a remote."""
652 manifest = self.getXmlManifest("""
653<manifest>
654 <remote name="default-remote" fetch="http://localhost" />
655 <remote name="superproject-remote" fetch="http://localhost" />
656 <default remote="default-remote" revision="refs/heads/main" />
657 <superproject name="platform/superproject" remote="superproject-remote"/>
658</manifest>
659""")
660 self.assertEqual(manifest.superproject['name'], 'platform/superproject')
661 self.assertEqual(manifest.superproject['remote'].name, 'superproject-remote')
662 self.assertEqual(manifest.superproject['remote'].url, 'http://localhost/platform/superproject')
663 self.assertEqual(manifest.superproject['revision'], 'refs/heads/main')
664 self.assertEqual(
665 sort_attributes(manifest.ToXml().toxml()),
666 '<?xml version="1.0" ?><manifest>'
667 '<remote fetch="http://localhost" name="default-remote"/>'
668 '<remote fetch="http://localhost" name="superproject-remote"/>'
669 '<default remote="default-remote" revision="refs/heads/main"/>'
670 '<superproject name="platform/superproject" remote="superproject-remote"/>'
671 '</manifest>')
672
673 def test_defalut_remote(self):
674 """Check superproject settings with a default remote."""
675 manifest = self.getXmlManifest("""
676<manifest>
677 <remote name="default-remote" fetch="http://localhost" />
678 <default remote="default-remote" revision="refs/heads/main" />
679 <superproject name="superproject" remote="default-remote"/>
680</manifest>
681""")
682 self.assertEqual(manifest.superproject['name'], 'superproject')
683 self.assertEqual(manifest.superproject['remote'].name, 'default-remote')
684 self.assertEqual(manifest.superproject['revision'], 'refs/heads/main')
685 self.assertEqual(
686 sort_attributes(manifest.ToXml().toxml()),
687 '<?xml version="1.0" ?><manifest>'
688 '<remote fetch="http://localhost" name="default-remote"/>'
689 '<default remote="default-remote" revision="refs/heads/main"/>'
690 '<superproject name="superproject"/>'
691 '</manifest>')
692
693
694class ContactinfoElementTests(ManifestParseTestCase):
695 """Tests for <contactinfo>."""
696
697 def test_contactinfo(self):
698 """Check contactinfo settings."""
699 bugurl = 'http://localhost/contactinfo'
700 manifest = self.getXmlManifest(f"""
701<manifest>
702 <contactinfo bugurl="{bugurl}"/>
703</manifest>
704""")
705 self.assertEqual(manifest.contactinfo.bugurl, bugurl)
706 self.assertEqual(
707 manifest.ToXml().toxml(),
708 '<?xml version="1.0" ?><manifest>'
709 f'<contactinfo bugurl="{bugurl}"/>'
710 '</manifest>')
711
712
713class DefaultElementTests(ManifestParseTestCase):
714 """Tests for <default>."""
715
716 def test_default(self):
717 """Check default settings."""
718 a = manifest_xml._Default()
719 a.revisionExpr = 'foo'
720 a.remote = manifest_xml._XmlRemote(name='remote')
721 b = manifest_xml._Default()
722 b.revisionExpr = 'bar'
723 self.assertEqual(a, a)
724 self.assertNotEqual(a, b)
725 self.assertNotEqual(b, a.remote)
726 self.assertNotEqual(a, 123)
727 self.assertNotEqual(a, None)
728
729
730class RemoteElementTests(ManifestParseTestCase):
731 """Tests for <remote>."""
732
733 def test_remote(self):
734 """Check remote settings."""
735 a = manifest_xml._XmlRemote(name='foo')
736 a.AddAnnotation('key1', 'value1', 'true')
737 b = manifest_xml._XmlRemote(name='foo')
738 b.AddAnnotation('key2', 'value1', 'true')
739 c = manifest_xml._XmlRemote(name='foo')
740 c.AddAnnotation('key1', 'value2', 'true')
741 d = manifest_xml._XmlRemote(name='foo')
742 d.AddAnnotation('key1', 'value1', 'false')
743 self.assertEqual(a, a)
744 self.assertNotEqual(a, b)
745 self.assertNotEqual(a, c)
746 self.assertNotEqual(a, d)
747 self.assertNotEqual(a, manifest_xml._Default())
748 self.assertNotEqual(a, 123)
749 self.assertNotEqual(a, None)
750
751
752class RemoveProjectElementTests(ManifestParseTestCase):
753 """Tests for <remove-project>."""
754
755 def test_remove_one_project(self):
756 manifest = self.getXmlManifest("""
757<manifest>
758 <remote name="default-remote" fetch="http://localhost" />
759 <default remote="default-remote" revision="refs/heads/main" />
760 <project name="myproject" />
761 <remove-project name="myproject" />
762</manifest>
763""")
764 self.assertEqual(manifest.projects, [])
765
766 def test_remove_one_project_one_remains(self):
767 manifest = self.getXmlManifest("""
768<manifest>
769 <remote name="default-remote" fetch="http://localhost" />
770 <default remote="default-remote" revision="refs/heads/main" />
771 <project name="myproject" />
772 <project name="yourproject" />
773 <remove-project name="myproject" />
774</manifest>
775""")
776
777 self.assertEqual(len(manifest.projects), 1)
778 self.assertEqual(manifest.projects[0].name, 'yourproject')
779
780 def test_remove_one_project_doesnt_exist(self):
781 with self.assertRaises(manifest_xml.ManifestParseError):
782 manifest = self.getXmlManifest("""
783<manifest>
784 <remote name="default-remote" fetch="http://localhost" />
785 <default remote="default-remote" revision="refs/heads/main" />
786 <remove-project name="myproject" />
787</manifest>
788""")
789 manifest.projects
790
791 def test_remove_one_optional_project_doesnt_exist(self):
792 manifest = self.getXmlManifest("""
793<manifest>
794 <remote name="default-remote" fetch="http://localhost" />
795 <default remote="default-remote" revision="refs/heads/main" />
796 <remove-project name="myproject" optional="true" />
797</manifest>
798""")
799 self.assertEqual(manifest.projects, [])
800
801
802class ExtendProjectElementTests(ManifestParseTestCase):
803 """Tests for <extend-project>."""
804
805 def test_extend_project_dest_path_single_match(self):
806 manifest = self.getXmlManifest("""
807<manifest>
808 <remote name="default-remote" fetch="http://localhost" />
809 <default remote="default-remote" revision="refs/heads/main" />
810 <project name="myproject" />
811 <extend-project name="myproject" dest-path="bar" />
812</manifest>
813""")
814 self.assertEqual(len(manifest.projects), 1)
815 self.assertEqual(manifest.projects[0].relpath, 'bar')
816
817 def test_extend_project_dest_path_multi_match(self):
818 with self.assertRaises(manifest_xml.ManifestParseError):
819 manifest = self.getXmlManifest("""
820<manifest>
821 <remote name="default-remote" fetch="http://localhost" />
822 <default remote="default-remote" revision="refs/heads/main" />
823 <project name="myproject" path="x" />
824 <project name="myproject" path="y" />
825 <extend-project name="myproject" dest-path="bar" />
826</manifest>
827""")
828 manifest.projects
829
830 def test_extend_project_dest_path_multi_match_path_specified(self):
831 manifest = self.getXmlManifest("""
832<manifest>
833 <remote name="default-remote" fetch="http://localhost" />
834 <default remote="default-remote" revision="refs/heads/main" />
835 <project name="myproject" path="x" />
836 <project name="myproject" path="y" />
837 <extend-project name="myproject" path="x" dest-path="bar" />
838</manifest>
839""")
840 self.assertEqual(len(manifest.projects), 2)
841 if manifest.projects[0].relpath == 'y':
842 self.assertEqual(manifest.projects[1].relpath, 'bar')
843 else:
844 self.assertEqual(manifest.projects[0].relpath, 'bar')
845 self.assertEqual(manifest.projects[1].relpath, 'y')
diff --git a/tests/test_platform_utils.py b/tests/test_platform_utils.py
new file mode 100644
index 00000000..55b7805c
--- /dev/null
+++ b/tests/test_platform_utils.py
@@ -0,0 +1,50 @@
1# Copyright 2021 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Unittests for the platform_utils.py module."""
16
17import os
18import tempfile
19import unittest
20
21import platform_utils
22
23
24class RemoveTests(unittest.TestCase):
25 """Check remove() helper."""
26
27 def testMissingOk(self):
28 """Check missing_ok handling."""
29 with tempfile.TemporaryDirectory() as tmpdir:
30 path = os.path.join(tmpdir, 'test')
31
32 # Should not fail.
33 platform_utils.remove(path, missing_ok=True)
34
35 # Should fail.
36 self.assertRaises(OSError, platform_utils.remove, path)
37 self.assertRaises(OSError, platform_utils.remove, path, missing_ok=False)
38
39 # Should not fail if it exists.
40 open(path, 'w').close()
41 platform_utils.remove(path, missing_ok=True)
42 self.assertFalse(os.path.exists(path))
43
44 open(path, 'w').close()
45 platform_utils.remove(path)
46 self.assertFalse(os.path.exists(path))
47
48 open(path, 'w').close()
49 platform_utils.remove(path, missing_ok=False)
50 self.assertFalse(os.path.exists(path))
diff --git a/tests/test_project.py b/tests/test_project.py
index 77126dff..9b2cc4e9 100644
--- a/tests/test_project.py
+++ b/tests/test_project.py
@@ -1,5 +1,3 @@
1# -*- coding:utf-8 -*-
2#
3# Copyright (C) 2019 The Android Open Source Project 1# Copyright (C) 2019 The Android Open Source Project
4# 2#
5# Licensed under the Apache License, Version 2.0 (the "License"); 3# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,8 +14,6 @@
16 14
17"""Unittests for the project.py module.""" 15"""Unittests for the project.py module."""
18 16
19from __future__ import print_function
20
21import contextlib 17import contextlib
22import os 18import os
23import shutil 19import shutil
@@ -25,7 +21,10 @@ import subprocess
25import tempfile 21import tempfile
26import unittest 22import unittest
27 23
24import error
25import git_command
28import git_config 26import git_config
27import platform_utils
29import project 28import project
30 29
31 30
@@ -36,49 +35,22 @@ def TempGitTree():
36 # Python 2 support entirely. 35 # Python 2 support entirely.
37 try: 36 try:
38 tempdir = tempfile.mkdtemp(prefix='repo-tests') 37 tempdir = tempfile.mkdtemp(prefix='repo-tests')
39 subprocess.check_call(['git', 'init'], cwd=tempdir) 38
39 # Tests need to assume, that main is default branch at init,
40 # which is not supported in config until 2.28.
41 cmd = ['git', 'init']
42 if git_command.git_require((2, 28, 0)):
43 cmd += ['--initial-branch=main']
44 else:
45 # Use template dir for init.
46 templatedir = tempfile.mkdtemp(prefix='.test-template')
47 with open(os.path.join(templatedir, 'HEAD'), 'w') as fp:
48 fp.write('ref: refs/heads/main\n')
49 cmd += ['--template', templatedir]
50 subprocess.check_call(cmd, cwd=tempdir)
40 yield tempdir 51 yield tempdir
41 finally: 52 finally:
42 shutil.rmtree(tempdir) 53 platform_utils.rmtree(tempdir)
43
44
45class RepoHookShebang(unittest.TestCase):
46 """Check shebang parsing in RepoHook."""
47
48 def test_no_shebang(self):
49 """Lines w/out shebangs should be rejected."""
50 DATA = (
51 '',
52 '# -*- coding:utf-8 -*-\n',
53 '#\n# foo\n',
54 '# Bad shebang in script\n#!/foo\n'
55 )
56 for data in DATA:
57 self.assertIsNone(project.RepoHook._ExtractInterpFromShebang(data))
58
59 def test_direct_interp(self):
60 """Lines whose shebang points directly to the interpreter."""
61 DATA = (
62 ('#!/foo', '/foo'),
63 ('#! /foo', '/foo'),
64 ('#!/bin/foo ', '/bin/foo'),
65 ('#! /usr/foo ', '/usr/foo'),
66 ('#! /usr/foo -args', '/usr/foo'),
67 )
68 for shebang, interp in DATA:
69 self.assertEqual(project.RepoHook._ExtractInterpFromShebang(shebang),
70 interp)
71
72 def test_env_interp(self):
73 """Lines whose shebang launches through `env`."""
74 DATA = (
75 ('#!/usr/bin/env foo', 'foo'),
76 ('#!/bin/env foo', 'foo'),
77 ('#! /bin/env /bin/foo ', '/bin/foo'),
78 )
79 for shebang, interp in DATA:
80 self.assertEqual(project.RepoHook._ExtractInterpFromShebang(shebang),
81 interp)
82 54
83 55
84class FakeProject(object): 56class FakeProject(object):
@@ -114,7 +86,7 @@ class ReviewableBranchTests(unittest.TestCase):
114 86
115 # Start off with the normal details. 87 # Start off with the normal details.
116 rb = project.ReviewableBranch( 88 rb = project.ReviewableBranch(
117 fakeproj, fakeproj.config.GetBranch('work'), 'master') 89 fakeproj, fakeproj.config.GetBranch('work'), 'main')
118 self.assertEqual('work', rb.name) 90 self.assertEqual('work', rb.name)
119 self.assertEqual(1, len(rb.commits)) 91 self.assertEqual(1, len(rb.commits))
120 self.assertIn('Del file', rb.commits[0]) 92 self.assertIn('Del file', rb.commits[0])
@@ -127,10 +99,239 @@ class ReviewableBranchTests(unittest.TestCase):
127 self.assertTrue(rb.date) 99 self.assertTrue(rb.date)
128 100
129 # Now delete the tracking branch! 101 # Now delete the tracking branch!
130 fakeproj.work_git.branch('-D', 'master') 102 fakeproj.work_git.branch('-D', 'main')
131 rb = project.ReviewableBranch( 103 rb = project.ReviewableBranch(
132 fakeproj, fakeproj.config.GetBranch('work'), 'master') 104 fakeproj, fakeproj.config.GetBranch('work'), 'main')
133 self.assertEqual(0, len(rb.commits)) 105 self.assertEqual(0, len(rb.commits))
134 self.assertFalse(rb.base_exists) 106 self.assertFalse(rb.base_exists)
135 # Hard to assert anything useful about this. 107 # Hard to assert anything useful about this.
136 self.assertTrue(rb.date) 108 self.assertTrue(rb.date)
109
110
111class CopyLinkTestCase(unittest.TestCase):
112 """TestCase for stub repo client checkouts.
113
114 It'll have a layout like:
115 tempdir/ # self.tempdir
116 checkout/ # self.topdir
117 git-project/ # self.worktree
118
119 Attributes:
120 tempdir: A dedicated temporary directory.
121 worktree: The top of the repo client checkout.
122 topdir: The top of a project checkout.
123 """
124
125 def setUp(self):
126 self.tempdir = tempfile.mkdtemp(prefix='repo_tests')
127 self.topdir = os.path.join(self.tempdir, 'checkout')
128 self.worktree = os.path.join(self.topdir, 'git-project')
129 os.makedirs(self.topdir)
130 os.makedirs(self.worktree)
131
132 def tearDown(self):
133 shutil.rmtree(self.tempdir, ignore_errors=True)
134
135 @staticmethod
136 def touch(path):
137 with open(path, 'w'):
138 pass
139
140 def assertExists(self, path, msg=None):
141 """Make sure |path| exists."""
142 if os.path.exists(path):
143 return
144
145 if msg is None:
146 msg = ['path is missing: %s' % path]
147 while path != '/':
148 path = os.path.dirname(path)
149 if not path:
150 # If we're given something like "foo", abort once we get to "".
151 break
152 result = os.path.exists(path)
153 msg.append('\tos.path.exists(%s): %s' % (path, result))
154 if result:
155 msg.append('\tcontents: %r' % os.listdir(path))
156 break
157 msg = '\n'.join(msg)
158
159 raise self.failureException(msg)
160
161
162class CopyFile(CopyLinkTestCase):
163 """Check _CopyFile handling."""
164
165 def CopyFile(self, src, dest):
166 return project._CopyFile(self.worktree, src, self.topdir, dest)
167
168 def test_basic(self):
169 """Basic test of copying a file from a project to the toplevel."""
170 src = os.path.join(self.worktree, 'foo.txt')
171 self.touch(src)
172 cf = self.CopyFile('foo.txt', 'foo')
173 cf._Copy()
174 self.assertExists(os.path.join(self.topdir, 'foo'))
175
176 def test_src_subdir(self):
177 """Copy a file from a subdir of a project."""
178 src = os.path.join(self.worktree, 'bar', 'foo.txt')
179 os.makedirs(os.path.dirname(src))
180 self.touch(src)
181 cf = self.CopyFile('bar/foo.txt', 'new.txt')
182 cf._Copy()
183 self.assertExists(os.path.join(self.topdir, 'new.txt'))
184
185 def test_dest_subdir(self):
186 """Copy a file to a subdir of a checkout."""
187 src = os.path.join(self.worktree, 'foo.txt')
188 self.touch(src)
189 cf = self.CopyFile('foo.txt', 'sub/dir/new.txt')
190 self.assertFalse(os.path.exists(os.path.join(self.topdir, 'sub')))
191 cf._Copy()
192 self.assertExists(os.path.join(self.topdir, 'sub', 'dir', 'new.txt'))
193
194 def test_update(self):
195 """Make sure changed files get copied again."""
196 src = os.path.join(self.worktree, 'foo.txt')
197 dest = os.path.join(self.topdir, 'bar')
198 with open(src, 'w') as f:
199 f.write('1st')
200 cf = self.CopyFile('foo.txt', 'bar')
201 cf._Copy()
202 self.assertExists(dest)
203 with open(dest) as f:
204 self.assertEqual(f.read(), '1st')
205
206 with open(src, 'w') as f:
207 f.write('2nd!')
208 cf._Copy()
209 with open(dest) as f:
210 self.assertEqual(f.read(), '2nd!')
211
212 def test_src_block_symlink(self):
213 """Do not allow reading from a symlinked path."""
214 src = os.path.join(self.worktree, 'foo.txt')
215 sym = os.path.join(self.worktree, 'sym')
216 self.touch(src)
217 platform_utils.symlink('foo.txt', sym)
218 self.assertExists(sym)
219 cf = self.CopyFile('sym', 'foo')
220 self.assertRaises(error.ManifestInvalidPathError, cf._Copy)
221
222 def test_src_block_symlink_traversal(self):
223 """Do not allow reading through a symlink dir."""
224 realfile = os.path.join(self.tempdir, 'file.txt')
225 self.touch(realfile)
226 src = os.path.join(self.worktree, 'bar', 'file.txt')
227 platform_utils.symlink(self.tempdir, os.path.join(self.worktree, 'bar'))
228 self.assertExists(src)
229 cf = self.CopyFile('bar/file.txt', 'foo')
230 self.assertRaises(error.ManifestInvalidPathError, cf._Copy)
231
232 def test_src_block_copy_from_dir(self):
233 """Do not allow copying from a directory."""
234 src = os.path.join(self.worktree, 'dir')
235 os.makedirs(src)
236 cf = self.CopyFile('dir', 'foo')
237 self.assertRaises(error.ManifestInvalidPathError, cf._Copy)
238
239 def test_dest_block_symlink(self):
240 """Do not allow writing to a symlink."""
241 src = os.path.join(self.worktree, 'foo.txt')
242 self.touch(src)
243 platform_utils.symlink('dest', os.path.join(self.topdir, 'sym'))
244 cf = self.CopyFile('foo.txt', 'sym')
245 self.assertRaises(error.ManifestInvalidPathError, cf._Copy)
246
247 def test_dest_block_symlink_traversal(self):
248 """Do not allow writing through a symlink dir."""
249 src = os.path.join(self.worktree, 'foo.txt')
250 self.touch(src)
251 platform_utils.symlink(tempfile.gettempdir(),
252 os.path.join(self.topdir, 'sym'))
253 cf = self.CopyFile('foo.txt', 'sym/foo.txt')
254 self.assertRaises(error.ManifestInvalidPathError, cf._Copy)
255
256 def test_src_block_copy_to_dir(self):
257 """Do not allow copying to a directory."""
258 src = os.path.join(self.worktree, 'foo.txt')
259 self.touch(src)
260 os.makedirs(os.path.join(self.topdir, 'dir'))
261 cf = self.CopyFile('foo.txt', 'dir')
262 self.assertRaises(error.ManifestInvalidPathError, cf._Copy)
263
264
265class LinkFile(CopyLinkTestCase):
266 """Check _LinkFile handling."""
267
268 def LinkFile(self, src, dest):
269 return project._LinkFile(self.worktree, src, self.topdir, dest)
270
271 def test_basic(self):
272 """Basic test of linking a file from a project into the toplevel."""
273 src = os.path.join(self.worktree, 'foo.txt')
274 self.touch(src)
275 lf = self.LinkFile('foo.txt', 'foo')
276 lf._Link()
277 dest = os.path.join(self.topdir, 'foo')
278 self.assertExists(dest)
279 self.assertTrue(os.path.islink(dest))
280 self.assertEqual(os.path.join('git-project', 'foo.txt'), os.readlink(dest))
281
282 def test_src_subdir(self):
283 """Link to a file in a subdir of a project."""
284 src = os.path.join(self.worktree, 'bar', 'foo.txt')
285 os.makedirs(os.path.dirname(src))
286 self.touch(src)
287 lf = self.LinkFile('bar/foo.txt', 'foo')
288 lf._Link()
289 self.assertExists(os.path.join(self.topdir, 'foo'))
290
291 def test_src_self(self):
292 """Link to the project itself."""
293 dest = os.path.join(self.topdir, 'foo', 'bar')
294 lf = self.LinkFile('.', 'foo/bar')
295 lf._Link()
296 self.assertExists(dest)
297 self.assertEqual(os.path.join('..', 'git-project'), os.readlink(dest))
298
299 def test_dest_subdir(self):
300 """Link a file to a subdir of a checkout."""
301 src = os.path.join(self.worktree, 'foo.txt')
302 self.touch(src)
303 lf = self.LinkFile('foo.txt', 'sub/dir/foo/bar')
304 self.assertFalse(os.path.exists(os.path.join(self.topdir, 'sub')))
305 lf._Link()
306 self.assertExists(os.path.join(self.topdir, 'sub', 'dir', 'foo', 'bar'))
307
308 def test_src_block_relative(self):
309 """Do not allow relative symlinks."""
310 BAD_SOURCES = (
311 './',
312 '..',
313 '../',
314 'foo/.',
315 'foo/./bar',
316 'foo/..',
317 'foo/../foo',
318 )
319 for src in BAD_SOURCES:
320 lf = self.LinkFile(src, 'foo')
321 self.assertRaises(error.ManifestInvalidPathError, lf._Link)
322
323 def test_update(self):
324 """Make sure changed targets get updated."""
325 dest = os.path.join(self.topdir, 'sym')
326
327 src = os.path.join(self.worktree, 'foo.txt')
328 self.touch(src)
329 lf = self.LinkFile('foo.txt', 'sym')
330 lf._Link()
331 self.assertEqual(os.path.join('git-project', 'foo.txt'), os.readlink(dest))
332
333 # Point the symlink somewhere else.
334 os.unlink(dest)
335 platform_utils.symlink(self.tempdir, dest)
336 lf._Link()
337 self.assertEqual(os.path.join('git-project', 'foo.txt'), os.readlink(dest))
diff --git a/tests/test_ssh.py b/tests/test_ssh.py
new file mode 100644
index 00000000..ffb5cb94
--- /dev/null
+++ b/tests/test_ssh.py
@@ -0,0 +1,74 @@
1# Copyright 2019 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Unittests for the ssh.py module."""
16
17import multiprocessing
18import subprocess
19import unittest
20from unittest import mock
21
22import ssh
23
24
25class SshTests(unittest.TestCase):
26 """Tests the ssh functions."""
27
28 def test_parse_ssh_version(self):
29 """Check _parse_ssh_version() handling."""
30 ver = ssh._parse_ssh_version('Unknown\n')
31 self.assertEqual(ver, ())
32 ver = ssh._parse_ssh_version('OpenSSH_1.0\n')
33 self.assertEqual(ver, (1, 0))
34 ver = ssh._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n')
35 self.assertEqual(ver, (6, 6, 1))
36 ver = ssh._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n')
37 self.assertEqual(ver, (7, 6))
38
39 def test_version(self):
40 """Check version() handling."""
41 with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'):
42 self.assertEqual(ssh.version(), (1, 2))
43
44 def test_context_manager_empty(self):
45 """Verify context manager with no clients works correctly."""
46 with multiprocessing.Manager() as manager:
47 with ssh.ProxyManager(manager):
48 pass
49
50 def test_context_manager_child_cleanup(self):
51 """Verify orphaned clients & masters get cleaned up."""
52 with multiprocessing.Manager() as manager:
53 with ssh.ProxyManager(manager) as ssh_proxy:
54 client = subprocess.Popen(['sleep', '964853320'])
55 ssh_proxy.add_client(client)
56 master = subprocess.Popen(['sleep', '964853321'])
57 ssh_proxy.add_master(master)
58 # If the process still exists, these will throw timeout errors.
59 client.wait(0)
60 master.wait(0)
61
62 def test_ssh_sock(self):
63 """Check sock() function."""
64 manager = multiprocessing.Manager()
65 proxy = ssh.ProxyManager(manager)
66 with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):
67 # old ssh version uses port
68 with mock.patch('ssh.version', return_value=(6, 6)):
69 self.assertTrue(proxy.sock().endswith('%p'))
70
71 proxy._sock_path = None
72 # new ssh version uses hash
73 with mock.patch('ssh.version', return_value=(6, 7)):
74 self.assertTrue(proxy.sock().endswith('%C'))
diff --git a/tests/test_subcmds.py b/tests/test_subcmds.py
new file mode 100644
index 00000000..bc53051a
--- /dev/null
+++ b/tests/test_subcmds.py
@@ -0,0 +1,73 @@
1# Copyright (C) 2020 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Unittests for the subcmds module (mostly __init__.py than subcommands)."""
16
17import optparse
18import unittest
19
20import subcmds
21
22
23class AllCommands(unittest.TestCase):
24 """Check registered all_commands."""
25
26 def test_required_basic(self):
27 """Basic checking of registered commands."""
28 # NB: We don't test all subcommands as we want to avoid "change detection"
29 # tests, so we just look for the most common/important ones here that are
30 # unlikely to ever change.
31 for cmd in {'cherry-pick', 'help', 'init', 'start', 'sync', 'upload'}:
32 self.assertIn(cmd, subcmds.all_commands)
33
34 def test_naming(self):
35 """Verify we don't add things that we shouldn't."""
36 for cmd in subcmds.all_commands:
37 # Reject filename suffixes like "help.py".
38 self.assertNotIn('.', cmd)
39
40 # Make sure all '_' were converted to '-'.
41 self.assertNotIn('_', cmd)
42
43 # Reject internal python paths like "__init__".
44 self.assertFalse(cmd.startswith('__'))
45
46 def test_help_desc_style(self):
47 """Force some consistency in option descriptions.
48
49 Python's optparse & argparse has a few default options like --help. Their
50 option description text uses lowercase sentence fragments, so enforce our
51 options follow the same style so UI is consistent.
52
53 We enforce:
54 * Text starts with lowercase.
55 * Text doesn't end with period.
56 """
57 for name, cls in subcmds.all_commands.items():
58 cmd = cls()
59 parser = cmd.OptionParser
60 for option in parser.option_list:
61 if option.help == optparse.SUPPRESS_HELP:
62 continue
63
64 c = option.help[0]
65 self.assertEqual(
66 c.lower(), c,
67 msg=f'subcmds/{name}.py: {option.get_opt_string()}: help text '
68 f'should start with lowercase: "{option.help}"')
69
70 self.assertNotEqual(
71 option.help[-1], '.',
72 msg=f'subcmds/{name}.py: {option.get_opt_string()}: help text '
73 f'should not end in a period: "{option.help}"')
diff --git a/tests/test_subcmds_init.py b/tests/test_subcmds_init.py
new file mode 100644
index 00000000..af4346de
--- /dev/null
+++ b/tests/test_subcmds_init.py
@@ -0,0 +1,49 @@
1# Copyright (C) 2020 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Unittests for the subcmds/init.py module."""
16
17import unittest
18
19from subcmds import init
20
21
22class InitCommand(unittest.TestCase):
23 """Check registered all_commands."""
24
25 def setUp(self):
26 self.cmd = init.Init()
27
28 def test_cli_parser_good(self):
29 """Check valid command line options."""
30 ARGV = (
31 [],
32 )
33 for argv in ARGV:
34 opts, args = self.cmd.OptionParser.parse_args(argv)
35 self.cmd.ValidateOptions(opts, args)
36
37 def test_cli_parser_bad(self):
38 """Check invalid command line options."""
39 ARGV = (
40 # Too many arguments.
41 ['url', 'asdf'],
42
43 # Conflicting options.
44 ['--mirror', '--archive'],
45 )
46 for argv in ARGV:
47 opts, args = self.cmd.OptionParser.parse_args(argv)
48 with self.assertRaises(SystemExit):
49 self.cmd.ValidateOptions(opts, args)
diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py
index 8ef8d48d..e9a1f64a 100644
--- a/tests/test_wrapper.py
+++ b/tests/test_wrapper.py
@@ -1,5 +1,3 @@
1# -*- coding:utf-8 -*-
2#
3# Copyright (C) 2015 The Android Open Source Project 1# Copyright (C) 2015 The Android Open Source Project
4# 2#
5# Licensed under the Apache License, Version 2.0 (the "License"); 3# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,27 +14,87 @@
16 14
17"""Unittests for the wrapper.py module.""" 15"""Unittests for the wrapper.py module."""
18 16
19from __future__ import print_function 17import contextlib
20 18from io import StringIO
21import os 19import os
20import re
21import shutil
22import sys
23import tempfile
22import unittest 24import unittest
25from unittest import mock
23 26
27import git_command
28import main
29import platform_utils
24import wrapper 30import wrapper
25 31
32
33@contextlib.contextmanager
34def TemporaryDirectory():
35 """Create a new empty git checkout for testing."""
36 # TODO(vapier): Convert this to tempfile.TemporaryDirectory once we drop
37 # Python 2 support entirely.
38 try:
39 tempdir = tempfile.mkdtemp(prefix='repo-tests')
40 yield tempdir
41 finally:
42 platform_utils.rmtree(tempdir)
43
44
26def fixture(*paths): 45def fixture(*paths):
27 """Return a path relative to tests/fixtures. 46 """Return a path relative to tests/fixtures.
28 """ 47 """
29 return os.path.join(os.path.dirname(__file__), 'fixtures', *paths) 48 return os.path.join(os.path.dirname(__file__), 'fixtures', *paths)
30 49
31class RepoWrapperUnitTest(unittest.TestCase): 50
32 """Tests helper functions in the repo wrapper 51class RepoWrapperTestCase(unittest.TestCase):
33 """ 52 """TestCase for the wrapper module."""
53
34 def setUp(self): 54 def setUp(self):
35 """Load the wrapper module every time 55 """Load the wrapper module every time."""
36 """
37 wrapper._wrapper_module = None 56 wrapper._wrapper_module = None
38 self.wrapper = wrapper.Wrapper() 57 self.wrapper = wrapper.Wrapper()
39 58
59
60class RepoWrapperUnitTest(RepoWrapperTestCase):
61 """Tests helper functions in the repo wrapper
62 """
63
64 def test_version(self):
65 """Make sure _Version works."""
66 with self.assertRaises(SystemExit) as e:
67 with mock.patch('sys.stdout', new_callable=StringIO) as stdout:
68 with mock.patch('sys.stderr', new_callable=StringIO) as stderr:
69 self.wrapper._Version()
70 self.assertEqual(0, e.exception.code)
71 self.assertEqual('', stderr.getvalue())
72 self.assertIn('repo launcher version', stdout.getvalue())
73
74 def test_python_constraints(self):
75 """The launcher should never require newer than main.py."""
76 self.assertGreaterEqual(main.MIN_PYTHON_VERSION_HARD,
77 wrapper.MIN_PYTHON_VERSION_HARD)
78 self.assertGreaterEqual(main.MIN_PYTHON_VERSION_SOFT,
79 wrapper.MIN_PYTHON_VERSION_SOFT)
80 # Make sure the versions are themselves in sync.
81 self.assertGreaterEqual(wrapper.MIN_PYTHON_VERSION_SOFT,
82 wrapper.MIN_PYTHON_VERSION_HARD)
83
84 def test_init_parser(self):
85 """Make sure 'init' GetParser works."""
86 parser = self.wrapper.GetParser(gitc_init=False)
87 opts, args = parser.parse_args([])
88 self.assertEqual([], args)
89 self.assertIsNone(opts.manifest_url)
90
91 def test_gitc_init_parser(self):
92 """Make sure 'gitc-init' GetParser works."""
93 parser = self.wrapper.GetParser(gitc_init=True)
94 opts, args = parser.parse_args([])
95 self.assertEqual([], args)
96 self.assertIsNone(opts.manifest_file)
97
40 def test_get_gitc_manifest_dir_no_gitc(self): 98 def test_get_gitc_manifest_dir_no_gitc(self):
41 """ 99 """
42 Test reading a missing gitc config file 100 Test reading a missing gitc config file
@@ -72,9 +130,442 @@ class RepoWrapperUnitTest(unittest.TestCase):
72 self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test/extra'), 'test') 130 self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test/extra'), 'test')
73 self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test'), 'test') 131 self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test'), 'test')
74 self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test/'), 'test') 132 self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test/'), 'test')
75 self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test/extra'), 'test') 133 self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test/extra'),
134 'test')
76 self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/'), None) 135 self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/'), None)
77 self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/'), None) 136 self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/'), None)
78 137
138
139class SetGitTrace2ParentSid(RepoWrapperTestCase):
140 """Check SetGitTrace2ParentSid behavior."""
141
142 KEY = 'GIT_TRACE2_PARENT_SID'
143 VALID_FORMAT = re.compile(r'^repo-[0-9]{8}T[0-9]{6}Z-P[0-9a-f]{8}$')
144
145 def test_first_set(self):
146 """Test env var not yet set."""
147 env = {}
148 self.wrapper.SetGitTrace2ParentSid(env)
149 self.assertIn(self.KEY, env)
150 value = env[self.KEY]
151 self.assertRegex(value, self.VALID_FORMAT)
152
153 def test_append(self):
154 """Test env var is appended."""
155 env = {self.KEY: 'pfx'}
156 self.wrapper.SetGitTrace2ParentSid(env)
157 self.assertIn(self.KEY, env)
158 value = env[self.KEY]
159 self.assertTrue(value.startswith('pfx/'))
160 self.assertRegex(value[4:], self.VALID_FORMAT)
161
162 def test_global_context(self):
163 """Check os.environ gets updated by default."""
164 os.environ.pop(self.KEY, None)
165 self.wrapper.SetGitTrace2ParentSid()
166 self.assertIn(self.KEY, os.environ)
167 value = os.environ[self.KEY]
168 self.assertRegex(value, self.VALID_FORMAT)
169
170
171class RunCommand(RepoWrapperTestCase):
172 """Check run_command behavior."""
173
174 def test_capture(self):
175 """Check capture_output handling."""
176 ret = self.wrapper.run_command(['echo', 'hi'], capture_output=True)
177 self.assertEqual(ret.stdout, 'hi\n')
178
179 def test_check(self):
180 """Check check handling."""
181 self.wrapper.run_command(['true'], check=False)
182 self.wrapper.run_command(['true'], check=True)
183 self.wrapper.run_command(['false'], check=False)
184 with self.assertRaises(self.wrapper.RunError):
185 self.wrapper.run_command(['false'], check=True)
186
187
188class RunGit(RepoWrapperTestCase):
189 """Check run_git behavior."""
190
191 def test_capture(self):
192 """Check capture_output handling."""
193 ret = self.wrapper.run_git('--version')
194 self.assertIn('git', ret.stdout)
195
196 def test_check(self):
197 """Check check handling."""
198 with self.assertRaises(self.wrapper.CloneFailure):
199 self.wrapper.run_git('--version-asdfasdf')
200 self.wrapper.run_git('--version-asdfasdf', check=False)
201
202
203class ParseGitVersion(RepoWrapperTestCase):
204 """Check ParseGitVersion behavior."""
205
206 def test_autoload(self):
207 """Check we can load the version from the live git."""
208 ret = self.wrapper.ParseGitVersion()
209 self.assertIsNotNone(ret)
210
211 def test_bad_ver(self):
212 """Check handling of bad git versions."""
213 ret = self.wrapper.ParseGitVersion(ver_str='asdf')
214 self.assertIsNone(ret)
215
216 def test_normal_ver(self):
217 """Check handling of normal git versions."""
218 ret = self.wrapper.ParseGitVersion(ver_str='git version 2.25.1')
219 self.assertEqual(2, ret.major)
220 self.assertEqual(25, ret.minor)
221 self.assertEqual(1, ret.micro)
222 self.assertEqual('2.25.1', ret.full)
223
224 def test_extended_ver(self):
225 """Check handling of extended distro git versions."""
226 ret = self.wrapper.ParseGitVersion(
227 ver_str='git version 1.30.50.696.g5e7596f4ac-goog')
228 self.assertEqual(1, ret.major)
229 self.assertEqual(30, ret.minor)
230 self.assertEqual(50, ret.micro)
231 self.assertEqual('1.30.50.696.g5e7596f4ac-goog', ret.full)
232
233
234class CheckGitVersion(RepoWrapperTestCase):
235 """Check _CheckGitVersion behavior."""
236
237 def test_unknown(self):
238 """Unknown versions should abort."""
239 with mock.patch.object(self.wrapper, 'ParseGitVersion', return_value=None):
240 with self.assertRaises(self.wrapper.CloneFailure):
241 self.wrapper._CheckGitVersion()
242
243 def test_old(self):
244 """Old versions should abort."""
245 with mock.patch.object(
246 self.wrapper, 'ParseGitVersion',
247 return_value=self.wrapper.GitVersion(1, 0, 0, '1.0.0')):
248 with self.assertRaises(self.wrapper.CloneFailure):
249 self.wrapper._CheckGitVersion()
250
251 def test_new(self):
252 """Newer versions should run fine."""
253 with mock.patch.object(
254 self.wrapper, 'ParseGitVersion',
255 return_value=self.wrapper.GitVersion(100, 0, 0, '100.0.0')):
256 self.wrapper._CheckGitVersion()
257
258
259class Requirements(RepoWrapperTestCase):
260 """Check Requirements handling."""
261
262 def test_missing_file(self):
263 """Don't crash if the file is missing (old version)."""
264 testdir = os.path.dirname(os.path.realpath(__file__))
265 self.assertIsNone(self.wrapper.Requirements.from_dir(testdir))
266 self.assertIsNone(self.wrapper.Requirements.from_file(
267 os.path.join(testdir, 'xxxxxxxxxxxxxxxxxxxxxxxx')))
268
269 def test_corrupt_data(self):
270 """If the file can't be parsed, don't blow up."""
271 self.assertIsNone(self.wrapper.Requirements.from_file(__file__))
272 self.assertIsNone(self.wrapper.Requirements.from_data(b'x'))
273
274 def test_valid_data(self):
275 """Make sure we can parse the file we ship."""
276 self.assertIsNotNone(self.wrapper.Requirements.from_data(b'{}'))
277 rootdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
278 self.assertIsNotNone(self.wrapper.Requirements.from_dir(rootdir))
279 self.assertIsNotNone(self.wrapper.Requirements.from_file(os.path.join(
280 rootdir, 'requirements.json')))
281
282 def test_format_ver(self):
283 """Check format_ver can format."""
284 self.assertEqual('1.2.3', self.wrapper.Requirements._format_ver((1, 2, 3)))
285 self.assertEqual('1', self.wrapper.Requirements._format_ver([1]))
286
287 def test_assert_all_unknown(self):
288 """Check assert_all works with incompatible file."""
289 reqs = self.wrapper.Requirements({})
290 reqs.assert_all()
291
292 def test_assert_all_new_repo(self):
293 """Check assert_all accepts new enough repo."""
294 reqs = self.wrapper.Requirements({'repo': {'hard': [1, 0]}})
295 reqs.assert_all()
296
297 def test_assert_all_old_repo(self):
298 """Check assert_all rejects old repo."""
299 reqs = self.wrapper.Requirements({'repo': {'hard': [99999, 0]}})
300 with self.assertRaises(SystemExit):
301 reqs.assert_all()
302
303 def test_assert_all_new_python(self):
304 """Check assert_all accepts new enough python."""
305 reqs = self.wrapper.Requirements({'python': {'hard': sys.version_info}})
306 reqs.assert_all()
307
308 def test_assert_all_old_python(self):
309 """Check assert_all rejects old python."""
310 reqs = self.wrapper.Requirements({'python': {'hard': [99999, 0]}})
311 with self.assertRaises(SystemExit):
312 reqs.assert_all()
313
314 def test_assert_ver_unknown(self):
315 """Check assert_ver works with incompatible file."""
316 reqs = self.wrapper.Requirements({})
317 reqs.assert_ver('xxx', (1, 0))
318
319 def test_assert_ver_new(self):
320 """Check assert_ver allows new enough versions."""
321 reqs = self.wrapper.Requirements({'git': {'hard': [1, 0], 'soft': [2, 0]}})
322 reqs.assert_ver('git', (1, 0))
323 reqs.assert_ver('git', (1, 5))
324 reqs.assert_ver('git', (2, 0))
325 reqs.assert_ver('git', (2, 5))
326
327 def test_assert_ver_old(self):
328 """Check assert_ver rejects old versions."""
329 reqs = self.wrapper.Requirements({'git': {'hard': [1, 0], 'soft': [2, 0]}})
330 with self.assertRaises(SystemExit):
331 reqs.assert_ver('git', (0, 5))
332
333
334class NeedSetupGnuPG(RepoWrapperTestCase):
335 """Check NeedSetupGnuPG behavior."""
336
337 def test_missing_dir(self):
338 """The ~/.repoconfig tree doesn't exist yet."""
339 with TemporaryDirectory() as tempdir:
340 self.wrapper.home_dot_repo = os.path.join(tempdir, 'foo')
341 self.assertTrue(self.wrapper.NeedSetupGnuPG())
342
343 def test_missing_keyring(self):
344 """The keyring-version file doesn't exist yet."""
345 with TemporaryDirectory() as tempdir:
346 self.wrapper.home_dot_repo = tempdir
347 self.assertTrue(self.wrapper.NeedSetupGnuPG())
348
349 def test_empty_keyring(self):
350 """The keyring-version file exists, but is empty."""
351 with TemporaryDirectory() as tempdir:
352 self.wrapper.home_dot_repo = tempdir
353 with open(os.path.join(tempdir, 'keyring-version'), 'w'):
354 pass
355 self.assertTrue(self.wrapper.NeedSetupGnuPG())
356
357 def test_old_keyring(self):
358 """The keyring-version file exists, but it's old."""
359 with TemporaryDirectory() as tempdir:
360 self.wrapper.home_dot_repo = tempdir
361 with open(os.path.join(tempdir, 'keyring-version'), 'w') as fp:
362 fp.write('1.0\n')
363 self.assertTrue(self.wrapper.NeedSetupGnuPG())
364
365 def test_new_keyring(self):
366 """The keyring-version file exists, and is up-to-date."""
367 with TemporaryDirectory() as tempdir:
368 self.wrapper.home_dot_repo = tempdir
369 with open(os.path.join(tempdir, 'keyring-version'), 'w') as fp:
370 fp.write('1000.0\n')
371 self.assertFalse(self.wrapper.NeedSetupGnuPG())
372
373
374class SetupGnuPG(RepoWrapperTestCase):
375 """Check SetupGnuPG behavior."""
376
377 def test_full(self):
378 """Make sure it works completely."""
379 with TemporaryDirectory() as tempdir:
380 self.wrapper.home_dot_repo = tempdir
381 self.wrapper.gpg_dir = os.path.join(self.wrapper.home_dot_repo, 'gnupg')
382 self.assertTrue(self.wrapper.SetupGnuPG(True))
383 with open(os.path.join(tempdir, 'keyring-version'), 'r') as fp:
384 data = fp.read()
385 self.assertEqual('.'.join(str(x) for x in self.wrapper.KEYRING_VERSION),
386 data.strip())
387
388
389class VerifyRev(RepoWrapperTestCase):
390 """Check verify_rev behavior."""
391
392 def test_verify_passes(self):
393 """Check when we have a valid signed tag."""
394 desc_result = self.wrapper.RunResult(0, 'v1.0\n', '')
395 gpg_result = self.wrapper.RunResult(0, '', '')
396 with mock.patch.object(self.wrapper, 'run_git',
397 side_effect=(desc_result, gpg_result)):
398 ret = self.wrapper.verify_rev('/', 'refs/heads/stable', '1234', True)
399 self.assertEqual('v1.0^0', ret)
400
401 def test_unsigned_commit(self):
402 """Check we fall back to signed tag when we have an unsigned commit."""
403 desc_result = self.wrapper.RunResult(0, 'v1.0-10-g1234\n', '')
404 gpg_result = self.wrapper.RunResult(0, '', '')
405 with mock.patch.object(self.wrapper, 'run_git',
406 side_effect=(desc_result, gpg_result)):
407 ret = self.wrapper.verify_rev('/', 'refs/heads/stable', '1234', True)
408 self.assertEqual('v1.0^0', ret)
409
410 def test_verify_fails(self):
411 """Check we fall back to signed tag when we have an unsigned commit."""
412 desc_result = self.wrapper.RunResult(0, 'v1.0-10-g1234\n', '')
413 gpg_result = Exception
414 with mock.patch.object(self.wrapper, 'run_git',
415 side_effect=(desc_result, gpg_result)):
416 with self.assertRaises(Exception):
417 self.wrapper.verify_rev('/', 'refs/heads/stable', '1234', True)
418
419
420class GitCheckoutTestCase(RepoWrapperTestCase):
421 """Tests that use a real/small git checkout."""
422
423 GIT_DIR = None
424 REV_LIST = None
425
426 @classmethod
427 def setUpClass(cls):
428 # Create a repo to operate on, but do it once per-class.
429 cls.GIT_DIR = tempfile.mkdtemp(prefix='repo-rev-tests')
430 run_git = wrapper.Wrapper().run_git
431
432 remote = os.path.join(cls.GIT_DIR, 'remote')
433 os.mkdir(remote)
434
435 # Tests need to assume, that main is default branch at init,
436 # which is not supported in config until 2.28.
437 if git_command.git_require((2, 28, 0)):
438 initstr = '--initial-branch=main'
439 else:
440 # Use template dir for init.
441 templatedir = tempfile.mkdtemp(prefix='.test-template')
442 with open(os.path.join(templatedir, 'HEAD'), 'w') as fp:
443 fp.write('ref: refs/heads/main\n')
444 initstr = '--template=' + templatedir
445
446 run_git('init', initstr, cwd=remote)
447 run_git('commit', '--allow-empty', '-minit', cwd=remote)
448 run_git('branch', 'stable', cwd=remote)
449 run_git('tag', 'v1.0', cwd=remote)
450 run_git('commit', '--allow-empty', '-m2nd commit', cwd=remote)
451 cls.REV_LIST = run_git('rev-list', 'HEAD', cwd=remote).stdout.splitlines()
452
453 run_git('init', cwd=cls.GIT_DIR)
454 run_git('fetch', remote, '+refs/heads/*:refs/remotes/origin/*', cwd=cls.GIT_DIR)
455
456 @classmethod
457 def tearDownClass(cls):
458 if not cls.GIT_DIR:
459 return
460
461 shutil.rmtree(cls.GIT_DIR)
462
463
464class ResolveRepoRev(GitCheckoutTestCase):
465 """Check resolve_repo_rev behavior."""
466
467 def test_explicit_branch(self):
468 """Check refs/heads/branch argument."""
469 rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/heads/stable')
470 self.assertEqual('refs/heads/stable', rrev)
471 self.assertEqual(self.REV_LIST[1], lrev)
472
473 with self.assertRaises(wrapper.CloneFailure):
474 self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/heads/unknown')
475
476 def test_explicit_tag(self):
477 """Check refs/tags/tag argument."""
478 rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/tags/v1.0')
479 self.assertEqual('refs/tags/v1.0', rrev)
480 self.assertEqual(self.REV_LIST[1], lrev)
481
482 with self.assertRaises(wrapper.CloneFailure):
483 self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/tags/unknown')
484
485 def test_branch_name(self):
486 """Check branch argument."""
487 rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'stable')
488 self.assertEqual('refs/heads/stable', rrev)
489 self.assertEqual(self.REV_LIST[1], lrev)
490
491 rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'main')
492 self.assertEqual('refs/heads/main', rrev)
493 self.assertEqual(self.REV_LIST[0], lrev)
494
495 def test_tag_name(self):
496 """Check tag argument."""
497 rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'v1.0')
498 self.assertEqual('refs/tags/v1.0', rrev)
499 self.assertEqual(self.REV_LIST[1], lrev)
500
501 def test_full_commit(self):
502 """Check specific commit argument."""
503 commit = self.REV_LIST[0]
504 rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit)
505 self.assertEqual(commit, rrev)
506 self.assertEqual(commit, lrev)
507
508 def test_partial_commit(self):
509 """Check specific (partial) commit argument."""
510 commit = self.REV_LIST[0][0:20]
511 rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit)
512 self.assertEqual(self.REV_LIST[0], rrev)
513 self.assertEqual(self.REV_LIST[0], lrev)
514
515 def test_unknown(self):
516 """Check unknown ref/commit argument."""
517 with self.assertRaises(wrapper.CloneFailure):
518 self.wrapper.resolve_repo_rev(self.GIT_DIR, 'boooooooya')
519
520
521class CheckRepoVerify(RepoWrapperTestCase):
522 """Check check_repo_verify behavior."""
523
524 def test_no_verify(self):
525 """Always fail with --no-repo-verify."""
526 self.assertFalse(self.wrapper.check_repo_verify(False))
527
528 def test_gpg_initialized(self):
529 """Should pass if gpg is setup already."""
530 with mock.patch.object(self.wrapper, 'NeedSetupGnuPG', return_value=False):
531 self.assertTrue(self.wrapper.check_repo_verify(True))
532
533 def test_need_gpg_setup(self):
534 """Should pass/fail based on gpg setup."""
535 with mock.patch.object(self.wrapper, 'NeedSetupGnuPG', return_value=True):
536 with mock.patch.object(self.wrapper, 'SetupGnuPG') as m:
537 m.return_value = True
538 self.assertTrue(self.wrapper.check_repo_verify(True))
539
540 m.return_value = False
541 self.assertFalse(self.wrapper.check_repo_verify(True))
542
543
544class CheckRepoRev(GitCheckoutTestCase):
545 """Check check_repo_rev behavior."""
546
547 def test_verify_works(self):
548 """Should pass when verification passes."""
549 with mock.patch.object(self.wrapper, 'check_repo_verify', return_value=True):
550 with mock.patch.object(self.wrapper, 'verify_rev', return_value='12345'):
551 rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, 'stable')
552 self.assertEqual('refs/heads/stable', rrev)
553 self.assertEqual('12345', lrev)
554
555 def test_verify_fails(self):
556 """Should fail when verification fails."""
557 with mock.patch.object(self.wrapper, 'check_repo_verify', return_value=True):
558 with mock.patch.object(self.wrapper, 'verify_rev', side_effect=Exception):
559 with self.assertRaises(Exception):
560 self.wrapper.check_repo_rev(self.GIT_DIR, 'stable')
561
562 def test_verify_ignore(self):
563 """Should pass when verification is disabled."""
564 with mock.patch.object(self.wrapper, 'verify_rev', side_effect=Exception):
565 rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, 'stable', repo_verify=False)
566 self.assertEqual('refs/heads/stable', rrev)
567 self.assertEqual(self.REV_LIST[1], lrev)
568
569
79if __name__ == '__main__': 570if __name__ == '__main__':
80 unittest.main() 571 unittest.main()