From ea2e330e43c182dc16b0111ebc69ee5a71ee4ce1 Mon Sep 17 00:00:00 2001 From: Gavin Mak Date: Sat, 11 Mar 2023 06:46:20 +0000 Subject: Format codebase with black and check formatting in CQ Apply rules set by https://gerrit-review.googlesource.com/c/git-repo/+/362954/ across the codebase and fix any lingering errors caught by flake8. Also check black formatting in run_tests (and CQ). Bug: b/267675342 Change-Id: I972d77649dac351150dcfeb1cd1ad0ea2efc1956 Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/363474 Reviewed-by: Mike Frysinger Tested-by: Gavin Mak Commit-Queue: Gavin Mak --- tests/conftest.py | 4 +- tests/test_editor.py | 44 +- tests/test_error.py | 60 +- tests/test_git_command.py | 222 +++--- tests/test_git_config.py | 318 ++++---- tests/test_git_superproject.py | 782 +++++++++++--------- tests/test_git_trace2_event_log.py | 725 +++++++++--------- tests/test_hooks.py | 63 +- tests/test_manifest_xml.py | 1414 ++++++++++++++++++++---------------- tests/test_platform_utils.py | 54 +- tests/test_project.py | 855 +++++++++++----------- tests/test_repo_trace.py | 68 +- tests/test_ssh.py | 90 +-- tests/test_subcmds.py | 86 +-- tests/test_subcmds_init.py | 51 +- tests/test_subcmds_sync.py | 215 +++--- tests/test_update_manpages.py | 10 +- tests/test_wrapper.py | 1029 ++++++++++++++------------ 18 files changed, 3275 insertions(+), 2815 deletions(-) (limited to 'tests') diff --git a/tests/conftest.py b/tests/conftest.py index 3e43f6d3..e1a2292a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,5 +21,5 @@ import repo_trace @pytest.fixture(autouse=True) def disable_repo_trace(tmp_path): - """Set an environment marker to relax certain strict checks for test code.""" - repo_trace._TRACE_FILE = str(tmp_path / 'TRACE_FILE_from_test') + """Set an environment marker to relax certain strict checks for test code.""" # noqa: E501 + repo_trace._TRACE_FILE = str(tmp_path / "TRACE_FILE_from_test") diff --git a/tests/test_editor.py b/tests/test_editor.py index cfd4f5ed..8f5d160e 100644 --- a/tests/test_editor.py +++ b/tests/test_editor.py @@ -20,37 +20,37 @@ from editor import Editor class EditorTestCase(unittest.TestCase): - """Take care of resetting Editor state across tests.""" + """Take care of resetting Editor state across tests.""" - def setUp(self): - self.setEditor(None) + def setUp(self): + self.setEditor(None) - def tearDown(self): - self.setEditor(None) + def tearDown(self): + self.setEditor(None) - @staticmethod - def setEditor(editor): - Editor._editor = editor + @staticmethod + def setEditor(editor): + Editor._editor = editor class GetEditor(EditorTestCase): - """Check GetEditor behavior.""" + """Check GetEditor behavior.""" - def test_basic(self): - """Basic checking of _GetEditor.""" - self.setEditor(':') - self.assertEqual(':', Editor._GetEditor()) + def test_basic(self): + """Basic checking of _GetEditor.""" + self.setEditor(":") + self.assertEqual(":", Editor._GetEditor()) class EditString(EditorTestCase): - """Check EditString behavior.""" + """Check EditString behavior.""" - def test_no_editor(self): - """Check behavior when no editor is available.""" - self.setEditor(':') - self.assertEqual('foo', Editor.EditString('foo')) + def test_no_editor(self): + """Check behavior when no editor is available.""" + self.setEditor(":") + self.assertEqual("foo", Editor.EditString("foo")) - def test_cat_editor(self): - """Check behavior when editor is `cat`.""" - self.setEditor('cat') - self.assertEqual('foo', Editor.EditString('foo')) + def test_cat_editor(self): + """Check behavior when editor is `cat`.""" + self.setEditor("cat") + self.assertEqual("foo", Editor.EditString("foo")) diff --git a/tests/test_error.py b/tests/test_error.py index 82b00c24..784e2d57 100644 --- a/tests/test_error.py +++ b/tests/test_error.py @@ -22,32 +22,34 @@ import error class PickleTests(unittest.TestCase): - """Make sure all our custom exceptions can be pickled.""" - - def getExceptions(self): - """Return all our custom exceptions.""" - for name in dir(error): - cls = getattr(error, name) - if isinstance(cls, type) and issubclass(cls, Exception): - yield cls - - def testExceptionLookup(self): - """Make sure our introspection logic works.""" - classes = list(self.getExceptions()) - self.assertIn(error.HookError, classes) - # Don't assert the exact number to avoid being a change-detector test. - self.assertGreater(len(classes), 10) - - def testPickle(self): - """Try to pickle all the exceptions.""" - for cls in self.getExceptions(): - args = inspect.getfullargspec(cls.__init__).args[1:] - obj = cls(*args) - p = pickle.dumps(obj) - try: - newobj = pickle.loads(p) - except Exception as e: # pylint: disable=broad-except - self.fail('Class %s is unable to be pickled: %s\n' - 'Incomplete super().__init__(...) call?' % (cls, e)) - self.assertIsInstance(newobj, cls) - self.assertEqual(str(obj), str(newobj)) + """Make sure all our custom exceptions can be pickled.""" + + def getExceptions(self): + """Return all our custom exceptions.""" + for name in dir(error): + cls = getattr(error, name) + if isinstance(cls, type) and issubclass(cls, Exception): + yield cls + + def testExceptionLookup(self): + """Make sure our introspection logic works.""" + classes = list(self.getExceptions()) + self.assertIn(error.HookError, classes) + # Don't assert the exact number to avoid being a change-detector test. + self.assertGreater(len(classes), 10) + + def testPickle(self): + """Try to pickle all the exceptions.""" + for cls in self.getExceptions(): + args = inspect.getfullargspec(cls.__init__).args[1:] + obj = cls(*args) + p = pickle.dumps(obj) + try: + newobj = pickle.loads(p) + except Exception as e: # pylint: disable=broad-except + self.fail( + "Class %s is unable to be pickled: %s\n" + "Incomplete super().__init__(...) call?" % (cls, e) + ) + self.assertIsInstance(newobj, cls) + self.assertEqual(str(obj), str(newobj)) diff --git a/tests/test_git_command.py b/tests/test_git_command.py index 96408a23..c4c3a4c5 100644 --- a/tests/test_git_command.py +++ b/tests/test_git_command.py @@ -19,138 +19,146 @@ import os import unittest try: - from unittest import mock + from unittest import mock except ImportError: - import mock + import mock import git_command import wrapper class GitCommandTest(unittest.TestCase): - """Tests the GitCommand class (via git_command.git).""" + """Tests the GitCommand class (via git_command.git).""" - def setUp(self): + def setUp(self): + def realpath_mock(val): + return val - def realpath_mock(val): - return val + mock.patch.object( + os.path, "realpath", side_effect=realpath_mock + ).start() - mock.patch.object(os.path, 'realpath', side_effect=realpath_mock).start() + def tearDown(self): + mock.patch.stopall() - def tearDown(self): - mock.patch.stopall() + def test_alternative_setting_when_matching(self): + r = git_command._build_env( + objdir=os.path.join("zap", "objects"), gitdir="zap" + ) - def test_alternative_setting_when_matching(self): - r = git_command._build_env( - objdir = os.path.join('zap', 'objects'), - gitdir = 'zap' - ) + self.assertIsNone(r.get("GIT_ALTERNATE_OBJECT_DIRECTORIES")) + self.assertEqual( + r.get("GIT_OBJECT_DIRECTORY"), os.path.join("zap", "objects") + ) - self.assertIsNone(r.get('GIT_ALTERNATE_OBJECT_DIRECTORIES')) - self.assertEqual(r.get('GIT_OBJECT_DIRECTORY'), os.path.join('zap', 'objects')) + def test_alternative_setting_when_different(self): + r = git_command._build_env( + objdir=os.path.join("wow", "objects"), gitdir="zap" + ) - def test_alternative_setting_when_different(self): - r = git_command._build_env( - objdir = os.path.join('wow', 'objects'), - gitdir = 'zap' - ) - - self.assertEqual(r.get('GIT_ALTERNATE_OBJECT_DIRECTORIES'), os.path.join('zap', 'objects')) - self.assertEqual(r.get('GIT_OBJECT_DIRECTORY'), os.path.join('wow', 'objects')) + self.assertEqual( + r.get("GIT_ALTERNATE_OBJECT_DIRECTORIES"), + os.path.join("zap", "objects"), + ) + self.assertEqual( + r.get("GIT_OBJECT_DIRECTORY"), os.path.join("wow", "objects") + ) class GitCallUnitTest(unittest.TestCase): - """Tests the _GitCall class (via git_command.git).""" + """Tests the _GitCall class (via git_command.git).""" - def test_version_tuple(self): - """Check git.version_tuple() handling.""" - ver = git_command.git.version_tuple() - self.assertIsNotNone(ver) + def test_version_tuple(self): + """Check git.version_tuple() handling.""" + ver = git_command.git.version_tuple() + self.assertIsNotNone(ver) - # We don't dive too deep into the values here to avoid having to update - # whenever git versions change. We do check relative to this min version - # as this is what `repo` itself requires via MIN_GIT_VERSION. - MIN_GIT_VERSION = (2, 10, 2) - self.assertTrue(isinstance(ver.major, int)) - self.assertTrue(isinstance(ver.minor, int)) - self.assertTrue(isinstance(ver.micro, int)) + # We don't dive too deep into the values here to avoid having to update + # whenever git versions change. We do check relative to this min + # version as this is what `repo` itself requires via MIN_GIT_VERSION. + MIN_GIT_VERSION = (2, 10, 2) + self.assertTrue(isinstance(ver.major, int)) + self.assertTrue(isinstance(ver.minor, int)) + self.assertTrue(isinstance(ver.micro, int)) - self.assertGreater(ver.major, MIN_GIT_VERSION[0] - 1) - self.assertGreaterEqual(ver.micro, 0) - self.assertGreaterEqual(ver.major, 0) + self.assertGreater(ver.major, MIN_GIT_VERSION[0] - 1) + self.assertGreaterEqual(ver.micro, 0) + self.assertGreaterEqual(ver.major, 0) - self.assertGreaterEqual(ver, MIN_GIT_VERSION) - self.assertLess(ver, (9999, 9999, 9999)) + self.assertGreaterEqual(ver, MIN_GIT_VERSION) + self.assertLess(ver, (9999, 9999, 9999)) - self.assertNotEqual('', ver.full) + self.assertNotEqual("", ver.full) class UserAgentUnitTest(unittest.TestCase): - """Tests the UserAgent function.""" - - def test_smoke_os(self): - """Make sure UA OS setting returns something useful.""" - os_name = git_command.user_agent.os - # We can't dive too deep because of OS/tool differences, but we can check - # the general form. - m = re.match(r'^[^ ]+$', os_name) - self.assertIsNotNone(m) - - def test_smoke_repo(self): - """Make sure repo UA returns something useful.""" - ua = git_command.user_agent.repo - # We can't dive too deep because of OS/tool differences, but we can check - # the general form. - m = re.match(r'^git-repo/[^ ]+ ([^ ]+) git/[^ ]+ Python/[0-9.]+', ua) - self.assertIsNotNone(m) - - def test_smoke_git(self): - """Make sure git UA returns something useful.""" - ua = git_command.user_agent.git - # We can't dive too deep because of OS/tool differences, but we can check - # the general form. - m = re.match(r'^git/[^ ]+ ([^ ]+) git-repo/[^ ]+', ua) - self.assertIsNotNone(m) + """Tests the UserAgent function.""" + + def test_smoke_os(self): + """Make sure UA OS setting returns something useful.""" + os_name = git_command.user_agent.os + # We can't dive too deep because of OS/tool differences, but we can + # check the general form. + m = re.match(r"^[^ ]+$", os_name) + self.assertIsNotNone(m) + + def test_smoke_repo(self): + """Make sure repo UA returns something useful.""" + ua = git_command.user_agent.repo + # We can't dive too deep because of OS/tool differences, but we can + # check the general form. + m = re.match(r"^git-repo/[^ ]+ ([^ ]+) git/[^ ]+ Python/[0-9.]+", ua) + self.assertIsNotNone(m) + + def test_smoke_git(self): + """Make sure git UA returns something useful.""" + ua = git_command.user_agent.git + # We can't dive too deep because of OS/tool differences, but we can + # check the general form. + m = re.match(r"^git/[^ ]+ ([^ ]+) git-repo/[^ ]+", ua) + self.assertIsNotNone(m) class GitRequireTests(unittest.TestCase): - """Test the git_require helper.""" - - def setUp(self): - self.wrapper = wrapper.Wrapper() - ver = self.wrapper.GitVersion(1, 2, 3, 4) - mock.patch.object(git_command.git, 'version_tuple', return_value=ver).start() - - def tearDown(self): - mock.patch.stopall() - - def test_older_nonfatal(self): - """Test non-fatal require calls with old versions.""" - self.assertFalse(git_command.git_require((2,))) - self.assertFalse(git_command.git_require((1, 3))) - self.assertFalse(git_command.git_require((1, 2, 4))) - self.assertFalse(git_command.git_require((1, 2, 3, 5))) - - def test_newer_nonfatal(self): - """Test non-fatal require calls with newer versions.""" - self.assertTrue(git_command.git_require((0,))) - self.assertTrue(git_command.git_require((1, 0))) - self.assertTrue(git_command.git_require((1, 2, 0))) - self.assertTrue(git_command.git_require((1, 2, 3, 0))) - - def test_equal_nonfatal(self): - """Test require calls with equal values.""" - self.assertTrue(git_command.git_require((1, 2, 3, 4), fail=False)) - self.assertTrue(git_command.git_require((1, 2, 3, 4), fail=True)) - - def test_older_fatal(self): - """Test fatal require calls with old versions.""" - with self.assertRaises(SystemExit) as e: - git_command.git_require((2,), fail=True) - self.assertNotEqual(0, e.code) - - def test_older_fatal_msg(self): - """Test fatal require calls with old versions and message.""" - with self.assertRaises(SystemExit) as e: - git_command.git_require((2,), fail=True, msg='so sad') - self.assertNotEqual(0, e.code) + """Test the git_require helper.""" + + def setUp(self): + self.wrapper = wrapper.Wrapper() + ver = self.wrapper.GitVersion(1, 2, 3, 4) + mock.patch.object( + git_command.git, "version_tuple", return_value=ver + ).start() + + def tearDown(self): + mock.patch.stopall() + + def test_older_nonfatal(self): + """Test non-fatal require calls with old versions.""" + self.assertFalse(git_command.git_require((2,))) + self.assertFalse(git_command.git_require((1, 3))) + self.assertFalse(git_command.git_require((1, 2, 4))) + self.assertFalse(git_command.git_require((1, 2, 3, 5))) + + def test_newer_nonfatal(self): + """Test non-fatal require calls with newer versions.""" + self.assertTrue(git_command.git_require((0,))) + self.assertTrue(git_command.git_require((1, 0))) + self.assertTrue(git_command.git_require((1, 2, 0))) + self.assertTrue(git_command.git_require((1, 2, 3, 0))) + + def test_equal_nonfatal(self): + """Test require calls with equal values.""" + self.assertTrue(git_command.git_require((1, 2, 3, 4), fail=False)) + self.assertTrue(git_command.git_require((1, 2, 3, 4), fail=True)) + + def test_older_fatal(self): + """Test fatal require calls with old versions.""" + with self.assertRaises(SystemExit) as e: + git_command.git_require((2,), fail=True) + self.assertNotEqual(0, e.code) + + def test_older_fatal_msg(self): + """Test fatal require calls with old versions and message.""" + with self.assertRaises(SystemExit) as e: + git_command.git_require((2,), fail=True, msg="so sad") + self.assertNotEqual(0, e.code) diff --git a/tests/test_git_config.py b/tests/test_git_config.py index 3b0aa8b4..a44dca0f 100644 --- a/tests/test_git_config.py +++ b/tests/test_git_config.py @@ -22,167 +22,169 @@ import git_config def fixture(*paths): - """Return a path relative to test/fixtures. - """ - return os.path.join(os.path.dirname(__file__), 'fixtures', *paths) + """Return a path relative to test/fixtures.""" + return os.path.join(os.path.dirname(__file__), "fixtures", *paths) class GitConfigReadOnlyTests(unittest.TestCase): - """Read-only tests of the GitConfig class.""" - - def setUp(self): - """Create a GitConfig object using the test.gitconfig fixture. - """ - config_fixture = fixture('test.gitconfig') - self.config = git_config.GitConfig(config_fixture) - - def test_GetString_with_empty_config_values(self): - """ - Test config entries with no value. - - [section] - empty - - """ - val = self.config.GetString('section.empty') - self.assertEqual(val, None) - - def test_GetString_with_true_value(self): - """ - Test config entries with a string value. - - [section] - nonempty = true - - """ - val = self.config.GetString('section.nonempty') - self.assertEqual(val, 'true') - - def test_GetString_from_missing_file(self): - """ - Test missing config file - """ - config_fixture = fixture('not.present.gitconfig') - config = git_config.GitConfig(config_fixture) - val = config.GetString('empty') - self.assertEqual(val, None) - - def test_GetBoolean_undefined(self): - """Test GetBoolean on key that doesn't exist.""" - self.assertIsNone(self.config.GetBoolean('section.missing')) - - def test_GetBoolean_invalid(self): - """Test GetBoolean on invalid boolean value.""" - self.assertIsNone(self.config.GetBoolean('section.boolinvalid')) - - def test_GetBoolean_true(self): - """Test GetBoolean on valid true boolean.""" - self.assertTrue(self.config.GetBoolean('section.booltrue')) - - def test_GetBoolean_false(self): - """Test GetBoolean on valid false boolean.""" - self.assertFalse(self.config.GetBoolean('section.boolfalse')) - - def test_GetInt_undefined(self): - """Test GetInt on key that doesn't exist.""" - self.assertIsNone(self.config.GetInt('section.missing')) - - def test_GetInt_invalid(self): - """Test GetInt on invalid integer value.""" - self.assertIsNone(self.config.GetBoolean('section.intinvalid')) - - def test_GetInt_valid(self): - """Test GetInt on valid integers.""" - TESTS = ( - ('inthex', 16), - ('inthexk', 16384), - ('int', 10), - ('intk', 10240), - ('intm', 10485760), - ('intg', 10737418240), - ) - for key, value in TESTS: - self.assertEqual(value, self.config.GetInt('section.%s' % (key,))) + """Read-only tests of the GitConfig class.""" + + def setUp(self): + """Create a GitConfig object using the test.gitconfig fixture.""" + config_fixture = fixture("test.gitconfig") + self.config = git_config.GitConfig(config_fixture) + + def test_GetString_with_empty_config_values(self): + """ + Test config entries with no value. + + [section] + empty + + """ + val = self.config.GetString("section.empty") + self.assertEqual(val, None) + + def test_GetString_with_true_value(self): + """ + Test config entries with a string value. + + [section] + nonempty = true + + """ + val = self.config.GetString("section.nonempty") + self.assertEqual(val, "true") + + def test_GetString_from_missing_file(self): + """ + Test missing config file + """ + config_fixture = fixture("not.present.gitconfig") + config = git_config.GitConfig(config_fixture) + val = config.GetString("empty") + self.assertEqual(val, None) + + def test_GetBoolean_undefined(self): + """Test GetBoolean on key that doesn't exist.""" + self.assertIsNone(self.config.GetBoolean("section.missing")) + + def test_GetBoolean_invalid(self): + """Test GetBoolean on invalid boolean value.""" + self.assertIsNone(self.config.GetBoolean("section.boolinvalid")) + + def test_GetBoolean_true(self): + """Test GetBoolean on valid true boolean.""" + self.assertTrue(self.config.GetBoolean("section.booltrue")) + + def test_GetBoolean_false(self): + """Test GetBoolean on valid false boolean.""" + self.assertFalse(self.config.GetBoolean("section.boolfalse")) + + def test_GetInt_undefined(self): + """Test GetInt on key that doesn't exist.""" + self.assertIsNone(self.config.GetInt("section.missing")) + + def test_GetInt_invalid(self): + """Test GetInt on invalid integer value.""" + self.assertIsNone(self.config.GetBoolean("section.intinvalid")) + + def test_GetInt_valid(self): + """Test GetInt on valid integers.""" + TESTS = ( + ("inthex", 16), + ("inthexk", 16384), + ("int", 10), + ("intk", 10240), + ("intm", 10485760), + ("intg", 10737418240), + ) + for key, value in TESTS: + self.assertEqual(value, self.config.GetInt("section.%s" % (key,))) class GitConfigReadWriteTests(unittest.TestCase): - """Read/write tests of the GitConfig class.""" - - def setUp(self): - self.tmpfile = tempfile.NamedTemporaryFile() - self.config = self.get_config() - - def get_config(self): - """Get a new GitConfig instance.""" - return git_config.GitConfig(self.tmpfile.name) - - def test_SetString(self): - """Test SetString behavior.""" - # Set a value. - self.assertIsNone(self.config.GetString('foo.bar')) - self.config.SetString('foo.bar', 'val') - self.assertEqual('val', self.config.GetString('foo.bar')) - - # Make sure the value was actually written out. - config = self.get_config() - self.assertEqual('val', config.GetString('foo.bar')) - - # Update the value. - self.config.SetString('foo.bar', 'valll') - self.assertEqual('valll', self.config.GetString('foo.bar')) - config = self.get_config() - self.assertEqual('valll', config.GetString('foo.bar')) - - # Delete the value. - self.config.SetString('foo.bar', None) - self.assertIsNone(self.config.GetString('foo.bar')) - config = self.get_config() - self.assertIsNone(config.GetString('foo.bar')) - - def test_SetBoolean(self): - """Test SetBoolean behavior.""" - # Set a true value. - self.assertIsNone(self.config.GetBoolean('foo.bar')) - for val in (True, 1): - self.config.SetBoolean('foo.bar', val) - self.assertTrue(self.config.GetBoolean('foo.bar')) - - # Make sure the value was actually written out. - config = self.get_config() - self.assertTrue(config.GetBoolean('foo.bar')) - self.assertEqual('true', config.GetString('foo.bar')) - - # Set a false value. - for val in (False, 0): - self.config.SetBoolean('foo.bar', val) - self.assertFalse(self.config.GetBoolean('foo.bar')) - - # Make sure the value was actually written out. - config = self.get_config() - self.assertFalse(config.GetBoolean('foo.bar')) - self.assertEqual('false', config.GetString('foo.bar')) - - # Delete the value. - self.config.SetBoolean('foo.bar', None) - self.assertIsNone(self.config.GetBoolean('foo.bar')) - config = self.get_config() - self.assertIsNone(config.GetBoolean('foo.bar')) - - def test_GetSyncAnalysisStateData(self): - """Test config entries with a sync state analysis data.""" - superproject_logging_data = {} - superproject_logging_data['test'] = False - options = type('options', (object,), {})() - options.verbose = 'true' - options.mp_update = 'false' - TESTS = ( - ('superproject.test', 'false'), - ('options.verbose', 'true'), - ('options.mpupdate', 'false'), - ('main.version', '1'), - ) - self.config.UpdateSyncAnalysisState(options, superproject_logging_data) - sync_data = self.config.GetSyncAnalysisStateData() - for key, value in TESTS: - self.assertEqual(sync_data[f'{git_config.SYNC_STATE_PREFIX}{key}'], value) - self.assertTrue(sync_data[f'{git_config.SYNC_STATE_PREFIX}main.synctime']) + """Read/write tests of the GitConfig class.""" + + def setUp(self): + self.tmpfile = tempfile.NamedTemporaryFile() + self.config = self.get_config() + + def get_config(self): + """Get a new GitConfig instance.""" + return git_config.GitConfig(self.tmpfile.name) + + def test_SetString(self): + """Test SetString behavior.""" + # Set a value. + self.assertIsNone(self.config.GetString("foo.bar")) + self.config.SetString("foo.bar", "val") + self.assertEqual("val", self.config.GetString("foo.bar")) + + # Make sure the value was actually written out. + config = self.get_config() + self.assertEqual("val", config.GetString("foo.bar")) + + # Update the value. + self.config.SetString("foo.bar", "valll") + self.assertEqual("valll", self.config.GetString("foo.bar")) + config = self.get_config() + self.assertEqual("valll", config.GetString("foo.bar")) + + # Delete the value. + self.config.SetString("foo.bar", None) + self.assertIsNone(self.config.GetString("foo.bar")) + config = self.get_config() + self.assertIsNone(config.GetString("foo.bar")) + + def test_SetBoolean(self): + """Test SetBoolean behavior.""" + # Set a true value. + self.assertIsNone(self.config.GetBoolean("foo.bar")) + for val in (True, 1): + self.config.SetBoolean("foo.bar", val) + self.assertTrue(self.config.GetBoolean("foo.bar")) + + # Make sure the value was actually written out. + config = self.get_config() + self.assertTrue(config.GetBoolean("foo.bar")) + self.assertEqual("true", config.GetString("foo.bar")) + + # Set a false value. + for val in (False, 0): + self.config.SetBoolean("foo.bar", val) + self.assertFalse(self.config.GetBoolean("foo.bar")) + + # Make sure the value was actually written out. + config = self.get_config() + self.assertFalse(config.GetBoolean("foo.bar")) + self.assertEqual("false", config.GetString("foo.bar")) + + # Delete the value. + self.config.SetBoolean("foo.bar", None) + self.assertIsNone(self.config.GetBoolean("foo.bar")) + config = self.get_config() + self.assertIsNone(config.GetBoolean("foo.bar")) + + def test_GetSyncAnalysisStateData(self): + """Test config entries with a sync state analysis data.""" + superproject_logging_data = {} + superproject_logging_data["test"] = False + options = type("options", (object,), {})() + options.verbose = "true" + options.mp_update = "false" + TESTS = ( + ("superproject.test", "false"), + ("options.verbose", "true"), + ("options.mpupdate", "false"), + ("main.version", "1"), + ) + self.config.UpdateSyncAnalysisState(options, superproject_logging_data) + sync_data = self.config.GetSyncAnalysisStateData() + for key, value in TESTS: + self.assertEqual( + sync_data[f"{git_config.SYNC_STATE_PREFIX}{key}"], value + ) + self.assertTrue( + sync_data[f"{git_config.SYNC_STATE_PREFIX}main.synctime"] + ) diff --git a/tests/test_git_superproject.py b/tests/test_git_superproject.py index b9b597a6..eb542c60 100644 --- a/tests/test_git_superproject.py +++ b/tests/test_git_superproject.py @@ -28,297 +28,369 @@ from test_manifest_xml import sort_attributes class SuperprojectTestCase(unittest.TestCase): - """TestCase for the Superproject module.""" - - PARENT_SID_KEY = 'GIT_TRACE2_PARENT_SID' - PARENT_SID_VALUE = 'parent_sid' - SELF_SID_REGEX = r'repo-\d+T\d+Z-.*' - FULL_SID_REGEX = r'^%s/%s' % (PARENT_SID_VALUE, SELF_SID_REGEX) - - def setUp(self): - """Set up superproject every time.""" - self.tempdirobj = tempfile.TemporaryDirectory(prefix='repo_tests') - self.tempdir = self.tempdirobj.name - self.repodir = os.path.join(self.tempdir, '.repo') - self.manifest_file = os.path.join( - self.repodir, manifest_xml.MANIFEST_FILE_NAME) - os.mkdir(self.repodir) - self.platform = platform.system().lower() - - # By default we initialize with the expected case where - # repo launches us (so GIT_TRACE2_PARENT_SID is set). - env = { - self.PARENT_SID_KEY: self.PARENT_SID_VALUE, - } - self.git_event_log = git_trace2_event_log.EventLog(env=env) - - # The manifest parsing really wants a git repo currently. - gitdir = os.path.join(self.repodir, 'manifests.git') - os.mkdir(gitdir) - with open(os.path.join(gitdir, 'config'), 'w') as fp: - fp.write("""[remote "origin"] + """TestCase for the Superproject module.""" + + PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID" + PARENT_SID_VALUE = "parent_sid" + SELF_SID_REGEX = r"repo-\d+T\d+Z-.*" + FULL_SID_REGEX = r"^%s/%s" % (PARENT_SID_VALUE, SELF_SID_REGEX) + + def setUp(self): + """Set up superproject every time.""" + self.tempdirobj = tempfile.TemporaryDirectory(prefix="repo_tests") + self.tempdir = self.tempdirobj.name + self.repodir = os.path.join(self.tempdir, ".repo") + self.manifest_file = os.path.join( + self.repodir, manifest_xml.MANIFEST_FILE_NAME + ) + os.mkdir(self.repodir) + self.platform = platform.system().lower() + + # By default we initialize with the expected case where + # repo launches us (so GIT_TRACE2_PARENT_SID is set). + env = { + self.PARENT_SID_KEY: self.PARENT_SID_VALUE, + } + self.git_event_log = git_trace2_event_log.EventLog(env=env) + + # The manifest parsing really wants a git repo currently. + gitdir = os.path.join(self.repodir, "manifests.git") + os.mkdir(gitdir) + with open(os.path.join(gitdir, "config"), "w") as fp: + fp.write( + """[remote "origin"] url = https://localhost:0/manifest -""") +""" + ) - manifest = self.getXmlManifest(""" + manifest = self.getXmlManifest( + """ - -""") - self._superproject = git_superproject.Superproject( - manifest, name='superproject', - remote=manifest.remotes.get('default-remote').ToRemoteSpec('superproject'), - revision='refs/heads/main') - - def tearDown(self): - """Tear down superproject every time.""" - self.tempdirobj.cleanup() - - def getXmlManifest(self, data): - """Helper to initialize a manifest for testing.""" - with open(self.manifest_file, 'w') as fp: - fp.write(data) - return manifest_xml.XmlManifest(self.repodir, self.manifest_file) - - def verifyCommonKeys(self, log_entry, expected_event_name, full_sid=True): - """Helper function to verify common event log keys.""" - self.assertIn('event', log_entry) - self.assertIn('sid', log_entry) - self.assertIn('thread', log_entry) - self.assertIn('time', log_entry) - - # Do basic data format validation. - self.assertEqual(expected_event_name, log_entry['event']) - if full_sid: - self.assertRegex(log_entry['sid'], self.FULL_SID_REGEX) - else: - self.assertRegex(log_entry['sid'], self.SELF_SID_REGEX) - self.assertRegex(log_entry['time'], r'^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$') - - def readLog(self, log_path): - """Helper function to read log data into a list.""" - log_data = [] - with open(log_path, mode='rb') as f: - for line in f: - log_data.append(json.loads(line)) - return log_data - - def verifyErrorEvent(self): - """Helper to verify that error event is written.""" - - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self.git_event_log.Write(path=tempdir) - self.log_data = self.readLog(log_path) - - self.assertEqual(len(self.log_data), 2) - error_event = self.log_data[1] - self.verifyCommonKeys(self.log_data[0], expected_event_name='version') - self.verifyCommonKeys(error_event, expected_event_name='error') - # Check for 'error' event specific fields. - self.assertIn('msg', error_event) - self.assertIn('fmt', error_event) - - def test_superproject_get_superproject_no_superproject(self): - """Test with no url.""" - manifest = self.getXmlManifest(""" +""" + ) + self._superproject = git_superproject.Superproject( + manifest, + name="superproject", + remote=manifest.remotes.get("default-remote").ToRemoteSpec( + "superproject" + ), + revision="refs/heads/main", + ) + + def tearDown(self): + """Tear down superproject every time.""" + self.tempdirobj.cleanup() + + def getXmlManifest(self, data): + """Helper to initialize a manifest for testing.""" + with open(self.manifest_file, "w") as fp: + fp.write(data) + return manifest_xml.XmlManifest(self.repodir, self.manifest_file) + + def verifyCommonKeys(self, log_entry, expected_event_name, full_sid=True): + """Helper function to verify common event log keys.""" + self.assertIn("event", log_entry) + self.assertIn("sid", log_entry) + self.assertIn("thread", log_entry) + self.assertIn("time", log_entry) + + # Do basic data format validation. + self.assertEqual(expected_event_name, log_entry["event"]) + if full_sid: + self.assertRegex(log_entry["sid"], self.FULL_SID_REGEX) + else: + self.assertRegex(log_entry["sid"], self.SELF_SID_REGEX) + self.assertRegex(log_entry["time"], r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$") + + def readLog(self, log_path): + """Helper function to read log data into a list.""" + log_data = [] + with open(log_path, mode="rb") as f: + for line in f: + log_data.append(json.loads(line)) + return log_data + + def verifyErrorEvent(self): + """Helper to verify that error event is written.""" + + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self.git_event_log.Write(path=tempdir) + self.log_data = self.readLog(log_path) + + self.assertEqual(len(self.log_data), 2) + error_event = self.log_data[1] + self.verifyCommonKeys(self.log_data[0], expected_event_name="version") + self.verifyCommonKeys(error_event, expected_event_name="error") + # Check for 'error' event specific fields. + self.assertIn("msg", error_event) + self.assertIn("fmt", error_event) + + def test_superproject_get_superproject_no_superproject(self): + """Test with no url.""" + manifest = self.getXmlManifest( + """ -""") - self.assertIsNone(manifest.superproject) - - def test_superproject_get_superproject_invalid_url(self): - """Test with an invalid url.""" - manifest = self.getXmlManifest(""" +""" + ) + self.assertIsNone(manifest.superproject) + + def test_superproject_get_superproject_invalid_url(self): + """Test with an invalid url.""" + manifest = self.getXmlManifest( + """ -""") - superproject = git_superproject.Superproject( - manifest, name='superproject', - remote=manifest.remotes.get('test-remote').ToRemoteSpec('superproject'), - revision='refs/heads/main') - sync_result = superproject.Sync(self.git_event_log) - self.assertFalse(sync_result.success) - self.assertTrue(sync_result.fatal) - - def test_superproject_get_superproject_invalid_branch(self): - """Test with an invalid branch.""" - manifest = self.getXmlManifest(""" +""" + ) + superproject = git_superproject.Superproject( + manifest, + name="superproject", + remote=manifest.remotes.get("test-remote").ToRemoteSpec( + "superproject" + ), + revision="refs/heads/main", + ) + sync_result = superproject.Sync(self.git_event_log) + self.assertFalse(sync_result.success) + self.assertTrue(sync_result.fatal) + + def test_superproject_get_superproject_invalid_branch(self): + """Test with an invalid branch.""" + manifest = self.getXmlManifest( + """ -""") - self._superproject = git_superproject.Superproject( - manifest, name='superproject', - remote=manifest.remotes.get('test-remote').ToRemoteSpec('superproject'), - revision='refs/heads/main') - with mock.patch.object(self._superproject, '_branch', 'junk'): - sync_result = self._superproject.Sync(self.git_event_log) - self.assertFalse(sync_result.success) - self.assertTrue(sync_result.fatal) - self.verifyErrorEvent() - - def test_superproject_get_superproject_mock_init(self): - """Test with _Init failing.""" - with mock.patch.object(self._superproject, '_Init', return_value=False): - sync_result = self._superproject.Sync(self.git_event_log) - self.assertFalse(sync_result.success) - self.assertTrue(sync_result.fatal) - - def test_superproject_get_superproject_mock_fetch(self): - """Test with _Fetch failing.""" - with mock.patch.object(self._superproject, '_Init', return_value=True): - os.mkdir(self._superproject._superproject_path) - with mock.patch.object(self._superproject, '_Fetch', return_value=False): - sync_result = self._superproject.Sync(self.git_event_log) - self.assertFalse(sync_result.success) - self.assertTrue(sync_result.fatal) - - def test_superproject_get_all_project_commit_ids_mock_ls_tree(self): - """Test with LsTree being a mock.""" - data = ('120000 blob 158258bdf146f159218e2b90f8b699c4d85b5804\tAndroid.bp\x00' - '160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00' - '160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00' - '120000 blob acc2cbdf438f9d2141f0ae424cec1d8fc4b5d97f\tbootstrap.bash\x00' - '160000 commit ade9b7a0d874e25fff4bf2552488825c6f111928\tbuild/bazel\x00') - with mock.patch.object(self._superproject, '_Init', return_value=True): - with mock.patch.object(self._superproject, '_Fetch', return_value=True): - with mock.patch.object(self._superproject, '_LsTree', return_value=data): - commit_ids_result = self._superproject._GetAllProjectsCommitIds() - self.assertEqual(commit_ids_result.commit_ids, { - 'art': '2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea', - 'bootable/recovery': 'e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06', - 'build/bazel': 'ade9b7a0d874e25fff4bf2552488825c6f111928' - }) - self.assertFalse(commit_ids_result.fatal) - - def test_superproject_write_manifest_file(self): - """Test with writing manifest to a file after setting revisionId.""" - self.assertEqual(len(self._superproject._manifest.projects), 1) - project = self._superproject._manifest.projects[0] - project.SetRevisionId('ABCDEF') - # Create temporary directory so that it can write the file. - os.mkdir(self._superproject._superproject_path) - manifest_path = self._superproject._WriteManifestFile() - self.assertIsNotNone(manifest_path) - with open(manifest_path, 'r') as fp: - manifest_xml_data = fp.read() - self.assertEqual( - sort_attributes(manifest_xml_data), - '' - '' - '' - '' - '' - '') - - def test_superproject_update_project_revision_id(self): - """Test with LsTree being a mock.""" - self.assertEqual(len(self._superproject._manifest.projects), 1) - projects = self._superproject._manifest.projects - data = ('160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00' - '160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00') - with mock.patch.object(self._superproject, '_Init', return_value=True): - with mock.patch.object(self._superproject, '_Fetch', return_value=True): - with mock.patch.object(self._superproject, - '_LsTree', - return_value=data): - # Create temporary directory so that it can write the file. - os.mkdir(self._superproject._superproject_path) - update_result = self._superproject.UpdateProjectsRevisionId(projects, self.git_event_log) - self.assertIsNotNone(update_result.manifest_path) - self.assertFalse(update_result.fatal) - with open(update_result.manifest_path, 'r') as fp: +""" + ) + self._superproject = git_superproject.Superproject( + manifest, + name="superproject", + remote=manifest.remotes.get("test-remote").ToRemoteSpec( + "superproject" + ), + revision="refs/heads/main", + ) + with mock.patch.object(self._superproject, "_branch", "junk"): + sync_result = self._superproject.Sync(self.git_event_log) + self.assertFalse(sync_result.success) + self.assertTrue(sync_result.fatal) + self.verifyErrorEvent() + + def test_superproject_get_superproject_mock_init(self): + """Test with _Init failing.""" + with mock.patch.object(self._superproject, "_Init", return_value=False): + sync_result = self._superproject.Sync(self.git_event_log) + self.assertFalse(sync_result.success) + self.assertTrue(sync_result.fatal) + + def test_superproject_get_superproject_mock_fetch(self): + """Test with _Fetch failing.""" + with mock.patch.object(self._superproject, "_Init", return_value=True): + os.mkdir(self._superproject._superproject_path) + with mock.patch.object( + self._superproject, "_Fetch", return_value=False + ): + sync_result = self._superproject.Sync(self.git_event_log) + self.assertFalse(sync_result.success) + self.assertTrue(sync_result.fatal) + + def test_superproject_get_all_project_commit_ids_mock_ls_tree(self): + """Test with LsTree being a mock.""" + data = ( + "120000 blob 158258bdf146f159218e2b90f8b699c4d85b5804\tAndroid.bp\x00" + "160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00" + "160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00" + "120000 blob acc2cbdf438f9d2141f0ae424cec1d8fc4b5d97f\tbootstrap.bash\x00" + "160000 commit ade9b7a0d874e25fff4bf2552488825c6f111928\tbuild/bazel\x00" + ) + with mock.patch.object(self._superproject, "_Init", return_value=True): + with mock.patch.object( + self._superproject, "_Fetch", return_value=True + ): + with mock.patch.object( + self._superproject, "_LsTree", return_value=data + ): + commit_ids_result = ( + self._superproject._GetAllProjectsCommitIds() + ) + self.assertEqual( + commit_ids_result.commit_ids, + { + "art": "2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea", + "bootable/recovery": "e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06", + "build/bazel": "ade9b7a0d874e25fff4bf2552488825c6f111928", + }, + ) + self.assertFalse(commit_ids_result.fatal) + + def test_superproject_write_manifest_file(self): + """Test with writing manifest to a file after setting revisionId.""" + self.assertEqual(len(self._superproject._manifest.projects), 1) + project = self._superproject._manifest.projects[0] + project.SetRevisionId("ABCDEF") + # Create temporary directory so that it can write the file. + os.mkdir(self._superproject._superproject_path) + manifest_path = self._superproject._WriteManifestFile() + self.assertIsNotNone(manifest_path) + with open(manifest_path, "r") as fp: manifest_xml_data = fp.read() - self.assertEqual( - sort_attributes(manifest_xml_data), - '' - '' - '' - '' - '' - '') - - def test_superproject_update_project_revision_id_no_superproject_tag(self): - """Test update of commit ids of a manifest without superproject tag.""" - manifest = self.getXmlManifest(""" + self.assertEqual( + sort_attributes(manifest_xml_data), + '' + '' + '' + '' + '' + "", + ) + + def test_superproject_update_project_revision_id(self): + """Test with LsTree being a mock.""" + self.assertEqual(len(self._superproject._manifest.projects), 1) + projects = self._superproject._manifest.projects + data = ( + "160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00" + "160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00" + ) + with mock.patch.object(self._superproject, "_Init", return_value=True): + with mock.patch.object( + self._superproject, "_Fetch", return_value=True + ): + with mock.patch.object( + self._superproject, "_LsTree", return_value=data + ): + # Create temporary directory so that it can write the file. + os.mkdir(self._superproject._superproject_path) + update_result = self._superproject.UpdateProjectsRevisionId( + projects, self.git_event_log + ) + self.assertIsNotNone(update_result.manifest_path) + self.assertFalse(update_result.fatal) + with open(update_result.manifest_path, "r") as fp: + manifest_xml_data = fp.read() + self.assertEqual( + sort_attributes(manifest_xml_data), + '' + '' + '' + '' + '' + "", + ) + + def test_superproject_update_project_revision_id_no_superproject_tag(self): + """Test update of commit ids of a manifest without superproject tag.""" + manifest = self.getXmlManifest( + """ -""") - self.maxDiff = None - self.assertIsNone(manifest.superproject) - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') - - def test_superproject_update_project_revision_id_from_local_manifest_group(self): - """Test update of commit ids of a manifest that have local manifest no superproject group.""" - local_group = manifest_xml.LOCAL_MANIFEST_GROUP_PREFIX + ':local' - manifest = self.getXmlManifest(""" +""" + ) + self.maxDiff = None + self.assertIsNone(manifest.superproject) + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + "", + ) + + def test_superproject_update_project_revision_id_from_local_manifest_group( + self, + ): + """Test update of commit ids of a manifest that have local manifest no superproject group.""" + local_group = manifest_xml.LOCAL_MANIFEST_GROUP_PREFIX + ":local" + manifest = self.getXmlManifest( + """ - -""") - self.maxDiff = None - self._superproject = git_superproject.Superproject( - manifest, name='superproject', - remote=manifest.remotes.get('default-remote').ToRemoteSpec('superproject'), - revision='refs/heads/main') - self.assertEqual(len(self._superproject._manifest.projects), 2) - projects = self._superproject._manifest.projects - data = ('160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00') - with mock.patch.object(self._superproject, '_Init', return_value=True): - with mock.patch.object(self._superproject, '_Fetch', return_value=True): - with mock.patch.object(self._superproject, - '_LsTree', - return_value=data): - # Create temporary directory so that it can write the file. - os.mkdir(self._superproject._superproject_path) - update_result = self._superproject.UpdateProjectsRevisionId(projects, self.git_event_log) - self.assertIsNotNone(update_result.manifest_path) - self.assertFalse(update_result.fatal) - with open(update_result.manifest_path, 'r') as fp: - manifest_xml_data = fp.read() - # Verify platform/vendor/x's project revision hasn't changed. - self.assertEqual( - sort_attributes(manifest_xml_data), - '' - '' - '' - '' - '' - '' - '') - - def test_superproject_update_project_revision_id_with_pinned_manifest(self): - """Test update of commit ids of a pinned manifest.""" - manifest = self.getXmlManifest(""" +""" + ) + self.maxDiff = None + self._superproject = git_superproject.Superproject( + manifest, + name="superproject", + remote=manifest.remotes.get("default-remote").ToRemoteSpec( + "superproject" + ), + revision="refs/heads/main", + ) + self.assertEqual(len(self._superproject._manifest.projects), 2) + projects = self._superproject._manifest.projects + data = "160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00" + with mock.patch.object(self._superproject, "_Init", return_value=True): + with mock.patch.object( + self._superproject, "_Fetch", return_value=True + ): + with mock.patch.object( + self._superproject, "_LsTree", return_value=data + ): + # Create temporary directory so that it can write the file. + os.mkdir(self._superproject._superproject_path) + update_result = self._superproject.UpdateProjectsRevisionId( + projects, self.git_event_log + ) + self.assertIsNotNone(update_result.manifest_path) + self.assertFalse(update_result.fatal) + with open(update_result.manifest_path, "r") as fp: + manifest_xml_data = fp.read() + # Verify platform/vendor/x's project revision hasn't + # changed. + self.assertEqual( + sort_attributes(manifest_xml_data), + '' + '' + '' + '' + '' + '' + "", + ) + + def test_superproject_update_project_revision_id_with_pinned_manifest(self): + """Test update of commit ids of a pinned manifest.""" + manifest = self.getXmlManifest( + """ @@ -326,80 +398,132 @@ class SuperprojectTestCase(unittest.TestCase): - -""") - self.maxDiff = None - self._superproject = git_superproject.Superproject( - manifest, name='superproject', - remote=manifest.remotes.get('default-remote').ToRemoteSpec('superproject'), - revision='refs/heads/main') - self.assertEqual(len(self._superproject._manifest.projects), 3) - projects = self._superproject._manifest.projects - data = ('160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00' - '160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tvendor/x\x00') - with mock.patch.object(self._superproject, '_Init', return_value=True): - with mock.patch.object(self._superproject, '_Fetch', return_value=True): - with mock.patch.object(self._superproject, - '_LsTree', - return_value=data): - # Create temporary directory so that it can write the file. - os.mkdir(self._superproject._superproject_path) - update_result = self._superproject.UpdateProjectsRevisionId(projects, self.git_event_log) - self.assertIsNotNone(update_result.manifest_path) - self.assertFalse(update_result.fatal) - with open(update_result.manifest_path, 'r') as fp: - manifest_xml_data = fp.read() - # Verify platform/vendor/x's project revision hasn't changed. - self.assertEqual( - sort_attributes(manifest_xml_data), - '' - '' - '' - '' - '' - '' - '' - '') - - def test_Fetch(self): - manifest = self.getXmlManifest(""" +""" + ) + self.maxDiff = None + self._superproject = git_superproject.Superproject( + manifest, + name="superproject", + remote=manifest.remotes.get("default-remote").ToRemoteSpec( + "superproject" + ), + revision="refs/heads/main", + ) + self.assertEqual(len(self._superproject._manifest.projects), 3) + projects = self._superproject._manifest.projects + data = ( + "160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00" + "160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tvendor/x\x00" + ) + with mock.patch.object(self._superproject, "_Init", return_value=True): + with mock.patch.object( + self._superproject, "_Fetch", return_value=True + ): + with mock.patch.object( + self._superproject, "_LsTree", return_value=data + ): + # Create temporary directory so that it can write the file. + os.mkdir(self._superproject._superproject_path) + update_result = self._superproject.UpdateProjectsRevisionId( + projects, self.git_event_log + ) + self.assertIsNotNone(update_result.manifest_path) + self.assertFalse(update_result.fatal) + with open(update_result.manifest_path, "r") as fp: + manifest_xml_data = fp.read() + # Verify platform/vendor/x's project revision hasn't + # changed. + self.assertEqual( + sort_attributes(manifest_xml_data), + '' + '' + '' + '' + '' + '' + '' + "", + ) + + def test_Fetch(self): + manifest = self.getXmlManifest( + """ " /> -""") - self.maxDiff = None - self._superproject = git_superproject.Superproject( - manifest, name='superproject', - remote=manifest.remotes.get('default-remote').ToRemoteSpec('superproject'), - revision='refs/heads/main') - os.mkdir(self._superproject._superproject_path) - os.mkdir(self._superproject._work_git) - with mock.patch.object(self._superproject, '_Init', return_value=True): - with mock.patch('git_superproject.GitCommand', autospec=True) as mock_git_command: - with mock.patch('git_superproject.GitRefs.get', autospec=True) as mock_git_refs: - instance = mock_git_command.return_value - instance.Wait.return_value = 0 - mock_git_refs.side_effect = ['', '1234'] - - self.assertTrue(self._superproject._Fetch()) - self.assertEqual(mock_git_command.call_args.args,(None, [ - 'fetch', 'http://localhost/superproject', '--depth', '1', - '--force', '--no-tags', '--filter', 'blob:none', - 'refs/heads/main:refs/heads/main' - ])) - - # If branch for revision exists, set as --negotiation-tip. - self.assertTrue(self._superproject._Fetch()) - self.assertEqual(mock_git_command.call_args.args,(None, [ - 'fetch', 'http://localhost/superproject', '--depth', '1', - '--force', '--no-tags', '--filter', 'blob:none', - '--negotiation-tip', '1234', - 'refs/heads/main:refs/heads/main' - ])) +""" + ) + self.maxDiff = None + self._superproject = git_superproject.Superproject( + manifest, + name="superproject", + remote=manifest.remotes.get("default-remote").ToRemoteSpec( + "superproject" + ), + revision="refs/heads/main", + ) + os.mkdir(self._superproject._superproject_path) + os.mkdir(self._superproject._work_git) + with mock.patch.object(self._superproject, "_Init", return_value=True): + with mock.patch( + "git_superproject.GitCommand", autospec=True + ) as mock_git_command: + with mock.patch( + "git_superproject.GitRefs.get", autospec=True + ) as mock_git_refs: + instance = mock_git_command.return_value + instance.Wait.return_value = 0 + mock_git_refs.side_effect = ["", "1234"] + + self.assertTrue(self._superproject._Fetch()) + self.assertEqual( + mock_git_command.call_args.args, + ( + None, + [ + "fetch", + "http://localhost/superproject", + "--depth", + "1", + "--force", + "--no-tags", + "--filter", + "blob:none", + "refs/heads/main:refs/heads/main", + ], + ), + ) + + # If branch for revision exists, set as --negotiation-tip. + self.assertTrue(self._superproject._Fetch()) + self.assertEqual( + mock_git_command.call_args.args, + ( + None, + [ + "fetch", + "http://localhost/superproject", + "--depth", + "1", + "--force", + "--no-tags", + "--filter", + "blob:none", + "--negotiation-tip", + "1234", + "refs/heads/main:refs/heads/main", + ], + ), + ) diff --git a/tests/test_git_trace2_event_log.py b/tests/test_git_trace2_event_log.py index 7e7dfb7a..a6078d38 100644 --- a/tests/test_git_trace2_event_log.py +++ b/tests/test_git_trace2_event_log.py @@ -27,361 +27,382 @@ import platform_utils def serverLoggingThread(socket_path, server_ready, received_traces): - """Helper function to receive logs over a Unix domain socket. - - Appends received messages on the provided socket and appends to received_traces. - - Args: - socket_path: path to a Unix domain socket on which to listen for traces - server_ready: a threading.Condition used to signal to the caller that this thread is ready to - accept connections - received_traces: a list to which received traces will be appended (after decoding to a utf-8 - string). - """ - platform_utils.remove(socket_path, missing_ok=True) - data = b'' - with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: - sock.bind(socket_path) - sock.listen(0) - with server_ready: - server_ready.notify() - with sock.accept()[0] as conn: - while True: - recved = conn.recv(4096) - if not recved: - break - data += recved - received_traces.extend(data.decode('utf-8').splitlines()) + """Helper function to receive logs over a Unix domain socket. + Appends received messages on the provided socket and appends to + received_traces. -class EventLogTestCase(unittest.TestCase): - """TestCase for the EventLog module.""" - - PARENT_SID_KEY = 'GIT_TRACE2_PARENT_SID' - PARENT_SID_VALUE = 'parent_sid' - SELF_SID_REGEX = r'repo-\d+T\d+Z-.*' - FULL_SID_REGEX = r'^%s/%s' % (PARENT_SID_VALUE, SELF_SID_REGEX) - - def setUp(self): - """Load the event_log module every time.""" - self._event_log_module = None - # By default we initialize with the expected case where - # repo launches us (so GIT_TRACE2_PARENT_SID is set). - env = { - self.PARENT_SID_KEY: self.PARENT_SID_VALUE, - } - self._event_log_module = git_trace2_event_log.EventLog(env=env) - self._log_data = None - - def verifyCommonKeys(self, log_entry, expected_event_name=None, full_sid=True): - """Helper function to verify common event log keys.""" - self.assertIn('event', log_entry) - self.assertIn('sid', log_entry) - self.assertIn('thread', log_entry) - self.assertIn('time', log_entry) - - # Do basic data format validation. - if expected_event_name: - self.assertEqual(expected_event_name, log_entry['event']) - if full_sid: - self.assertRegex(log_entry['sid'], self.FULL_SID_REGEX) - else: - self.assertRegex(log_entry['sid'], self.SELF_SID_REGEX) - self.assertRegex(log_entry['time'], r'^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$') - - def readLog(self, log_path): - """Helper function to read log data into a list.""" - log_data = [] - with open(log_path, mode='rb') as f: - for line in f: - log_data.append(json.loads(line)) - return log_data - - def remove_prefix(self, s, prefix): - """Return a copy string after removing |prefix| from |s|, if present or the original string.""" - if s.startswith(prefix): - return s[len(prefix):] - else: - return s - - def test_initial_state_with_parent_sid(self): - """Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent.""" - self.assertRegex(self._event_log_module.full_sid, self.FULL_SID_REGEX) - - def test_initial_state_no_parent_sid(self): - """Test initial state when 'GIT_TRACE2_PARENT_SID' is not set.""" - # Setup an empty environment dict (no parent sid). - self._event_log_module = git_trace2_event_log.EventLog(env={}) - self.assertRegex(self._event_log_module.full_sid, self.SELF_SID_REGEX) - - def test_version_event(self): - """Test 'version' event data is valid. - - Verify that the 'version' event is written even when no other - events are addded. - - Expected event log: - - """ - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - # A log with no added events should only have the version entry. - self.assertEqual(len(self._log_data), 1) - version_event = self._log_data[0] - self.verifyCommonKeys(version_event, expected_event_name='version') - # Check for 'version' event specific fields. - self.assertIn('evt', version_event) - self.assertIn('exe', version_event) - # Verify "evt" version field is a string. - self.assertIsInstance(version_event['evt'], str) - - def test_start_event(self): - """Test and validate 'start' event data is valid. - - Expected event log: - - - """ - self._event_log_module.StartEvent() - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - start_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - self.verifyCommonKeys(start_event, expected_event_name='start') - # Check for 'start' event specific fields. - self.assertIn('argv', start_event) - self.assertTrue(isinstance(start_event['argv'], list)) - - def test_exit_event_result_none(self): - """Test 'exit' event data is valid when result is None. - - We expect None result to be converted to 0 in the exit event data. - - Expected event log: - - - """ - self._event_log_module.ExitEvent(None) - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - exit_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - self.verifyCommonKeys(exit_event, expected_event_name='exit') - # Check for 'exit' event specific fields. - self.assertIn('code', exit_event) - # 'None' result should convert to 0 (successful) return code. - self.assertEqual(exit_event['code'], 0) - - def test_exit_event_result_integer(self): - """Test 'exit' event data is valid when result is an integer. - - Expected event log: - - - """ - self._event_log_module.ExitEvent(2) - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - exit_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - self.verifyCommonKeys(exit_event, expected_event_name='exit') - # Check for 'exit' event specific fields. - self.assertIn('code', exit_event) - self.assertEqual(exit_event['code'], 2) - - def test_command_event(self): - """Test and validate 'command' event data is valid. - - Expected event log: - - - """ - name = 'repo' - subcommands = ['init' 'this'] - self._event_log_module.CommandEvent(name='repo', subcommands=subcommands) - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - command_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - self.verifyCommonKeys(command_event, expected_event_name='command') - # Check for 'command' event specific fields. - self.assertIn('name', command_event) - self.assertIn('subcommands', command_event) - self.assertEqual(command_event['name'], name) - self.assertEqual(command_event['subcommands'], subcommands) - - def test_def_params_event_repo_config(self): - """Test 'def_params' event data outputs only repo config keys. - - Expected event log: - - - + Args: + socket_path: path to a Unix domain socket on which to listen for traces + server_ready: a threading.Condition used to signal to the caller that + this thread is ready to accept connections + received_traces: a list to which received traces will be appended (after + decoding to a utf-8 string). """ - config = { - 'git.foo': 'bar', - 'repo.partialclone': 'true', - 'repo.partialclonefilter': 'blob:none', - } - self._event_log_module.DefParamRepoEvents(config) - - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 3) - def_param_events = self._log_data[1:] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - - for event in def_param_events: - self.verifyCommonKeys(event, expected_event_name='def_param') - # Check for 'def_param' event specific fields. - self.assertIn('param', event) - self.assertIn('value', event) - self.assertTrue(event['param'].startswith('repo.')) - - def test_def_params_event_no_repo_config(self): - """Test 'def_params' event data won't output non-repo config keys. - - Expected event log: - - """ - config = { - 'git.foo': 'bar', - 'git.core.foo2': 'baz', - } - self._event_log_module.DefParamRepoEvents(config) - - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 1) - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - - def test_data_event_config(self): - """Test 'data' event data outputs all config keys. - - Expected event log: - - - - """ - config = { - 'git.foo': 'bar', - 'repo.partialclone': 'false', - 'repo.syncstate.superproject.hassuperprojecttag': 'true', - 'repo.syncstate.superproject.sys.argv': ['--', 'sync', 'protobuf'], - } - prefix_value = 'prefix' - self._event_log_module.LogDataConfigEvents(config, prefix_value) - - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 5) - data_events = self._log_data[1:] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - - for event in data_events: - self.verifyCommonKeys(event) - # Check for 'data' event specific fields. - self.assertIn('key', event) - self.assertIn('value', event) - key = event['key'] - key = self.remove_prefix(key, f'{prefix_value}/') - value = event['value'] - self.assertEqual(self._event_log_module.GetDataEventName(value), event['event']) - self.assertTrue(key in config and value == config[key]) - - def test_error_event(self): - """Test and validate 'error' event data is valid. - - Expected event log: - - - """ - msg = 'invalid option: --cahced' - fmt = 'invalid option: %s' - self._event_log_module.ErrorEvent(msg, fmt) - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - error_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - self.verifyCommonKeys(error_event, expected_event_name='error') - # Check for 'error' event specific fields. - self.assertIn('msg', error_event) - self.assertIn('fmt', error_event) - self.assertEqual(error_event['msg'], msg) - self.assertEqual(error_event['fmt'], fmt) - - def test_write_with_filename(self): - """Test Write() with a path to a file exits with None.""" - self.assertIsNone(self._event_log_module.Write(path='path/to/file')) - - def test_write_with_git_config(self): - """Test Write() uses the git config path when 'git config' call succeeds.""" - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - with mock.patch.object(self._event_log_module, - '_GetEventTargetPath', return_value=tempdir): - self.assertEqual(os.path.dirname(self._event_log_module.Write()), tempdir) - - def test_write_no_git_config(self): - """Test Write() with no git config variable present exits with None.""" - with mock.patch.object(self._event_log_module, - '_GetEventTargetPath', return_value=None): - self.assertIsNone(self._event_log_module.Write()) - - def test_write_non_string(self): - """Test Write() with non-string type for |path| throws TypeError.""" - with self.assertRaises(TypeError): - self._event_log_module.Write(path=1234) - - def test_write_socket(self): - """Test Write() with Unix domain socket for |path| and validate received traces.""" - received_traces = [] - with tempfile.TemporaryDirectory(prefix='test_server_sockets') as tempdir: - socket_path = os.path.join(tempdir, "server.sock") - server_ready = threading.Condition() - # Start "server" listening on Unix domain socket at socket_path. - try: - server_thread = threading.Thread( - target=serverLoggingThread, - args=(socket_path, server_ready, received_traces)) - server_thread.start() - + platform_utils.remove(socket_path, missing_ok=True) + data = b"" + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: + sock.bind(socket_path) + sock.listen(0) with server_ready: - server_ready.wait(timeout=120) + server_ready.notify() + with sock.accept()[0] as conn: + while True: + recved = conn.recv(4096) + if not recved: + break + data += recved + received_traces.extend(data.decode("utf-8").splitlines()) + +class EventLogTestCase(unittest.TestCase): + """TestCase for the EventLog module.""" + + PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID" + PARENT_SID_VALUE = "parent_sid" + SELF_SID_REGEX = r"repo-\d+T\d+Z-.*" + FULL_SID_REGEX = r"^%s/%s" % (PARENT_SID_VALUE, SELF_SID_REGEX) + + def setUp(self): + """Load the event_log module every time.""" + self._event_log_module = None + # By default we initialize with the expected case where + # repo launches us (so GIT_TRACE2_PARENT_SID is set). + env = { + self.PARENT_SID_KEY: self.PARENT_SID_VALUE, + } + self._event_log_module = git_trace2_event_log.EventLog(env=env) + self._log_data = None + + def verifyCommonKeys( + self, log_entry, expected_event_name=None, full_sid=True + ): + """Helper function to verify common event log keys.""" + self.assertIn("event", log_entry) + self.assertIn("sid", log_entry) + self.assertIn("thread", log_entry) + self.assertIn("time", log_entry) + + # Do basic data format validation. + if expected_event_name: + self.assertEqual(expected_event_name, log_entry["event"]) + if full_sid: + self.assertRegex(log_entry["sid"], self.FULL_SID_REGEX) + else: + self.assertRegex(log_entry["sid"], self.SELF_SID_REGEX) + self.assertRegex(log_entry["time"], r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$") + + def readLog(self, log_path): + """Helper function to read log data into a list.""" + log_data = [] + with open(log_path, mode="rb") as f: + for line in f: + log_data.append(json.loads(line)) + return log_data + + def remove_prefix(self, s, prefix): + """Return a copy string after removing |prefix| from |s|, if present or + the original string.""" + if s.startswith(prefix): + return s[len(prefix) :] + else: + return s + + def test_initial_state_with_parent_sid(self): + """Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent.""" + self.assertRegex(self._event_log_module.full_sid, self.FULL_SID_REGEX) + + def test_initial_state_no_parent_sid(self): + """Test initial state when 'GIT_TRACE2_PARENT_SID' is not set.""" + # Setup an empty environment dict (no parent sid). + self._event_log_module = git_trace2_event_log.EventLog(env={}) + self.assertRegex(self._event_log_module.full_sid, self.SELF_SID_REGEX) + + def test_version_event(self): + """Test 'version' event data is valid. + + Verify that the 'version' event is written even when no other + events are addded. + + Expected event log: + + """ + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + # A log with no added events should only have the version entry. + self.assertEqual(len(self._log_data), 1) + version_event = self._log_data[0] + self.verifyCommonKeys(version_event, expected_event_name="version") + # Check for 'version' event specific fields. + self.assertIn("evt", version_event) + self.assertIn("exe", version_event) + # Verify "evt" version field is a string. + self.assertIsInstance(version_event["evt"], str) + + def test_start_event(self): + """Test and validate 'start' event data is valid. + + Expected event log: + + + """ self._event_log_module.StartEvent() - path = self._event_log_module.Write(path=f'af_unix:{socket_path}') - finally: - server_thread.join(timeout=5) - - self.assertEqual(path, f'af_unix:stream:{socket_path}') - self.assertEqual(len(received_traces), 2) - version_event = json.loads(received_traces[0]) - start_event = json.loads(received_traces[1]) - self.verifyCommonKeys(version_event, expected_event_name='version') - self.verifyCommonKeys(start_event, expected_event_name='start') - # Check for 'start' event specific fields. - self.assertIn('argv', start_event) - self.assertIsInstance(start_event['argv'], list) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 2) + start_event = self._log_data[1] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + self.verifyCommonKeys(start_event, expected_event_name="start") + # Check for 'start' event specific fields. + self.assertIn("argv", start_event) + self.assertTrue(isinstance(start_event["argv"], list)) + + def test_exit_event_result_none(self): + """Test 'exit' event data is valid when result is None. + + We expect None result to be converted to 0 in the exit event data. + + Expected event log: + + + """ + self._event_log_module.ExitEvent(None) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 2) + exit_event = self._log_data[1] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + self.verifyCommonKeys(exit_event, expected_event_name="exit") + # Check for 'exit' event specific fields. + self.assertIn("code", exit_event) + # 'None' result should convert to 0 (successful) return code. + self.assertEqual(exit_event["code"], 0) + + def test_exit_event_result_integer(self): + """Test 'exit' event data is valid when result is an integer. + + Expected event log: + + + """ + self._event_log_module.ExitEvent(2) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 2) + exit_event = self._log_data[1] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + self.verifyCommonKeys(exit_event, expected_event_name="exit") + # Check for 'exit' event specific fields. + self.assertIn("code", exit_event) + self.assertEqual(exit_event["code"], 2) + + def test_command_event(self): + """Test and validate 'command' event data is valid. + + Expected event log: + + + """ + name = "repo" + subcommands = ["init" "this"] + self._event_log_module.CommandEvent( + name="repo", subcommands=subcommands + ) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 2) + command_event = self._log_data[1] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + self.verifyCommonKeys(command_event, expected_event_name="command") + # Check for 'command' event specific fields. + self.assertIn("name", command_event) + self.assertIn("subcommands", command_event) + self.assertEqual(command_event["name"], name) + self.assertEqual(command_event["subcommands"], subcommands) + + def test_def_params_event_repo_config(self): + """Test 'def_params' event data outputs only repo config keys. + + Expected event log: + + + + """ + config = { + "git.foo": "bar", + "repo.partialclone": "true", + "repo.partialclonefilter": "blob:none", + } + self._event_log_module.DefParamRepoEvents(config) + + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 3) + def_param_events = self._log_data[1:] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + + for event in def_param_events: + self.verifyCommonKeys(event, expected_event_name="def_param") + # Check for 'def_param' event specific fields. + self.assertIn("param", event) + self.assertIn("value", event) + self.assertTrue(event["param"].startswith("repo.")) + + def test_def_params_event_no_repo_config(self): + """Test 'def_params' event data won't output non-repo config keys. + + Expected event log: + + """ + config = { + "git.foo": "bar", + "git.core.foo2": "baz", + } + self._event_log_module.DefParamRepoEvents(config) + + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 1) + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + + def test_data_event_config(self): + """Test 'data' event data outputs all config keys. + + Expected event log: + + + + """ + config = { + "git.foo": "bar", + "repo.partialclone": "false", + "repo.syncstate.superproject.hassuperprojecttag": "true", + "repo.syncstate.superproject.sys.argv": ["--", "sync", "protobuf"], + } + prefix_value = "prefix" + self._event_log_module.LogDataConfigEvents(config, prefix_value) + + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 5) + data_events = self._log_data[1:] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + + for event in data_events: + self.verifyCommonKeys(event) + # Check for 'data' event specific fields. + self.assertIn("key", event) + self.assertIn("value", event) + key = event["key"] + key = self.remove_prefix(key, f"{prefix_value}/") + value = event["value"] + self.assertEqual( + self._event_log_module.GetDataEventName(value), event["event"] + ) + self.assertTrue(key in config and value == config[key]) + + def test_error_event(self): + """Test and validate 'error' event data is valid. + + Expected event log: + + + """ + msg = "invalid option: --cahced" + fmt = "invalid option: %s" + self._event_log_module.ErrorEvent(msg, fmt) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 2) + error_event = self._log_data[1] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + self.verifyCommonKeys(error_event, expected_event_name="error") + # Check for 'error' event specific fields. + self.assertIn("msg", error_event) + self.assertIn("fmt", error_event) + self.assertEqual(error_event["msg"], msg) + self.assertEqual(error_event["fmt"], fmt) + + def test_write_with_filename(self): + """Test Write() with a path to a file exits with None.""" + self.assertIsNone(self._event_log_module.Write(path="path/to/file")) + + def test_write_with_git_config(self): + """Test Write() uses the git config path when 'git config' call + succeeds.""" + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + with mock.patch.object( + self._event_log_module, + "_GetEventTargetPath", + return_value=tempdir, + ): + self.assertEqual( + os.path.dirname(self._event_log_module.Write()), tempdir + ) + + def test_write_no_git_config(self): + """Test Write() with no git config variable present exits with None.""" + with mock.patch.object( + self._event_log_module, "_GetEventTargetPath", return_value=None + ): + self.assertIsNone(self._event_log_module.Write()) + + def test_write_non_string(self): + """Test Write() with non-string type for |path| throws TypeError.""" + with self.assertRaises(TypeError): + self._event_log_module.Write(path=1234) + + def test_write_socket(self): + """Test Write() with Unix domain socket for |path| and validate received + traces.""" + received_traces = [] + with tempfile.TemporaryDirectory( + prefix="test_server_sockets" + ) as tempdir: + socket_path = os.path.join(tempdir, "server.sock") + server_ready = threading.Condition() + # Start "server" listening on Unix domain socket at socket_path. + try: + server_thread = threading.Thread( + target=serverLoggingThread, + args=(socket_path, server_ready, received_traces), + ) + server_thread.start() + + with server_ready: + server_ready.wait(timeout=120) + + self._event_log_module.StartEvent() + path = self._event_log_module.Write( + path=f"af_unix:{socket_path}" + ) + finally: + server_thread.join(timeout=5) + + self.assertEqual(path, f"af_unix:stream:{socket_path}") + self.assertEqual(len(received_traces), 2) + version_event = json.loads(received_traces[0]) + start_event = json.loads(received_traces[1]) + self.verifyCommonKeys(version_event, expected_event_name="version") + self.verifyCommonKeys(start_event, expected_event_name="start") + # Check for 'start' event specific fields. + self.assertIn("argv", start_event) + self.assertIsInstance(start_event["argv"], list) diff --git a/tests/test_hooks.py b/tests/test_hooks.py index 6632b3e5..78277128 100644 --- a/tests/test_hooks.py +++ b/tests/test_hooks.py @@ -17,39 +17,38 @@ import hooks import unittest + class RepoHookShebang(unittest.TestCase): - """Check shebang parsing in RepoHook.""" + """Check shebang parsing in RepoHook.""" - def test_no_shebang(self): - """Lines w/out shebangs should be rejected.""" - DATA = ( - '', - '#\n# foo\n', - '# Bad shebang in script\n#!/foo\n' - ) - for data in DATA: - self.assertIsNone(hooks.RepoHook._ExtractInterpFromShebang(data)) + def test_no_shebang(self): + """Lines w/out shebangs should be rejected.""" + DATA = ("", "#\n# foo\n", "# Bad shebang in script\n#!/foo\n") + for data in DATA: + self.assertIsNone(hooks.RepoHook._ExtractInterpFromShebang(data)) - def test_direct_interp(self): - """Lines whose shebang points directly to the interpreter.""" - DATA = ( - ('#!/foo', '/foo'), - ('#! /foo', '/foo'), - ('#!/bin/foo ', '/bin/foo'), - ('#! /usr/foo ', '/usr/foo'), - ('#! /usr/foo -args', '/usr/foo'), - ) - for shebang, interp in DATA: - self.assertEqual(hooks.RepoHook._ExtractInterpFromShebang(shebang), - interp) + def test_direct_interp(self): + """Lines whose shebang points directly to the interpreter.""" + DATA = ( + ("#!/foo", "/foo"), + ("#! /foo", "/foo"), + ("#!/bin/foo ", "/bin/foo"), + ("#! /usr/foo ", "/usr/foo"), + ("#! /usr/foo -args", "/usr/foo"), + ) + for shebang, interp in DATA: + self.assertEqual( + hooks.RepoHook._ExtractInterpFromShebang(shebang), interp + ) - def test_env_interp(self): - """Lines whose shebang launches through `env`.""" - DATA = ( - ('#!/usr/bin/env foo', 'foo'), - ('#!/bin/env foo', 'foo'), - ('#! /bin/env /bin/foo ', '/bin/foo'), - ) - for shebang, interp in DATA: - self.assertEqual(hooks.RepoHook._ExtractInterpFromShebang(shebang), - interp) + def test_env_interp(self): + """Lines whose shebang launches through `env`.""" + DATA = ( + ("#!/usr/bin/env foo", "foo"), + ("#!/bin/env foo", "foo"), + ("#! /bin/env /bin/foo ", "/bin/foo"), + ) + for shebang, interp in DATA: + self.assertEqual( + hooks.RepoHook._ExtractInterpFromShebang(shebang), interp + ) diff --git a/tests/test_manifest_xml.py b/tests/test_manifest_xml.py index 3634701f..648acde8 100644 --- a/tests/test_manifest_xml.py +++ b/tests/test_manifest_xml.py @@ -27,291 +27,318 @@ import manifest_xml # Invalid paths that we don't want in the filesystem. INVALID_FS_PATHS = ( - '', - '.', - '..', - '../', - './', - './/', - 'foo/', - './foo', - '../foo', - 'foo/./bar', - 'foo/../../bar', - '/foo', - './../foo', - '.git/foo', + "", + ".", + "..", + "../", + "./", + ".//", + "foo/", + "./foo", + "../foo", + "foo/./bar", + "foo/../../bar", + "/foo", + "./../foo", + ".git/foo", # Check case folding. - '.GIT/foo', - 'blah/.git/foo', - '.repo/foo', - '.repoconfig', + ".GIT/foo", + "blah/.git/foo", + ".repo/foo", + ".repoconfig", # Block ~ due to 8.3 filenames on Windows filesystems. - '~', - 'foo~', - 'blah/foo~', + "~", + "foo~", + "blah/foo~", # Block Unicode characters that get normalized out by filesystems. - u'foo\u200Cbar', + "foo\u200Cbar", # Block newlines. - 'f\n/bar', - 'f\r/bar', + "f\n/bar", + "f\r/bar", ) # Make sure platforms that use path separators (e.g. Windows) are also # rejected properly. -if os.path.sep != '/': - INVALID_FS_PATHS += tuple(x.replace('/', os.path.sep) for x in INVALID_FS_PATHS) +if os.path.sep != "/": + INVALID_FS_PATHS += tuple( + x.replace("/", os.path.sep) for x in INVALID_FS_PATHS + ) def sort_attributes(manifest): - """Sort the attributes of all elements alphabetically. - - This is needed because different versions of the toxml() function from - xml.dom.minidom outputs the attributes of elements in different orders. - Before Python 3.8 they were output alphabetically, later versions preserve - the order specified by the user. - - Args: - manifest: String containing an XML manifest. - - Returns: - The XML manifest with the attributes of all elements sorted alphabetically. - """ - new_manifest = '' - # This will find every element in the XML manifest, whether they have - # attributes or not. This simplifies recreating the manifest below. - matches = re.findall(r'(<[/?]?[a-z-]+\s*)((?:\S+?="[^"]+"\s*?)*)(\s*[/?]?>)', manifest) - for head, attrs, tail in matches: - m = re.findall(r'\S+?="[^"]+"', attrs) - new_manifest += head + ' '.join(sorted(m)) + tail - return new_manifest + """Sort the attributes of all elements alphabetically. + + This is needed because different versions of the toxml() function from + xml.dom.minidom outputs the attributes of elements in different orders. + Before Python 3.8 they were output alphabetically, later versions preserve + the order specified by the user. + + Args: + manifest: String containing an XML manifest. + + Returns: + The XML manifest with the attributes of all elements sorted + alphabetically. + """ + new_manifest = "" + # This will find every element in the XML manifest, whether they have + # attributes or not. This simplifies recreating the manifest below. + matches = re.findall( + r'(<[/?]?[a-z-]+\s*)((?:\S+?="[^"]+"\s*?)*)(\s*[/?]?>)', manifest + ) + for head, attrs, tail in matches: + m = re.findall(r'\S+?="[^"]+"', attrs) + new_manifest += head + " ".join(sorted(m)) + tail + return new_manifest class ManifestParseTestCase(unittest.TestCase): - """TestCase for parsing manifests.""" - - def setUp(self): - self.tempdirobj = tempfile.TemporaryDirectory(prefix='repo_tests') - self.tempdir = self.tempdirobj.name - self.repodir = os.path.join(self.tempdir, '.repo') - self.manifest_dir = os.path.join(self.repodir, 'manifests') - self.manifest_file = os.path.join( - self.repodir, manifest_xml.MANIFEST_FILE_NAME) - self.local_manifest_dir = os.path.join( - self.repodir, manifest_xml.LOCAL_MANIFESTS_DIR_NAME) - os.mkdir(self.repodir) - os.mkdir(self.manifest_dir) - - # The manifest parsing really wants a git repo currently. - gitdir = os.path.join(self.repodir, 'manifests.git') - os.mkdir(gitdir) - with open(os.path.join(gitdir, 'config'), 'w') as fp: - fp.write("""[remote "origin"] + """TestCase for parsing manifests.""" + + def setUp(self): + self.tempdirobj = tempfile.TemporaryDirectory(prefix="repo_tests") + self.tempdir = self.tempdirobj.name + self.repodir = os.path.join(self.tempdir, ".repo") + self.manifest_dir = os.path.join(self.repodir, "manifests") + self.manifest_file = os.path.join( + self.repodir, manifest_xml.MANIFEST_FILE_NAME + ) + self.local_manifest_dir = os.path.join( + self.repodir, manifest_xml.LOCAL_MANIFESTS_DIR_NAME + ) + os.mkdir(self.repodir) + os.mkdir(self.manifest_dir) + + # The manifest parsing really wants a git repo currently. + gitdir = os.path.join(self.repodir, "manifests.git") + os.mkdir(gitdir) + with open(os.path.join(gitdir, "config"), "w") as fp: + fp.write( + """[remote "origin"] url = https://localhost:0/manifest -""") +""" + ) - def tearDown(self): - self.tempdirobj.cleanup() + def tearDown(self): + self.tempdirobj.cleanup() - def getXmlManifest(self, data): - """Helper to initialize a manifest for testing.""" - with open(self.manifest_file, 'w', encoding="utf-8") as fp: - fp.write(data) - return manifest_xml.XmlManifest(self.repodir, self.manifest_file) + def getXmlManifest(self, data): + """Helper to initialize a manifest for testing.""" + with open(self.manifest_file, "w", encoding="utf-8") as fp: + fp.write(data) + return manifest_xml.XmlManifest(self.repodir, self.manifest_file) - @staticmethod - def encodeXmlAttr(attr): - """Encode |attr| using XML escape rules.""" - return attr.replace('\r', ' ').replace('\n', ' ') + @staticmethod + def encodeXmlAttr(attr): + """Encode |attr| using XML escape rules.""" + return attr.replace("\r", " ").replace("\n", " ") class ManifestValidateFilePaths(unittest.TestCase): - """Check _ValidateFilePaths helper. - - This doesn't access a real filesystem. - """ - - def check_both(self, *args): - manifest_xml.XmlManifest._ValidateFilePaths('copyfile', *args) - manifest_xml.XmlManifest._ValidateFilePaths('linkfile', *args) - - def test_normal_path(self): - """Make sure good paths are accepted.""" - self.check_both('foo', 'bar') - self.check_both('foo/bar', 'bar') - self.check_both('foo', 'bar/bar') - self.check_both('foo/bar', 'bar/bar') - - def test_symlink_targets(self): - """Some extra checks for symlinks.""" - def check(*args): - manifest_xml.XmlManifest._ValidateFilePaths('linkfile', *args) - - # We allow symlinks to end in a slash since we allow them to point to dirs - # in general. Technically the slash isn't necessary. - check('foo/', 'bar') - # We allow a single '.' to get a reference to the project itself. - check('.', 'bar') - - def test_bad_paths(self): - """Make sure bad paths (src & dest) are rejected.""" - for path in INVALID_FS_PATHS: - self.assertRaises( - error.ManifestInvalidPathError, self.check_both, path, 'a') - self.assertRaises( - error.ManifestInvalidPathError, self.check_both, 'a', path) + """Check _ValidateFilePaths helper. + + This doesn't access a real filesystem. + """ + + def check_both(self, *args): + manifest_xml.XmlManifest._ValidateFilePaths("copyfile", *args) + manifest_xml.XmlManifest._ValidateFilePaths("linkfile", *args) + + def test_normal_path(self): + """Make sure good paths are accepted.""" + self.check_both("foo", "bar") + self.check_both("foo/bar", "bar") + self.check_both("foo", "bar/bar") + self.check_both("foo/bar", "bar/bar") + + def test_symlink_targets(self): + """Some extra checks for symlinks.""" + + def check(*args): + manifest_xml.XmlManifest._ValidateFilePaths("linkfile", *args) + + # We allow symlinks to end in a slash since we allow them to point to + # dirs in general. Technically the slash isn't necessary. + check("foo/", "bar") + # We allow a single '.' to get a reference to the project itself. + check(".", "bar") + + def test_bad_paths(self): + """Make sure bad paths (src & dest) are rejected.""" + for path in INVALID_FS_PATHS: + self.assertRaises( + error.ManifestInvalidPathError, self.check_both, path, "a" + ) + self.assertRaises( + error.ManifestInvalidPathError, self.check_both, "a", path + ) class ValueTests(unittest.TestCase): - """Check utility parsing code.""" - - def _get_node(self, text): - return xml.dom.minidom.parseString(text).firstChild - - def test_bool_default(self): - """Check XmlBool default handling.""" - node = self._get_node('') - self.assertIsNone(manifest_xml.XmlBool(node, 'a')) - self.assertIsNone(manifest_xml.XmlBool(node, 'a', None)) - self.assertEqual(123, manifest_xml.XmlBool(node, 'a', 123)) - - node = self._get_node('') - self.assertIsNone(manifest_xml.XmlBool(node, 'a')) - - def test_bool_invalid(self): - """Check XmlBool invalid handling.""" - node = self._get_node('') - self.assertEqual(123, manifest_xml.XmlBool(node, 'a', 123)) - - def test_bool_true(self): - """Check XmlBool true values.""" - for value in ('yes', 'true', '1'): - node = self._get_node('' % (value,)) - self.assertTrue(manifest_xml.XmlBool(node, 'a')) - - def test_bool_false(self): - """Check XmlBool false values.""" - for value in ('no', 'false', '0'): - node = self._get_node('' % (value,)) - self.assertFalse(manifest_xml.XmlBool(node, 'a')) - - def test_int_default(self): - """Check XmlInt default handling.""" - node = self._get_node('') - self.assertIsNone(manifest_xml.XmlInt(node, 'a')) - self.assertIsNone(manifest_xml.XmlInt(node, 'a', None)) - self.assertEqual(123, manifest_xml.XmlInt(node, 'a', 123)) - - node = self._get_node('') - self.assertIsNone(manifest_xml.XmlInt(node, 'a')) - - def test_int_good(self): - """Check XmlInt numeric handling.""" - for value in (-1, 0, 1, 50000): - node = self._get_node('' % (value,)) - self.assertEqual(value, manifest_xml.XmlInt(node, 'a')) - - def test_int_invalid(self): - """Check XmlInt invalid handling.""" - with self.assertRaises(error.ManifestParseError): - node = self._get_node('') - manifest_xml.XmlInt(node, 'a') + """Check utility parsing code.""" + + def _get_node(self, text): + return xml.dom.minidom.parseString(text).firstChild + + def test_bool_default(self): + """Check XmlBool default handling.""" + node = self._get_node("") + self.assertIsNone(manifest_xml.XmlBool(node, "a")) + self.assertIsNone(manifest_xml.XmlBool(node, "a", None)) + self.assertEqual(123, manifest_xml.XmlBool(node, "a", 123)) + + node = self._get_node('') + self.assertIsNone(manifest_xml.XmlBool(node, "a")) + + def test_bool_invalid(self): + """Check XmlBool invalid handling.""" + node = self._get_node('') + self.assertEqual(123, manifest_xml.XmlBool(node, "a", 123)) + + def test_bool_true(self): + """Check XmlBool true values.""" + for value in ("yes", "true", "1"): + node = self._get_node('' % (value,)) + self.assertTrue(manifest_xml.XmlBool(node, "a")) + + def test_bool_false(self): + """Check XmlBool false values.""" + for value in ("no", "false", "0"): + node = self._get_node('' % (value,)) + self.assertFalse(manifest_xml.XmlBool(node, "a")) + + def test_int_default(self): + """Check XmlInt default handling.""" + node = self._get_node("") + self.assertIsNone(manifest_xml.XmlInt(node, "a")) + self.assertIsNone(manifest_xml.XmlInt(node, "a", None)) + self.assertEqual(123, manifest_xml.XmlInt(node, "a", 123)) + + node = self._get_node('') + self.assertIsNone(manifest_xml.XmlInt(node, "a")) + + def test_int_good(self): + """Check XmlInt numeric handling.""" + for value in (-1, 0, 1, 50000): + node = self._get_node('' % (value,)) + self.assertEqual(value, manifest_xml.XmlInt(node, "a")) + + def test_int_invalid(self): + """Check XmlInt invalid handling.""" + with self.assertRaises(error.ManifestParseError): + node = self._get_node('') + manifest_xml.XmlInt(node, "a") class XmlManifestTests(ManifestParseTestCase): - """Check manifest processing.""" - - def test_empty(self): - """Parse an 'empty' manifest file.""" - manifest = self.getXmlManifest( - '' - '') - self.assertEqual(manifest.remotes, {}) - self.assertEqual(manifest.projects, []) - - def test_link(self): - """Verify Link handling with new names.""" - manifest = manifest_xml.XmlManifest(self.repodir, self.manifest_file) - with open(os.path.join(self.manifest_dir, 'foo.xml'), 'w') as fp: - fp.write('') - manifest.Link('foo.xml') - with open(self.manifest_file) as fp: - self.assertIn('', fp.read()) - - def test_toxml_empty(self): - """Verify the ToXml() helper.""" - manifest = self.getXmlManifest( - '' - '') - self.assertEqual(manifest.ToXml().toxml(), '') - - def test_todict_empty(self): - """Verify the ToDict() helper.""" - manifest = self.getXmlManifest( - '' - '') - self.assertEqual(manifest.ToDict(), {}) - - def test_toxml_omit_local(self): - """Does not include local_manifests projects when omit_local=True.""" - manifest = self.getXmlManifest( - '' - '' - '' - '' - '' - '') - self.assertEqual( - sort_attributes(manifest.ToXml(omit_local=True).toxml()), - '' - '' - '') - - def test_toxml_with_local(self): - """Does include local_manifests projects when omit_local=False.""" - manifest = self.getXmlManifest( - '' - '' - '' - '' - '' - '') - self.assertEqual( - sort_attributes(manifest.ToXml(omit_local=False).toxml()), - '' - '' - '' - '') - - def test_repo_hooks(self): - """Check repo-hooks settings.""" - manifest = self.getXmlManifest(""" + """Check manifest processing.""" + + def test_empty(self): + """Parse an 'empty' manifest file.""" + manifest = self.getXmlManifest( + '' "" + ) + self.assertEqual(manifest.remotes, {}) + self.assertEqual(manifest.projects, []) + + def test_link(self): + """Verify Link handling with new names.""" + manifest = manifest_xml.XmlManifest(self.repodir, self.manifest_file) + with open(os.path.join(self.manifest_dir, "foo.xml"), "w") as fp: + fp.write("") + manifest.Link("foo.xml") + with open(self.manifest_file) as fp: + self.assertIn('', fp.read()) + + def test_toxml_empty(self): + """Verify the ToXml() helper.""" + manifest = self.getXmlManifest( + '' "" + ) + self.assertEqual( + manifest.ToXml().toxml(), '' + ) + + def test_todict_empty(self): + """Verify the ToDict() helper.""" + manifest = self.getXmlManifest( + '' "" + ) + self.assertEqual(manifest.ToDict(), {}) + + def test_toxml_omit_local(self): + """Does not include local_manifests projects when omit_local=True.""" + manifest = self.getXmlManifest( + '' + '' + '' + '' + '' + "" + ) + self.assertEqual( + sort_attributes(manifest.ToXml(omit_local=True).toxml()), + '' + '' + '', + ) + + def test_toxml_with_local(self): + """Does include local_manifests projects when omit_local=False.""" + manifest = self.getXmlManifest( + '' + '' + '' + '' + '' + "" + ) + self.assertEqual( + sort_attributes(manifest.ToXml(omit_local=False).toxml()), + '' + '' + '' + '', + ) + + def test_repo_hooks(self): + """Check repo-hooks settings.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.repo_hooks_project.name, 'repohooks') - self.assertEqual(manifest.repo_hooks_project.enabled_repo_hooks, ['a', 'b']) - - def test_repo_hooks_unordered(self): - """Check repo-hooks settings work even if the project def comes second.""" - manifest = self.getXmlManifest(""" +""" + ) + self.assertEqual(manifest.repo_hooks_project.name, "repohooks") + self.assertEqual( + manifest.repo_hooks_project.enabled_repo_hooks, ["a", "b"] + ) + + def test_repo_hooks_unordered(self): + """Check repo-hooks settings work even if the project def comes second.""" # noqa: E501 + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.repo_hooks_project.name, 'repohooks') - self.assertEqual(manifest.repo_hooks_project.enabled_repo_hooks, ['a', 'b']) - - def test_unknown_tags(self): - """Check superproject settings.""" - manifest = self.getXmlManifest(""" +""" + ) + self.assertEqual(manifest.repo_hooks_project.name, "repohooks") + self.assertEqual( + manifest.repo_hooks_project.enabled_repo_hooks, ["a", "b"] + ) + + def test_unknown_tags(self): + """Check superproject settings.""" + manifest = self.getXmlManifest( + """ @@ -319,44 +346,54 @@ class XmlManifestTests(ManifestParseTestCase): X tags are always ignored -""") - self.assertEqual(manifest.superproject.name, 'superproject') - self.assertEqual(manifest.superproject.remote.name, 'test-remote') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') - - def test_remote_annotations(self): - """Check remote settings.""" - manifest = self.getXmlManifest(""" +""" + ) + self.assertEqual(manifest.superproject.name, "superproject") + self.assertEqual(manifest.superproject.remote.name, "test-remote") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + "", + ) + + def test_remote_annotations(self): + """Check remote settings.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.remotes['test-remote'].annotations[0].name, 'foo') - self.assertEqual(manifest.remotes['test-remote'].annotations[0].value, 'bar') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') +""" + ) + self.assertEqual( + manifest.remotes["test-remote"].annotations[0].name, "foo" + ) + self.assertEqual( + manifest.remotes["test-remote"].annotations[0].value, "bar" + ) + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + "" + "", + ) class IncludeElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_group_levels(self): - root_m = os.path.join(self.manifest_dir, 'root.xml') - with open(root_m, 'w') as fp: - fp.write(""" + def test_group_levels(self): + root_m = os.path.join(self.manifest_dir, "root.xml") + with open(root_m, "w") as fp: + fp.write( + """ @@ -364,438 +401,524 @@ class IncludeElementTests(ManifestParseTestCase): -""") - with open(os.path.join(self.manifest_dir, 'level1.xml'), 'w') as fp: - fp.write(""" +""" + ) + with open(os.path.join(self.manifest_dir, "level1.xml"), "w") as fp: + fp.write( + """ -""") - with open(os.path.join(self.manifest_dir, 'level2.xml'), 'w') as fp: - fp.write(""" +""" + ) + with open(os.path.join(self.manifest_dir, "level2.xml"), "w") as fp: + fp.write( + """ -""") - include_m = manifest_xml.XmlManifest(self.repodir, root_m) - for proj in include_m.projects: - if proj.name == 'root-name1': - # Check include group not set on root level proj. - self.assertNotIn('level1-group', proj.groups) - if proj.name == 'root-name2': - # Check root proj group not removed. - self.assertIn('r2g1', proj.groups) - if proj.name == 'level1-name1': - # Check level1 proj has inherited group level 1. - self.assertIn('level1-group', proj.groups) - if proj.name == 'level2-name1': - # Check level2 proj has inherited group levels 1 and 2. - self.assertIn('level1-group', proj.groups) - self.assertIn('level2-group', proj.groups) - # Check level2 proj group not removed. - self.assertIn('l2g1', proj.groups) - - def test_allow_bad_name_from_user(self): - """Check handling of bad name attribute from the user's input.""" - def parse(name): - name = self.encodeXmlAttr(name) - manifest = self.getXmlManifest(f""" +""" + ) + include_m = manifest_xml.XmlManifest(self.repodir, root_m) + for proj in include_m.projects: + if proj.name == "root-name1": + # Check include group not set on root level proj. + self.assertNotIn("level1-group", proj.groups) + if proj.name == "root-name2": + # Check root proj group not removed. + self.assertIn("r2g1", proj.groups) + if proj.name == "level1-name1": + # Check level1 proj has inherited group level 1. + self.assertIn("level1-group", proj.groups) + if proj.name == "level2-name1": + # Check level2 proj has inherited group levels 1 and 2. + self.assertIn("level1-group", proj.groups) + self.assertIn("level2-group", proj.groups) + # Check level2 proj group not removed. + self.assertIn("l2g1", proj.groups) + + def test_allow_bad_name_from_user(self): + """Check handling of bad name attribute from the user's input.""" + + def parse(name): + name = self.encodeXmlAttr(name) + manifest = self.getXmlManifest( + f""" -""") - # Force the manifest to be parsed. - manifest.ToXml() - - # Setup target of the include. - target = os.path.join(self.tempdir, 'target.xml') - with open(target, 'w') as fp: - fp.write('') - - # Include with absolute path. - parse(os.path.abspath(target)) - - # Include with relative path. - parse(os.path.relpath(target, self.manifest_dir)) - - def test_bad_name_checks(self): - """Check handling of bad name attribute.""" - def parse(name): - name = self.encodeXmlAttr(name) - # Setup target of the include. - with open(os.path.join(self.manifest_dir, 'target.xml'), 'w', encoding="utf-8") as fp: - fp.write(f'') - - manifest = self.getXmlManifest(""" +""" + ) + # Force the manifest to be parsed. + manifest.ToXml() + + # Setup target of the include. + target = os.path.join(self.tempdir, "target.xml") + with open(target, "w") as fp: + fp.write("") + + # Include with absolute path. + parse(os.path.abspath(target)) + + # Include with relative path. + parse(os.path.relpath(target, self.manifest_dir)) + + def test_bad_name_checks(self): + """Check handling of bad name attribute.""" + + def parse(name): + name = self.encodeXmlAttr(name) + # Setup target of the include. + with open( + os.path.join(self.manifest_dir, "target.xml"), + "w", + encoding="utf-8", + ) as fp: + fp.write(f'') + + manifest = self.getXmlManifest( + """ -""") - # Force the manifest to be parsed. - manifest.ToXml() +""" + ) + # Force the manifest to be parsed. + manifest.ToXml() - # Handle empty name explicitly because a different codepath rejects it. - with self.assertRaises(error.ManifestParseError): - parse('') + # Handle empty name explicitly because a different codepath rejects it. + with self.assertRaises(error.ManifestParseError): + parse("") - for path in INVALID_FS_PATHS: - if not path: - continue + for path in INVALID_FS_PATHS: + if not path: + continue - with self.assertRaises(error.ManifestInvalidPathError): - parse(path) + with self.assertRaises(error.ManifestInvalidPathError): + parse(path) class ProjectElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_group(self): - """Check project group settings.""" - manifest = self.getXmlManifest(""" + def test_group(self): + """Check project group settings.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(len(manifest.projects), 2) - # Ordering isn't guaranteed. - result = { - manifest.projects[0].name: manifest.projects[0].groups, - manifest.projects[1].name: manifest.projects[1].groups, - } - project = manifest.projects[0] - self.assertCountEqual( - result['test-name'], - ['name:test-name', 'all', 'path:test-path']) - self.assertCountEqual( - result['extras'], - ['g1', 'g2', 'g1', 'name:extras', 'all', 'path:path']) - groupstr = 'default,platform-' + platform.system().lower() - self.assertEqual(groupstr, manifest.GetGroupsStr()) - groupstr = 'g1,g2,g1' - manifest.manifestProject.config.SetString('manifest.groups', groupstr) - self.assertEqual(groupstr, manifest.GetGroupsStr()) - - def test_set_revision_id(self): - """Check setting of project's revisionId.""" - manifest = self.getXmlManifest(""" +""" + ) + self.assertEqual(len(manifest.projects), 2) + # Ordering isn't guaranteed. + result = { + manifest.projects[0].name: manifest.projects[0].groups, + manifest.projects[1].name: manifest.projects[1].groups, + } + self.assertCountEqual( + result["test-name"], ["name:test-name", "all", "path:test-path"] + ) + self.assertCountEqual( + result["extras"], + ["g1", "g2", "g1", "name:extras", "all", "path:path"], + ) + groupstr = "default,platform-" + platform.system().lower() + self.assertEqual(groupstr, manifest.GetGroupsStr()) + groupstr = "g1,g2,g1" + manifest.manifestProject.config.SetString("manifest.groups", groupstr) + self.assertEqual(groupstr, manifest.GetGroupsStr()) + + def test_set_revision_id(self): + """Check setting of project's revisionId.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(len(manifest.projects), 1) - project = manifest.projects[0] - project.SetRevisionId('ABCDEF') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') - - def test_trailing_slash(self): - """Check handling of trailing slashes in attributes.""" - def parse(name, path): - name = self.encodeXmlAttr(name) - path = self.encodeXmlAttr(path) - return self.getXmlManifest(f""" +""" + ) + self.assertEqual(len(manifest.projects), 1) + project = manifest.projects[0] + project.SetRevisionId("ABCDEF") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' # noqa: E501 + "", + ) + + def test_trailing_slash(self): + """Check handling of trailing slashes in attributes.""" + + def parse(name, path): + name = self.encodeXmlAttr(name) + path = self.encodeXmlAttr(path) + return self.getXmlManifest( + f""" -""") - - manifest = parse('a/path/', 'foo') - self.assertEqual(os.path.normpath(manifest.projects[0].gitdir), - os.path.join(self.tempdir, '.repo', 'projects', 'foo.git')) - self.assertEqual(os.path.normpath(manifest.projects[0].objdir), - os.path.join(self.tempdir, '.repo', 'project-objects', 'a', 'path.git')) - - manifest = parse('a/path', 'foo/') - self.assertEqual(os.path.normpath(manifest.projects[0].gitdir), - os.path.join(self.tempdir, '.repo', 'projects', 'foo.git')) - self.assertEqual(os.path.normpath(manifest.projects[0].objdir), - os.path.join(self.tempdir, '.repo', 'project-objects', 'a', 'path.git')) - - manifest = parse('a/path', 'foo//////') - self.assertEqual(os.path.normpath(manifest.projects[0].gitdir), - os.path.join(self.tempdir, '.repo', 'projects', 'foo.git')) - self.assertEqual(os.path.normpath(manifest.projects[0].objdir), - os.path.join(self.tempdir, '.repo', 'project-objects', 'a', 'path.git')) - - def test_toplevel_path(self): - """Check handling of path=. specially.""" - def parse(name, path): - name = self.encodeXmlAttr(name) - path = self.encodeXmlAttr(path) - return self.getXmlManifest(f""" +""" + ) + + manifest = parse("a/path/", "foo") + self.assertEqual( + os.path.normpath(manifest.projects[0].gitdir), + os.path.join(self.tempdir, ".repo", "projects", "foo.git"), + ) + self.assertEqual( + os.path.normpath(manifest.projects[0].objdir), + os.path.join( + self.tempdir, ".repo", "project-objects", "a", "path.git" + ), + ) + + manifest = parse("a/path", "foo/") + self.assertEqual( + os.path.normpath(manifest.projects[0].gitdir), + os.path.join(self.tempdir, ".repo", "projects", "foo.git"), + ) + self.assertEqual( + os.path.normpath(manifest.projects[0].objdir), + os.path.join( + self.tempdir, ".repo", "project-objects", "a", "path.git" + ), + ) + + manifest = parse("a/path", "foo//////") + self.assertEqual( + os.path.normpath(manifest.projects[0].gitdir), + os.path.join(self.tempdir, ".repo", "projects", "foo.git"), + ) + self.assertEqual( + os.path.normpath(manifest.projects[0].objdir), + os.path.join( + self.tempdir, ".repo", "project-objects", "a", "path.git" + ), + ) + + def test_toplevel_path(self): + """Check handling of path=. specially.""" + + def parse(name, path): + name = self.encodeXmlAttr(name) + path = self.encodeXmlAttr(path) + return self.getXmlManifest( + f""" -""") - - for path in ('.', './', './/', './//'): - manifest = parse('server/path', path) - self.assertEqual(os.path.normpath(manifest.projects[0].gitdir), - os.path.join(self.tempdir, '.repo', 'projects', '..git')) - - def test_bad_path_name_checks(self): - """Check handling of bad path & name attributes.""" - def parse(name, path): - name = self.encodeXmlAttr(name) - path = self.encodeXmlAttr(path) - manifest = self.getXmlManifest(f""" +""" + ) + + for path in (".", "./", ".//", ".///"): + manifest = parse("server/path", path) + self.assertEqual( + os.path.normpath(manifest.projects[0].gitdir), + os.path.join(self.tempdir, ".repo", "projects", "..git"), + ) + + def test_bad_path_name_checks(self): + """Check handling of bad path & name attributes.""" + + def parse(name, path): + name = self.encodeXmlAttr(name) + path = self.encodeXmlAttr(path) + manifest = self.getXmlManifest( + f""" -""") - # Force the manifest to be parsed. - manifest.ToXml() +""" + ) + # Force the manifest to be parsed. + manifest.ToXml() - # Verify the parser is valid by default to avoid buggy tests below. - parse('ok', 'ok') + # Verify the parser is valid by default to avoid buggy tests below. + parse("ok", "ok") - # Handle empty name explicitly because a different codepath rejects it. - # Empty path is OK because it defaults to the name field. - with self.assertRaises(error.ManifestParseError): - parse('', 'ok') + # Handle empty name explicitly because a different codepath rejects it. + # Empty path is OK because it defaults to the name field. + with self.assertRaises(error.ManifestParseError): + parse("", "ok") - for path in INVALID_FS_PATHS: - if not path or path.endswith('/') or path.endswith(os.path.sep): - continue + for path in INVALID_FS_PATHS: + if not path or path.endswith("/") or path.endswith(os.path.sep): + continue - with self.assertRaises(error.ManifestInvalidPathError): - parse(path, 'ok') + with self.assertRaises(error.ManifestInvalidPathError): + parse(path, "ok") - # We have a dedicated test for path=".". - if path not in {'.'}: - with self.assertRaises(error.ManifestInvalidPathError): - parse('ok', path) + # We have a dedicated test for path=".". + if path not in {"."}: + with self.assertRaises(error.ManifestInvalidPathError): + parse("ok", path) class SuperProjectElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_superproject(self): - """Check superproject settings.""" - manifest = self.getXmlManifest(""" + def test_superproject(self): + """Check superproject settings.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.superproject.name, 'superproject') - self.assertEqual(manifest.superproject.remote.name, 'test-remote') - self.assertEqual(manifest.superproject.remote.url, 'http://localhost/superproject') - self.assertEqual(manifest.superproject.revision, 'refs/heads/main') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') - - def test_superproject_revision(self): - """Check superproject settings with a different revision attribute""" - self.maxDiff = None - manifest = self.getXmlManifest(""" +""" + ) + self.assertEqual(manifest.superproject.name, "superproject") + self.assertEqual(manifest.superproject.remote.name, "test-remote") + self.assertEqual( + manifest.superproject.remote.url, "http://localhost/superproject" + ) + self.assertEqual(manifest.superproject.revision, "refs/heads/main") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + "", + ) + + def test_superproject_revision(self): + """Check superproject settings with a different revision attribute""" + self.maxDiff = None + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.superproject.name, 'superproject') - self.assertEqual(manifest.superproject.remote.name, 'test-remote') - self.assertEqual(manifest.superproject.remote.url, 'http://localhost/superproject') - self.assertEqual(manifest.superproject.revision, 'refs/heads/stable') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') - - def test_superproject_revision_default_negative(self): - """Check superproject settings with a same revision attribute""" - self.maxDiff = None - manifest = self.getXmlManifest(""" +""" + ) + self.assertEqual(manifest.superproject.name, "superproject") + self.assertEqual(manifest.superproject.remote.name, "test-remote") + self.assertEqual( + manifest.superproject.remote.url, "http://localhost/superproject" + ) + self.assertEqual(manifest.superproject.revision, "refs/heads/stable") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + "", + ) + + def test_superproject_revision_default_negative(self): + """Check superproject settings with a same revision attribute""" + self.maxDiff = None + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.superproject.name, 'superproject') - self.assertEqual(manifest.superproject.remote.name, 'test-remote') - self.assertEqual(manifest.superproject.remote.url, 'http://localhost/superproject') - self.assertEqual(manifest.superproject.revision, 'refs/heads/stable') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') - - def test_superproject_revision_remote(self): - """Check superproject settings with a same revision attribute""" - self.maxDiff = None - manifest = self.getXmlManifest(""" +""" + ) + self.assertEqual(manifest.superproject.name, "superproject") + self.assertEqual(manifest.superproject.remote.name, "test-remote") + self.assertEqual( + manifest.superproject.remote.url, "http://localhost/superproject" + ) + self.assertEqual(manifest.superproject.revision, "refs/heads/stable") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + "", + ) + + def test_superproject_revision_remote(self): + """Check superproject settings with a same revision attribute""" + self.maxDiff = None + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.superproject.name, 'superproject') - self.assertEqual(manifest.superproject.remote.name, 'test-remote') - self.assertEqual(manifest.superproject.remote.url, 'http://localhost/superproject') - self.assertEqual(manifest.superproject.revision, 'refs/heads/stable') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') - - def test_remote(self): - """Check superproject settings with a remote.""" - manifest = self.getXmlManifest(""" +""" # noqa: E501 + ) + self.assertEqual(manifest.superproject.name, "superproject") + self.assertEqual(manifest.superproject.remote.name, "test-remote") + self.assertEqual( + manifest.superproject.remote.url, "http://localhost/superproject" + ) + self.assertEqual(manifest.superproject.revision, "refs/heads/stable") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' # noqa: E501 + '' + '' + "", + ) + + def test_remote(self): + """Check superproject settings with a remote.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.superproject.name, 'platform/superproject') - self.assertEqual(manifest.superproject.remote.name, 'superproject-remote') - self.assertEqual(manifest.superproject.remote.url, 'http://localhost/platform/superproject') - self.assertEqual(manifest.superproject.revision, 'refs/heads/main') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '' - '') - - def test_defalut_remote(self): - """Check superproject settings with a default remote.""" - manifest = self.getXmlManifest(""" +""" + ) + self.assertEqual(manifest.superproject.name, "platform/superproject") + self.assertEqual( + manifest.superproject.remote.name, "superproject-remote" + ) + self.assertEqual( + manifest.superproject.remote.url, + "http://localhost/platform/superproject", + ) + self.assertEqual(manifest.superproject.revision, "refs/heads/main") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + '' # noqa: E501 + "", + ) + + def test_defalut_remote(self): + """Check superproject settings with a default remote.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.superproject.name, 'superproject') - self.assertEqual(manifest.superproject.remote.name, 'default-remote') - self.assertEqual(manifest.superproject.revision, 'refs/heads/main') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') +""" + ) + self.assertEqual(manifest.superproject.name, "superproject") + self.assertEqual(manifest.superproject.remote.name, "default-remote") + self.assertEqual(manifest.superproject.revision, "refs/heads/main") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + "", + ) class ContactinfoElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_contactinfo(self): - """Check contactinfo settings.""" - bugurl = 'http://localhost/contactinfo' - manifest = self.getXmlManifest(f""" + def test_contactinfo(self): + """Check contactinfo settings.""" + bugurl = "http://localhost/contactinfo" + manifest = self.getXmlManifest( + f""" -""") - self.assertEqual(manifest.contactinfo.bugurl, bugurl) - self.assertEqual( - manifest.ToXml().toxml(), - '' - f'' - '') +""" + ) + self.assertEqual(manifest.contactinfo.bugurl, bugurl) + self.assertEqual( + manifest.ToXml().toxml(), + '' + f'' + "", + ) class DefaultElementTests(ManifestParseTestCase): - """Tests for .""" - - def test_default(self): - """Check default settings.""" - a = manifest_xml._Default() - a.revisionExpr = 'foo' - a.remote = manifest_xml._XmlRemote(name='remote') - b = manifest_xml._Default() - b.revisionExpr = 'bar' - self.assertEqual(a, a) - self.assertNotEqual(a, b) - self.assertNotEqual(b, a.remote) - self.assertNotEqual(a, 123) - self.assertNotEqual(a, None) + """Tests for .""" + + def test_default(self): + """Check default settings.""" + a = manifest_xml._Default() + a.revisionExpr = "foo" + a.remote = manifest_xml._XmlRemote(name="remote") + b = manifest_xml._Default() + b.revisionExpr = "bar" + self.assertEqual(a, a) + self.assertNotEqual(a, b) + self.assertNotEqual(b, a.remote) + self.assertNotEqual(a, 123) + self.assertNotEqual(a, None) class RemoteElementTests(ManifestParseTestCase): - """Tests for .""" - - def test_remote(self): - """Check remote settings.""" - a = manifest_xml._XmlRemote(name='foo') - a.AddAnnotation('key1', 'value1', 'true') - b = manifest_xml._XmlRemote(name='foo') - b.AddAnnotation('key2', 'value1', 'true') - c = manifest_xml._XmlRemote(name='foo') - c.AddAnnotation('key1', 'value2', 'true') - d = manifest_xml._XmlRemote(name='foo') - d.AddAnnotation('key1', 'value1', 'false') - self.assertEqual(a, a) - self.assertNotEqual(a, b) - self.assertNotEqual(a, c) - self.assertNotEqual(a, d) - self.assertNotEqual(a, manifest_xml._Default()) - self.assertNotEqual(a, 123) - self.assertNotEqual(a, None) + """Tests for .""" + + def test_remote(self): + """Check remote settings.""" + a = manifest_xml._XmlRemote(name="foo") + a.AddAnnotation("key1", "value1", "true") + b = manifest_xml._XmlRemote(name="foo") + b.AddAnnotation("key2", "value1", "true") + c = manifest_xml._XmlRemote(name="foo") + c.AddAnnotation("key1", "value2", "true") + d = manifest_xml._XmlRemote(name="foo") + d.AddAnnotation("key1", "value1", "false") + self.assertEqual(a, a) + self.assertNotEqual(a, b) + self.assertNotEqual(a, c) + self.assertNotEqual(a, d) + self.assertNotEqual(a, manifest_xml._Default()) + self.assertNotEqual(a, 123) + self.assertNotEqual(a, None) class RemoveProjectElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_remove_one_project(self): - manifest = self.getXmlManifest(""" + def test_remove_one_project(self): + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.projects, []) +""" + ) + self.assertEqual(manifest.projects, []) - def test_remove_one_project_one_remains(self): - manifest = self.getXmlManifest(""" + def test_remove_one_project_one_remains(self): + manifest = self.getXmlManifest( + """ @@ -803,51 +926,59 @@ class RemoveProjectElementTests(ManifestParseTestCase): -""") +""" + ) - self.assertEqual(len(manifest.projects), 1) - self.assertEqual(manifest.projects[0].name, 'yourproject') + self.assertEqual(len(manifest.projects), 1) + self.assertEqual(manifest.projects[0].name, "yourproject") - def test_remove_one_project_doesnt_exist(self): - with self.assertRaises(manifest_xml.ManifestParseError): - manifest = self.getXmlManifest(""" + def test_remove_one_project_doesnt_exist(self): + with self.assertRaises(manifest_xml.ManifestParseError): + manifest = self.getXmlManifest( + """ -""") - manifest.projects +""" + ) + manifest.projects - def test_remove_one_optional_project_doesnt_exist(self): - manifest = self.getXmlManifest(""" + def test_remove_one_optional_project_doesnt_exist(self): + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.projects, []) +""" + ) + self.assertEqual(manifest.projects, []) class ExtendProjectElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_extend_project_dest_path_single_match(self): - manifest = self.getXmlManifest(""" + def test_extend_project_dest_path_single_match(self): + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(len(manifest.projects), 1) - self.assertEqual(manifest.projects[0].relpath, 'bar') - - def test_extend_project_dest_path_multi_match(self): - with self.assertRaises(manifest_xml.ManifestParseError): - manifest = self.getXmlManifest(""" +""" + ) + self.assertEqual(len(manifest.projects), 1) + self.assertEqual(manifest.projects[0].relpath, "bar") + + def test_extend_project_dest_path_multi_match(self): + with self.assertRaises(manifest_xml.ManifestParseError): + manifest = self.getXmlManifest( + """ @@ -855,11 +986,13 @@ class ExtendProjectElementTests(ManifestParseTestCase): -""") - manifest.projects +""" + ) + manifest.projects - def test_extend_project_dest_path_multi_match_path_specified(self): - manifest = self.getXmlManifest(""" + def test_extend_project_dest_path_multi_match_path_specified(self): + manifest = self.getXmlManifest( + """ @@ -867,34 +1000,39 @@ class ExtendProjectElementTests(ManifestParseTestCase): -""") - self.assertEqual(len(manifest.projects), 2) - if manifest.projects[0].relpath == 'y': - self.assertEqual(manifest.projects[1].relpath, 'bar') - else: - self.assertEqual(manifest.projects[0].relpath, 'bar') - self.assertEqual(manifest.projects[1].relpath, 'y') - - def test_extend_project_dest_branch(self): - manifest = self.getXmlManifest(""" +""" + ) + self.assertEqual(len(manifest.projects), 2) + if manifest.projects[0].relpath == "y": + self.assertEqual(manifest.projects[1].relpath, "bar") + else: + self.assertEqual(manifest.projects[0].relpath, "bar") + self.assertEqual(manifest.projects[1].relpath, "y") + + def test_extend_project_dest_branch(self): + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(len(manifest.projects), 1) - self.assertEqual(manifest.projects[0].dest_branch, 'bar') - - def test_extend_project_upstream(self): - manifest = self.getXmlManifest(""" +""" # noqa: E501 + ) + self.assertEqual(len(manifest.projects), 1) + self.assertEqual(manifest.projects[0].dest_branch, "bar") + + def test_extend_project_upstream(self): + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(len(manifest.projects), 1) - self.assertEqual(manifest.projects[0].upstream, 'bar') +""" + ) + self.assertEqual(len(manifest.projects), 1) + self.assertEqual(manifest.projects[0].upstream, "bar") diff --git a/tests/test_platform_utils.py b/tests/test_platform_utils.py index 55b7805c..7a42de01 100644 --- a/tests/test_platform_utils.py +++ b/tests/test_platform_utils.py @@ -22,29 +22,31 @@ import platform_utils class RemoveTests(unittest.TestCase): - """Check remove() helper.""" - - def testMissingOk(self): - """Check missing_ok handling.""" - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, 'test') - - # Should not fail. - platform_utils.remove(path, missing_ok=True) - - # Should fail. - self.assertRaises(OSError, platform_utils.remove, path) - self.assertRaises(OSError, platform_utils.remove, path, missing_ok=False) - - # Should not fail if it exists. - open(path, 'w').close() - platform_utils.remove(path, missing_ok=True) - self.assertFalse(os.path.exists(path)) - - open(path, 'w').close() - platform_utils.remove(path) - self.assertFalse(os.path.exists(path)) - - open(path, 'w').close() - platform_utils.remove(path, missing_ok=False) - self.assertFalse(os.path.exists(path)) + """Check remove() helper.""" + + def testMissingOk(self): + """Check missing_ok handling.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "test") + + # Should not fail. + platform_utils.remove(path, missing_ok=True) + + # Should fail. + self.assertRaises(OSError, platform_utils.remove, path) + self.assertRaises( + OSError, platform_utils.remove, path, missing_ok=False + ) + + # Should not fail if it exists. + open(path, "w").close() + platform_utils.remove(path, missing_ok=True) + self.assertFalse(os.path.exists(path)) + + open(path, "w").close() + platform_utils.remove(path) + self.assertFalse(os.path.exists(path)) + + open(path, "w").close() + platform_utils.remove(path, missing_ok=False) + self.assertFalse(os.path.exists(path)) diff --git a/tests/test_project.py b/tests/test_project.py index c50d9940..bc8330b2 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -31,452 +31,493 @@ import project @contextlib.contextmanager def TempGitTree(): - """Create a new empty git checkout for testing.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - # Tests need to assume, that main is default branch at init, - # which is not supported in config until 2.28. - cmd = ['git', 'init'] - if git_command.git_require((2, 28, 0)): - cmd += ['--initial-branch=main'] - else: - # Use template dir for init. - templatedir = tempfile.mkdtemp(prefix='.test-template') - with open(os.path.join(templatedir, 'HEAD'), 'w') as fp: - fp.write('ref: refs/heads/main\n') - cmd += ['--template', templatedir] - subprocess.check_call(cmd, cwd=tempdir) - yield tempdir + """Create a new empty git checkout for testing.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + # Tests need to assume, that main is default branch at init, + # which is not supported in config until 2.28. + cmd = ["git", "init"] + if git_command.git_require((2, 28, 0)): + cmd += ["--initial-branch=main"] + else: + # Use template dir for init. + templatedir = tempfile.mkdtemp(prefix=".test-template") + with open(os.path.join(templatedir, "HEAD"), "w") as fp: + fp.write("ref: refs/heads/main\n") + cmd += ["--template", templatedir] + subprocess.check_call(cmd, cwd=tempdir) + yield tempdir class FakeProject(object): - """A fake for Project for basic functionality.""" + """A fake for Project for basic functionality.""" - def __init__(self, worktree): - self.worktree = worktree - self.gitdir = os.path.join(worktree, '.git') - self.name = 'fakeproject' - self.work_git = project.Project._GitGetByExec( - self, bare=False, gitdir=self.gitdir) - self.bare_git = project.Project._GitGetByExec( - self, bare=True, gitdir=self.gitdir) - self.config = git_config.GitConfig.ForRepository(gitdir=self.gitdir) + def __init__(self, worktree): + self.worktree = worktree + self.gitdir = os.path.join(worktree, ".git") + self.name = "fakeproject" + self.work_git = project.Project._GitGetByExec( + self, bare=False, gitdir=self.gitdir + ) + self.bare_git = project.Project._GitGetByExec( + self, bare=True, gitdir=self.gitdir + ) + self.config = git_config.GitConfig.ForRepository(gitdir=self.gitdir) class ReviewableBranchTests(unittest.TestCase): - """Check ReviewableBranch behavior.""" - - def test_smoke(self): - """A quick run through everything.""" - with TempGitTree() as tempdir: - fakeproj = FakeProject(tempdir) - - # Generate some commits. - with open(os.path.join(tempdir, 'readme'), 'w') as fp: - fp.write('txt') - fakeproj.work_git.add('readme') - fakeproj.work_git.commit('-mAdd file') - fakeproj.work_git.checkout('-b', 'work') - fakeproj.work_git.rm('-f', 'readme') - fakeproj.work_git.commit('-mDel file') - - # Start off with the normal details. - rb = project.ReviewableBranch( - fakeproj, fakeproj.config.GetBranch('work'), 'main') - self.assertEqual('work', rb.name) - self.assertEqual(1, len(rb.commits)) - self.assertIn('Del file', rb.commits[0]) - d = rb.unabbrev_commits - self.assertEqual(1, len(d)) - short, long = next(iter(d.items())) - self.assertTrue(long.startswith(short)) - self.assertTrue(rb.base_exists) - # Hard to assert anything useful about this. - self.assertTrue(rb.date) - - # Now delete the tracking branch! - fakeproj.work_git.branch('-D', 'main') - rb = project.ReviewableBranch( - fakeproj, fakeproj.config.GetBranch('work'), 'main') - self.assertEqual(0, len(rb.commits)) - self.assertFalse(rb.base_exists) - # Hard to assert anything useful about this. - self.assertTrue(rb.date) + """Check ReviewableBranch behavior.""" + + def test_smoke(self): + """A quick run through everything.""" + with TempGitTree() as tempdir: + fakeproj = FakeProject(tempdir) + + # Generate some commits. + with open(os.path.join(tempdir, "readme"), "w") as fp: + fp.write("txt") + fakeproj.work_git.add("readme") + fakeproj.work_git.commit("-mAdd file") + fakeproj.work_git.checkout("-b", "work") + fakeproj.work_git.rm("-f", "readme") + fakeproj.work_git.commit("-mDel file") + + # Start off with the normal details. + rb = project.ReviewableBranch( + fakeproj, fakeproj.config.GetBranch("work"), "main" + ) + self.assertEqual("work", rb.name) + self.assertEqual(1, len(rb.commits)) + self.assertIn("Del file", rb.commits[0]) + d = rb.unabbrev_commits + self.assertEqual(1, len(d)) + short, long = next(iter(d.items())) + self.assertTrue(long.startswith(short)) + self.assertTrue(rb.base_exists) + # Hard to assert anything useful about this. + self.assertTrue(rb.date) + + # Now delete the tracking branch! + fakeproj.work_git.branch("-D", "main") + rb = project.ReviewableBranch( + fakeproj, fakeproj.config.GetBranch("work"), "main" + ) + self.assertEqual(0, len(rb.commits)) + self.assertFalse(rb.base_exists) + # Hard to assert anything useful about this. + self.assertTrue(rb.date) class CopyLinkTestCase(unittest.TestCase): - """TestCase for stub repo client checkouts. - - It'll have a layout like this: - tempdir/ # self.tempdir - checkout/ # self.topdir - git-project/ # self.worktree - - Attributes: - tempdir: A dedicated temporary directory. - worktree: The top of the repo client checkout. - topdir: The top of a project checkout. - """ - - def setUp(self): - self.tempdirobj = tempfile.TemporaryDirectory(prefix='repo_tests') - self.tempdir = self.tempdirobj.name - self.topdir = os.path.join(self.tempdir, 'checkout') - self.worktree = os.path.join(self.topdir, 'git-project') - os.makedirs(self.topdir) - os.makedirs(self.worktree) - - def tearDown(self): - self.tempdirobj.cleanup() - - @staticmethod - def touch(path): - with open(path, 'w'): - pass - - def assertExists(self, path, msg=None): - """Make sure |path| exists.""" - if os.path.exists(path): - return - - if msg is None: - msg = ['path is missing: %s' % path] - while path != '/': - path = os.path.dirname(path) - if not path: - # If we're given something like "foo", abort once we get to "". - break - result = os.path.exists(path) - msg.append('\tos.path.exists(%s): %s' % (path, result)) - if result: - msg.append('\tcontents: %r' % os.listdir(path)) - break - msg = '\n'.join(msg) - - raise self.failureException(msg) + """TestCase for stub repo client checkouts. + + It'll have a layout like this: + tempdir/ # self.tempdir + checkout/ # self.topdir + git-project/ # self.worktree + + Attributes: + tempdir: A dedicated temporary directory. + worktree: The top of the repo client checkout. + topdir: The top of a project checkout. + """ + + def setUp(self): + self.tempdirobj = tempfile.TemporaryDirectory(prefix="repo_tests") + self.tempdir = self.tempdirobj.name + self.topdir = os.path.join(self.tempdir, "checkout") + self.worktree = os.path.join(self.topdir, "git-project") + os.makedirs(self.topdir) + os.makedirs(self.worktree) + + def tearDown(self): + self.tempdirobj.cleanup() + + @staticmethod + def touch(path): + with open(path, "w"): + pass + + def assertExists(self, path, msg=None): + """Make sure |path| exists.""" + if os.path.exists(path): + return + + if msg is None: + msg = ["path is missing: %s" % path] + while path != "/": + path = os.path.dirname(path) + if not path: + # If we're given something like "foo", abort once we get to + # "". + break + result = os.path.exists(path) + msg.append("\tos.path.exists(%s): %s" % (path, result)) + if result: + msg.append("\tcontents: %r" % os.listdir(path)) + break + msg = "\n".join(msg) + + raise self.failureException(msg) class CopyFile(CopyLinkTestCase): - """Check _CopyFile handling.""" - - def CopyFile(self, src, dest): - return project._CopyFile(self.worktree, src, self.topdir, dest) - - def test_basic(self): - """Basic test of copying a file from a project to the toplevel.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - cf = self.CopyFile('foo.txt', 'foo') - cf._Copy() - self.assertExists(os.path.join(self.topdir, 'foo')) - - def test_src_subdir(self): - """Copy a file from a subdir of a project.""" - src = os.path.join(self.worktree, 'bar', 'foo.txt') - os.makedirs(os.path.dirname(src)) - self.touch(src) - cf = self.CopyFile('bar/foo.txt', 'new.txt') - cf._Copy() - self.assertExists(os.path.join(self.topdir, 'new.txt')) - - def test_dest_subdir(self): - """Copy a file to a subdir of a checkout.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - cf = self.CopyFile('foo.txt', 'sub/dir/new.txt') - self.assertFalse(os.path.exists(os.path.join(self.topdir, 'sub'))) - cf._Copy() - self.assertExists(os.path.join(self.topdir, 'sub', 'dir', 'new.txt')) - - def test_update(self): - """Make sure changed files get copied again.""" - src = os.path.join(self.worktree, 'foo.txt') - dest = os.path.join(self.topdir, 'bar') - with open(src, 'w') as f: - f.write('1st') - cf = self.CopyFile('foo.txt', 'bar') - cf._Copy() - self.assertExists(dest) - with open(dest) as f: - self.assertEqual(f.read(), '1st') - - with open(src, 'w') as f: - f.write('2nd!') - cf._Copy() - with open(dest) as f: - self.assertEqual(f.read(), '2nd!') - - def test_src_block_symlink(self): - """Do not allow reading from a symlinked path.""" - src = os.path.join(self.worktree, 'foo.txt') - sym = os.path.join(self.worktree, 'sym') - self.touch(src) - platform_utils.symlink('foo.txt', sym) - self.assertExists(sym) - cf = self.CopyFile('sym', 'foo') - self.assertRaises(error.ManifestInvalidPathError, cf._Copy) - - def test_src_block_symlink_traversal(self): - """Do not allow reading through a symlink dir.""" - realfile = os.path.join(self.tempdir, 'file.txt') - self.touch(realfile) - src = os.path.join(self.worktree, 'bar', 'file.txt') - platform_utils.symlink(self.tempdir, os.path.join(self.worktree, 'bar')) - self.assertExists(src) - cf = self.CopyFile('bar/file.txt', 'foo') - self.assertRaises(error.ManifestInvalidPathError, cf._Copy) - - def test_src_block_copy_from_dir(self): - """Do not allow copying from a directory.""" - src = os.path.join(self.worktree, 'dir') - os.makedirs(src) - cf = self.CopyFile('dir', 'foo') - self.assertRaises(error.ManifestInvalidPathError, cf._Copy) - - def test_dest_block_symlink(self): - """Do not allow writing to a symlink.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - platform_utils.symlink('dest', os.path.join(self.topdir, 'sym')) - cf = self.CopyFile('foo.txt', 'sym') - self.assertRaises(error.ManifestInvalidPathError, cf._Copy) - - def test_dest_block_symlink_traversal(self): - """Do not allow writing through a symlink dir.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - platform_utils.symlink(tempfile.gettempdir(), - os.path.join(self.topdir, 'sym')) - cf = self.CopyFile('foo.txt', 'sym/foo.txt') - self.assertRaises(error.ManifestInvalidPathError, cf._Copy) - - def test_src_block_copy_to_dir(self): - """Do not allow copying to a directory.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - os.makedirs(os.path.join(self.topdir, 'dir')) - cf = self.CopyFile('foo.txt', 'dir') - self.assertRaises(error.ManifestInvalidPathError, cf._Copy) + """Check _CopyFile handling.""" + + def CopyFile(self, src, dest): + return project._CopyFile(self.worktree, src, self.topdir, dest) + + def test_basic(self): + """Basic test of copying a file from a project to the toplevel.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + cf = self.CopyFile("foo.txt", "foo") + cf._Copy() + self.assertExists(os.path.join(self.topdir, "foo")) + + def test_src_subdir(self): + """Copy a file from a subdir of a project.""" + src = os.path.join(self.worktree, "bar", "foo.txt") + os.makedirs(os.path.dirname(src)) + self.touch(src) + cf = self.CopyFile("bar/foo.txt", "new.txt") + cf._Copy() + self.assertExists(os.path.join(self.topdir, "new.txt")) + + def test_dest_subdir(self): + """Copy a file to a subdir of a checkout.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + cf = self.CopyFile("foo.txt", "sub/dir/new.txt") + self.assertFalse(os.path.exists(os.path.join(self.topdir, "sub"))) + cf._Copy() + self.assertExists(os.path.join(self.topdir, "sub", "dir", "new.txt")) + + def test_update(self): + """Make sure changed files get copied again.""" + src = os.path.join(self.worktree, "foo.txt") + dest = os.path.join(self.topdir, "bar") + with open(src, "w") as f: + f.write("1st") + cf = self.CopyFile("foo.txt", "bar") + cf._Copy() + self.assertExists(dest) + with open(dest) as f: + self.assertEqual(f.read(), "1st") + + with open(src, "w") as f: + f.write("2nd!") + cf._Copy() + with open(dest) as f: + self.assertEqual(f.read(), "2nd!") + + def test_src_block_symlink(self): + """Do not allow reading from a symlinked path.""" + src = os.path.join(self.worktree, "foo.txt") + sym = os.path.join(self.worktree, "sym") + self.touch(src) + platform_utils.symlink("foo.txt", sym) + self.assertExists(sym) + cf = self.CopyFile("sym", "foo") + self.assertRaises(error.ManifestInvalidPathError, cf._Copy) + + def test_src_block_symlink_traversal(self): + """Do not allow reading through a symlink dir.""" + realfile = os.path.join(self.tempdir, "file.txt") + self.touch(realfile) + src = os.path.join(self.worktree, "bar", "file.txt") + platform_utils.symlink(self.tempdir, os.path.join(self.worktree, "bar")) + self.assertExists(src) + cf = self.CopyFile("bar/file.txt", "foo") + self.assertRaises(error.ManifestInvalidPathError, cf._Copy) + + def test_src_block_copy_from_dir(self): + """Do not allow copying from a directory.""" + src = os.path.join(self.worktree, "dir") + os.makedirs(src) + cf = self.CopyFile("dir", "foo") + self.assertRaises(error.ManifestInvalidPathError, cf._Copy) + + def test_dest_block_symlink(self): + """Do not allow writing to a symlink.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + platform_utils.symlink("dest", os.path.join(self.topdir, "sym")) + cf = self.CopyFile("foo.txt", "sym") + self.assertRaises(error.ManifestInvalidPathError, cf._Copy) + + def test_dest_block_symlink_traversal(self): + """Do not allow writing through a symlink dir.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + platform_utils.symlink( + tempfile.gettempdir(), os.path.join(self.topdir, "sym") + ) + cf = self.CopyFile("foo.txt", "sym/foo.txt") + self.assertRaises(error.ManifestInvalidPathError, cf._Copy) + + def test_src_block_copy_to_dir(self): + """Do not allow copying to a directory.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + os.makedirs(os.path.join(self.topdir, "dir")) + cf = self.CopyFile("foo.txt", "dir") + self.assertRaises(error.ManifestInvalidPathError, cf._Copy) class LinkFile(CopyLinkTestCase): - """Check _LinkFile handling.""" - - def LinkFile(self, src, dest): - return project._LinkFile(self.worktree, src, self.topdir, dest) - - def test_basic(self): - """Basic test of linking a file from a project into the toplevel.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - lf = self.LinkFile('foo.txt', 'foo') - lf._Link() - dest = os.path.join(self.topdir, 'foo') - self.assertExists(dest) - self.assertTrue(os.path.islink(dest)) - self.assertEqual(os.path.join('git-project', 'foo.txt'), os.readlink(dest)) - - def test_src_subdir(self): - """Link to a file in a subdir of a project.""" - src = os.path.join(self.worktree, 'bar', 'foo.txt') - os.makedirs(os.path.dirname(src)) - self.touch(src) - lf = self.LinkFile('bar/foo.txt', 'foo') - lf._Link() - self.assertExists(os.path.join(self.topdir, 'foo')) - - def test_src_self(self): - """Link to the project itself.""" - dest = os.path.join(self.topdir, 'foo', 'bar') - lf = self.LinkFile('.', 'foo/bar') - lf._Link() - self.assertExists(dest) - self.assertEqual(os.path.join('..', 'git-project'), os.readlink(dest)) - - def test_dest_subdir(self): - """Link a file to a subdir of a checkout.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - lf = self.LinkFile('foo.txt', 'sub/dir/foo/bar') - self.assertFalse(os.path.exists(os.path.join(self.topdir, 'sub'))) - lf._Link() - self.assertExists(os.path.join(self.topdir, 'sub', 'dir', 'foo', 'bar')) - - def test_src_block_relative(self): - """Do not allow relative symlinks.""" - BAD_SOURCES = ( - './', - '..', - '../', - 'foo/.', - 'foo/./bar', - 'foo/..', - 'foo/../foo', - ) - for src in BAD_SOURCES: - lf = self.LinkFile(src, 'foo') - self.assertRaises(error.ManifestInvalidPathError, lf._Link) - - def test_update(self): - """Make sure changed targets get updated.""" - dest = os.path.join(self.topdir, 'sym') - - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - lf = self.LinkFile('foo.txt', 'sym') - lf._Link() - self.assertEqual(os.path.join('git-project', 'foo.txt'), os.readlink(dest)) - - # Point the symlink somewhere else. - os.unlink(dest) - platform_utils.symlink(self.tempdir, dest) - lf._Link() - self.assertEqual(os.path.join('git-project', 'foo.txt'), os.readlink(dest)) + """Check _LinkFile handling.""" + + def LinkFile(self, src, dest): + return project._LinkFile(self.worktree, src, self.topdir, dest) + + def test_basic(self): + """Basic test of linking a file from a project into the toplevel.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + lf = self.LinkFile("foo.txt", "foo") + lf._Link() + dest = os.path.join(self.topdir, "foo") + self.assertExists(dest) + self.assertTrue(os.path.islink(dest)) + self.assertEqual( + os.path.join("git-project", "foo.txt"), os.readlink(dest) + ) + + def test_src_subdir(self): + """Link to a file in a subdir of a project.""" + src = os.path.join(self.worktree, "bar", "foo.txt") + os.makedirs(os.path.dirname(src)) + self.touch(src) + lf = self.LinkFile("bar/foo.txt", "foo") + lf._Link() + self.assertExists(os.path.join(self.topdir, "foo")) + + def test_src_self(self): + """Link to the project itself.""" + dest = os.path.join(self.topdir, "foo", "bar") + lf = self.LinkFile(".", "foo/bar") + lf._Link() + self.assertExists(dest) + self.assertEqual(os.path.join("..", "git-project"), os.readlink(dest)) + + def test_dest_subdir(self): + """Link a file to a subdir of a checkout.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + lf = self.LinkFile("foo.txt", "sub/dir/foo/bar") + self.assertFalse(os.path.exists(os.path.join(self.topdir, "sub"))) + lf._Link() + self.assertExists(os.path.join(self.topdir, "sub", "dir", "foo", "bar")) + + def test_src_block_relative(self): + """Do not allow relative symlinks.""" + BAD_SOURCES = ( + "./", + "..", + "../", + "foo/.", + "foo/./bar", + "foo/..", + "foo/../foo", + ) + for src in BAD_SOURCES: + lf = self.LinkFile(src, "foo") + self.assertRaises(error.ManifestInvalidPathError, lf._Link) + + def test_update(self): + """Make sure changed targets get updated.""" + dest = os.path.join(self.topdir, "sym") + + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + lf = self.LinkFile("foo.txt", "sym") + lf._Link() + self.assertEqual( + os.path.join("git-project", "foo.txt"), os.readlink(dest) + ) + + # Point the symlink somewhere else. + os.unlink(dest) + platform_utils.symlink(self.tempdir, dest) + lf._Link() + self.assertEqual( + os.path.join("git-project", "foo.txt"), os.readlink(dest) + ) class MigrateWorkTreeTests(unittest.TestCase): - """Check _MigrateOldWorkTreeGitDir handling.""" - - _SYMLINKS = { - 'config', 'description', 'hooks', 'info', 'logs', 'objects', - 'packed-refs', 'refs', 'rr-cache', 'shallow', 'svn', - } - _FILES = { - 'COMMIT_EDITMSG', 'FETCH_HEAD', 'HEAD', 'index', 'ORIG_HEAD', - 'unknown-file-should-be-migrated', - } - _CLEAN_FILES = { - 'a-vim-temp-file~', '#an-emacs-temp-file#', - } - - @classmethod - @contextlib.contextmanager - def _simple_layout(cls): - """Create a simple repo client checkout to test against.""" - with tempfile.TemporaryDirectory() as tempdir: - tempdir = Path(tempdir) - - gitdir = tempdir / '.repo/projects/src/test.git' - gitdir.mkdir(parents=True) - cmd = ['git', 'init', '--bare', str(gitdir)] - subprocess.check_call(cmd) - - dotgit = tempdir / 'src/test/.git' - dotgit.mkdir(parents=True) - for name in cls._SYMLINKS: - (dotgit / name).symlink_to(f'../../../.repo/projects/src/test.git/{name}') - for name in cls._FILES | cls._CLEAN_FILES: - (dotgit / name).write_text(name) - - yield tempdir - - def test_standard(self): - """Migrate a standard checkout that we expect.""" - with self._simple_layout() as tempdir: - dotgit = tempdir / 'src/test/.git' - project.Project._MigrateOldWorkTreeGitDir(str(dotgit)) - - # Make sure the dir was transformed into a symlink. - self.assertTrue(dotgit.is_symlink()) - self.assertEqual(os.readlink(dotgit), os.path.normpath('../../.repo/projects/src/test.git')) - - # Make sure files were moved over. - gitdir = tempdir / '.repo/projects/src/test.git' - for name in self._FILES: - self.assertEqual(name, (gitdir / name).read_text()) - # Make sure files were removed. - for name in self._CLEAN_FILES: - self.assertFalse((gitdir / name).exists()) - - def test_unknown(self): - """A checkout with unknown files should abort.""" - with self._simple_layout() as tempdir: - dotgit = tempdir / 'src/test/.git' - (tempdir / '.repo/projects/src/test.git/random-file').write_text('one') - (dotgit / 'random-file').write_text('two') - with self.assertRaises(error.GitError): - project.Project._MigrateOldWorkTreeGitDir(str(dotgit)) - - # Make sure no content was actually changed. - self.assertTrue(dotgit.is_dir()) - for name in self._FILES: - self.assertTrue((dotgit / name).is_file()) - for name in self._CLEAN_FILES: - self.assertTrue((dotgit / name).is_file()) - for name in self._SYMLINKS: - self.assertTrue((dotgit / name).is_symlink()) + """Check _MigrateOldWorkTreeGitDir handling.""" + + _SYMLINKS = { + "config", + "description", + "hooks", + "info", + "logs", + "objects", + "packed-refs", + "refs", + "rr-cache", + "shallow", + "svn", + } + _FILES = { + "COMMIT_EDITMSG", + "FETCH_HEAD", + "HEAD", + "index", + "ORIG_HEAD", + "unknown-file-should-be-migrated", + } + _CLEAN_FILES = { + "a-vim-temp-file~", + "#an-emacs-temp-file#", + } + + @classmethod + @contextlib.contextmanager + def _simple_layout(cls): + """Create a simple repo client checkout to test against.""" + with tempfile.TemporaryDirectory() as tempdir: + tempdir = Path(tempdir) + + gitdir = tempdir / ".repo/projects/src/test.git" + gitdir.mkdir(parents=True) + cmd = ["git", "init", "--bare", str(gitdir)] + subprocess.check_call(cmd) + + dotgit = tempdir / "src/test/.git" + dotgit.mkdir(parents=True) + for name in cls._SYMLINKS: + (dotgit / name).symlink_to( + f"../../../.repo/projects/src/test.git/{name}" + ) + for name in cls._FILES | cls._CLEAN_FILES: + (dotgit / name).write_text(name) + + yield tempdir + + def test_standard(self): + """Migrate a standard checkout that we expect.""" + with self._simple_layout() as tempdir: + dotgit = tempdir / "src/test/.git" + project.Project._MigrateOldWorkTreeGitDir(str(dotgit)) + + # Make sure the dir was transformed into a symlink. + self.assertTrue(dotgit.is_symlink()) + self.assertEqual( + os.readlink(dotgit), + os.path.normpath("../../.repo/projects/src/test.git"), + ) + + # Make sure files were moved over. + gitdir = tempdir / ".repo/projects/src/test.git" + for name in self._FILES: + self.assertEqual(name, (gitdir / name).read_text()) + # Make sure files were removed. + for name in self._CLEAN_FILES: + self.assertFalse((gitdir / name).exists()) + + def test_unknown(self): + """A checkout with unknown files should abort.""" + with self._simple_layout() as tempdir: + dotgit = tempdir / "src/test/.git" + (tempdir / ".repo/projects/src/test.git/random-file").write_text( + "one" + ) + (dotgit / "random-file").write_text("two") + with self.assertRaises(error.GitError): + project.Project._MigrateOldWorkTreeGitDir(str(dotgit)) + + # Make sure no content was actually changed. + self.assertTrue(dotgit.is_dir()) + for name in self._FILES: + self.assertTrue((dotgit / name).is_file()) + for name in self._CLEAN_FILES: + self.assertTrue((dotgit / name).is_file()) + for name in self._SYMLINKS: + self.assertTrue((dotgit / name).is_symlink()) class ManifestPropertiesFetchedCorrectly(unittest.TestCase): - """Ensure properties are fetched properly.""" + """Ensure properties are fetched properly.""" - def setUpManifest(self, tempdir): - repodir = os.path.join(tempdir, '.repo') - manifest_dir = os.path.join(repodir, 'manifests') - manifest_file = os.path.join( - repodir, manifest_xml.MANIFEST_FILE_NAME) - local_manifest_dir = os.path.join( - repodir, manifest_xml.LOCAL_MANIFESTS_DIR_NAME) - os.mkdir(repodir) - os.mkdir(manifest_dir) - manifest = manifest_xml.XmlManifest(repodir, manifest_file) + def setUpManifest(self, tempdir): + repodir = os.path.join(tempdir, ".repo") + manifest_dir = os.path.join(repodir, "manifests") + manifest_file = os.path.join(repodir, manifest_xml.MANIFEST_FILE_NAME) + os.mkdir(repodir) + os.mkdir(manifest_dir) + manifest = manifest_xml.XmlManifest(repodir, manifest_file) - return project.ManifestProject( - manifest, 'test/manifest', os.path.join(tempdir, '.git'), tempdir) + return project.ManifestProject( + manifest, "test/manifest", os.path.join(tempdir, ".git"), tempdir + ) - def test_manifest_config_properties(self): - """Test we are fetching the manifest config properties correctly.""" + def test_manifest_config_properties(self): + """Test we are fetching the manifest config properties correctly.""" - with TempGitTree() as tempdir: - fakeproj = self.setUpManifest(tempdir) + with TempGitTree() as tempdir: + fakeproj = self.setUpManifest(tempdir) - # Set property using the expected Set method, then ensure - # the porperty functions are using the correct Get methods. - fakeproj.config.SetString( - 'manifest.standalone', 'https://chicken/manifest.git') - self.assertEqual( - fakeproj.standalone_manifest_url, 'https://chicken/manifest.git') + # Set property using the expected Set method, then ensure + # the porperty functions are using the correct Get methods. + fakeproj.config.SetString( + "manifest.standalone", "https://chicken/manifest.git" + ) + self.assertEqual( + fakeproj.standalone_manifest_url, "https://chicken/manifest.git" + ) - fakeproj.config.SetString('manifest.groups', 'test-group, admin-group') - self.assertEqual(fakeproj.manifest_groups, 'test-group, admin-group') + fakeproj.config.SetString( + "manifest.groups", "test-group, admin-group" + ) + self.assertEqual( + fakeproj.manifest_groups, "test-group, admin-group" + ) - fakeproj.config.SetString('repo.reference', 'mirror/ref') - self.assertEqual(fakeproj.reference, 'mirror/ref') + fakeproj.config.SetString("repo.reference", "mirror/ref") + self.assertEqual(fakeproj.reference, "mirror/ref") - fakeproj.config.SetBoolean('repo.dissociate', False) - self.assertFalse(fakeproj.dissociate) + fakeproj.config.SetBoolean("repo.dissociate", False) + self.assertFalse(fakeproj.dissociate) - fakeproj.config.SetBoolean('repo.archive', False) - self.assertFalse(fakeproj.archive) + fakeproj.config.SetBoolean("repo.archive", False) + self.assertFalse(fakeproj.archive) - fakeproj.config.SetBoolean('repo.mirror', False) - self.assertFalse(fakeproj.mirror) + fakeproj.config.SetBoolean("repo.mirror", False) + self.assertFalse(fakeproj.mirror) - fakeproj.config.SetBoolean('repo.worktree', False) - self.assertFalse(fakeproj.use_worktree) + fakeproj.config.SetBoolean("repo.worktree", False) + self.assertFalse(fakeproj.use_worktree) - fakeproj.config.SetBoolean('repo.clonebundle', False) - self.assertFalse(fakeproj.clone_bundle) + fakeproj.config.SetBoolean("repo.clonebundle", False) + self.assertFalse(fakeproj.clone_bundle) - fakeproj.config.SetBoolean('repo.submodules', False) - self.assertFalse(fakeproj.submodules) + fakeproj.config.SetBoolean("repo.submodules", False) + self.assertFalse(fakeproj.submodules) - fakeproj.config.SetBoolean('repo.git-lfs', False) - self.assertFalse(fakeproj.git_lfs) + fakeproj.config.SetBoolean("repo.git-lfs", False) + self.assertFalse(fakeproj.git_lfs) - fakeproj.config.SetBoolean('repo.superproject', False) - self.assertFalse(fakeproj.use_superproject) + fakeproj.config.SetBoolean("repo.superproject", False) + self.assertFalse(fakeproj.use_superproject) - fakeproj.config.SetBoolean('repo.partialclone', False) - self.assertFalse(fakeproj.partial_clone) + fakeproj.config.SetBoolean("repo.partialclone", False) + self.assertFalse(fakeproj.partial_clone) - fakeproj.config.SetString('repo.depth', '48') - self.assertEqual(fakeproj.depth, '48') + fakeproj.config.SetString("repo.depth", "48") + self.assertEqual(fakeproj.depth, "48") - fakeproj.config.SetString('repo.clonefilter', 'blob:limit=10M') - self.assertEqual(fakeproj.clone_filter, 'blob:limit=10M') + fakeproj.config.SetString("repo.clonefilter", "blob:limit=10M") + self.assertEqual(fakeproj.clone_filter, "blob:limit=10M") - fakeproj.config.SetString('repo.partialcloneexclude', 'third_party/big_repo') - self.assertEqual(fakeproj.partial_clone_exclude, 'third_party/big_repo') + fakeproj.config.SetString( + "repo.partialcloneexclude", "third_party/big_repo" + ) + self.assertEqual( + fakeproj.partial_clone_exclude, "third_party/big_repo" + ) - fakeproj.config.SetString('manifest.platform', 'auto') - self.assertEqual(fakeproj.manifest_platform, 'auto') + fakeproj.config.SetString("manifest.platform", "auto") + self.assertEqual(fakeproj.manifest_platform, "auto") diff --git a/tests/test_repo_trace.py b/tests/test_repo_trace.py index 5faf2938..e4aeb5de 100644 --- a/tests/test_repo_trace.py +++ b/tests/test_repo_trace.py @@ -22,35 +22,39 @@ import repo_trace class TraceTests(unittest.TestCase): - """Check Trace behavior.""" - - def testTrace_MaxSizeEnforced(self): - content = 'git chicken' - - with repo_trace.Trace(content, first_trace=True): - pass - first_trace_size = os.path.getsize(repo_trace._TRACE_FILE) - - with repo_trace.Trace(content): - pass - self.assertGreater( - os.path.getsize(repo_trace._TRACE_FILE), first_trace_size) - - # Check we clear everything is the last chunk is larger than _MAX_SIZE. - with mock.patch('repo_trace._MAX_SIZE', 0): - with repo_trace.Trace(content, first_trace=True): - pass - self.assertEqual(first_trace_size, - os.path.getsize(repo_trace._TRACE_FILE)) - - # Check we only clear the chunks we need to. - repo_trace._MAX_SIZE = (first_trace_size + 1) / (1024 * 1024) - with repo_trace.Trace(content, first_trace=True): - pass - self.assertEqual(first_trace_size * 2, - os.path.getsize(repo_trace._TRACE_FILE)) - - with repo_trace.Trace(content, first_trace=True): - pass - self.assertEqual(first_trace_size * 2, - os.path.getsize(repo_trace._TRACE_FILE)) + """Check Trace behavior.""" + + def testTrace_MaxSizeEnforced(self): + content = "git chicken" + + with repo_trace.Trace(content, first_trace=True): + pass + first_trace_size = os.path.getsize(repo_trace._TRACE_FILE) + + with repo_trace.Trace(content): + pass + self.assertGreater( + os.path.getsize(repo_trace._TRACE_FILE), first_trace_size + ) + + # Check we clear everything is the last chunk is larger than _MAX_SIZE. + with mock.patch("repo_trace._MAX_SIZE", 0): + with repo_trace.Trace(content, first_trace=True): + pass + self.assertEqual( + first_trace_size, os.path.getsize(repo_trace._TRACE_FILE) + ) + + # Check we only clear the chunks we need to. + repo_trace._MAX_SIZE = (first_trace_size + 1) / (1024 * 1024) + with repo_trace.Trace(content, first_trace=True): + pass + self.assertEqual( + first_trace_size * 2, os.path.getsize(repo_trace._TRACE_FILE) + ) + + with repo_trace.Trace(content, first_trace=True): + pass + self.assertEqual( + first_trace_size * 2, os.path.getsize(repo_trace._TRACE_FILE) + ) diff --git a/tests/test_ssh.py b/tests/test_ssh.py index ffb5cb94..a9c1be7f 100644 --- a/tests/test_ssh.py +++ b/tests/test_ssh.py @@ -23,52 +23,56 @@ import ssh class SshTests(unittest.TestCase): - """Tests the ssh functions.""" + """Tests the ssh functions.""" - def test_parse_ssh_version(self): - """Check _parse_ssh_version() handling.""" - ver = ssh._parse_ssh_version('Unknown\n') - self.assertEqual(ver, ()) - ver = ssh._parse_ssh_version('OpenSSH_1.0\n') - self.assertEqual(ver, (1, 0)) - ver = ssh._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n') - self.assertEqual(ver, (6, 6, 1)) - ver = ssh._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n') - self.assertEqual(ver, (7, 6)) + def test_parse_ssh_version(self): + """Check _parse_ssh_version() handling.""" + ver = ssh._parse_ssh_version("Unknown\n") + self.assertEqual(ver, ()) + ver = ssh._parse_ssh_version("OpenSSH_1.0\n") + self.assertEqual(ver, (1, 0)) + ver = ssh._parse_ssh_version( + "OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n" + ) + self.assertEqual(ver, (6, 6, 1)) + ver = ssh._parse_ssh_version( + "OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n" + ) + self.assertEqual(ver, (7, 6)) - def test_version(self): - """Check version() handling.""" - with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'): - self.assertEqual(ssh.version(), (1, 2)) + def test_version(self): + """Check version() handling.""" + with mock.patch("ssh._run_ssh_version", return_value="OpenSSH_1.2\n"): + self.assertEqual(ssh.version(), (1, 2)) - def test_context_manager_empty(self): - """Verify context manager with no clients works correctly.""" - with multiprocessing.Manager() as manager: - with ssh.ProxyManager(manager): - pass + def test_context_manager_empty(self): + """Verify context manager with no clients works correctly.""" + with multiprocessing.Manager() as manager: + with ssh.ProxyManager(manager): + pass - def test_context_manager_child_cleanup(self): - """Verify orphaned clients & masters get cleaned up.""" - with multiprocessing.Manager() as manager: - with ssh.ProxyManager(manager) as ssh_proxy: - client = subprocess.Popen(['sleep', '964853320']) - ssh_proxy.add_client(client) - master = subprocess.Popen(['sleep', '964853321']) - ssh_proxy.add_master(master) - # If the process still exists, these will throw timeout errors. - client.wait(0) - master.wait(0) + def test_context_manager_child_cleanup(self): + """Verify orphaned clients & masters get cleaned up.""" + with multiprocessing.Manager() as manager: + with ssh.ProxyManager(manager) as ssh_proxy: + client = subprocess.Popen(["sleep", "964853320"]) + ssh_proxy.add_client(client) + master = subprocess.Popen(["sleep", "964853321"]) + ssh_proxy.add_master(master) + # If the process still exists, these will throw timeout errors. + client.wait(0) + master.wait(0) - def test_ssh_sock(self): - """Check sock() function.""" - manager = multiprocessing.Manager() - proxy = ssh.ProxyManager(manager) - with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): - # old ssh version uses port - with mock.patch('ssh.version', return_value=(6, 6)): - self.assertTrue(proxy.sock().endswith('%p')) + def test_ssh_sock(self): + """Check sock() function.""" + manager = multiprocessing.Manager() + proxy = ssh.ProxyManager(manager) + with mock.patch("tempfile.mkdtemp", return_value="/tmp/foo"): + # Old ssh version uses port. + with mock.patch("ssh.version", return_value=(6, 6)): + self.assertTrue(proxy.sock().endswith("%p")) - proxy._sock_path = None - # new ssh version uses hash - with mock.patch('ssh.version', return_value=(6, 7)): - self.assertTrue(proxy.sock().endswith('%C')) + proxy._sock_path = None + # New ssh version uses hash. + with mock.patch("ssh.version", return_value=(6, 7)): + self.assertTrue(proxy.sock().endswith("%C")) diff --git a/tests/test_subcmds.py b/tests/test_subcmds.py index bc53051a..73b66e3f 100644 --- a/tests/test_subcmds.py +++ b/tests/test_subcmds.py @@ -21,53 +21,57 @@ import subcmds class AllCommands(unittest.TestCase): - """Check registered all_commands.""" + """Check registered all_commands.""" - def test_required_basic(self): - """Basic checking of registered commands.""" - # NB: We don't test all subcommands as we want to avoid "change detection" - # tests, so we just look for the most common/important ones here that are - # unlikely to ever change. - for cmd in {'cherry-pick', 'help', 'init', 'start', 'sync', 'upload'}: - self.assertIn(cmd, subcmds.all_commands) + def test_required_basic(self): + """Basic checking of registered commands.""" + # NB: We don't test all subcommands as we want to avoid "change + # detection" tests, so we just look for the most common/important ones + # here that are unlikely to ever change. + for cmd in {"cherry-pick", "help", "init", "start", "sync", "upload"}: + self.assertIn(cmd, subcmds.all_commands) - def test_naming(self): - """Verify we don't add things that we shouldn't.""" - for cmd in subcmds.all_commands: - # Reject filename suffixes like "help.py". - self.assertNotIn('.', cmd) + def test_naming(self): + """Verify we don't add things that we shouldn't.""" + for cmd in subcmds.all_commands: + # Reject filename suffixes like "help.py". + self.assertNotIn(".", cmd) - # Make sure all '_' were converted to '-'. - self.assertNotIn('_', cmd) + # Make sure all '_' were converted to '-'. + self.assertNotIn("_", cmd) - # Reject internal python paths like "__init__". - self.assertFalse(cmd.startswith('__')) + # Reject internal python paths like "__init__". + self.assertFalse(cmd.startswith("__")) - def test_help_desc_style(self): - """Force some consistency in option descriptions. + def test_help_desc_style(self): + """Force some consistency in option descriptions. - Python's optparse & argparse has a few default options like --help. Their - option description text uses lowercase sentence fragments, so enforce our - options follow the same style so UI is consistent. + Python's optparse & argparse has a few default options like --help. + Their option description text uses lowercase sentence fragments, so + enforce our options follow the same style so UI is consistent. - We enforce: - * Text starts with lowercase. - * Text doesn't end with period. - """ - for name, cls in subcmds.all_commands.items(): - cmd = cls() - parser = cmd.OptionParser - for option in parser.option_list: - if option.help == optparse.SUPPRESS_HELP: - continue + We enforce: + * Text starts with lowercase. + * Text doesn't end with period. + """ + for name, cls in subcmds.all_commands.items(): + cmd = cls() + parser = cmd.OptionParser + for option in parser.option_list: + if option.help == optparse.SUPPRESS_HELP: + continue - c = option.help[0] - self.assertEqual( - c.lower(), c, - msg=f'subcmds/{name}.py: {option.get_opt_string()}: help text ' - f'should start with lowercase: "{option.help}"') + c = option.help[0] + self.assertEqual( + c.lower(), + c, + msg=f"subcmds/{name}.py: {option.get_opt_string()}: " + f'help text should start with lowercase: "{option.help}"', + ) - self.assertNotEqual( - option.help[-1], '.', - msg=f'subcmds/{name}.py: {option.get_opt_string()}: help text ' - f'should not end in a period: "{option.help}"') + self.assertNotEqual( + option.help[-1], + ".", + msg=f"subcmds/{name}.py: {option.get_opt_string()}: " + f'help text should not end in a period: "{option.help}"', + ) diff --git a/tests/test_subcmds_init.py b/tests/test_subcmds_init.py index af4346de..25e5be56 100644 --- a/tests/test_subcmds_init.py +++ b/tests/test_subcmds_init.py @@ -20,30 +20,27 @@ from subcmds import init class InitCommand(unittest.TestCase): - """Check registered all_commands.""" - - def setUp(self): - self.cmd = init.Init() - - def test_cli_parser_good(self): - """Check valid command line options.""" - ARGV = ( - [], - ) - for argv in ARGV: - opts, args = self.cmd.OptionParser.parse_args(argv) - self.cmd.ValidateOptions(opts, args) - - def test_cli_parser_bad(self): - """Check invalid command line options.""" - ARGV = ( - # Too many arguments. - ['url', 'asdf'], - - # Conflicting options. - ['--mirror', '--archive'], - ) - for argv in ARGV: - opts, args = self.cmd.OptionParser.parse_args(argv) - with self.assertRaises(SystemExit): - self.cmd.ValidateOptions(opts, args) + """Check registered all_commands.""" + + def setUp(self): + self.cmd = init.Init() + + def test_cli_parser_good(self): + """Check valid command line options.""" + ARGV = ([],) + for argv in ARGV: + opts, args = self.cmd.OptionParser.parse_args(argv) + self.cmd.ValidateOptions(opts, args) + + def test_cli_parser_bad(self): + """Check invalid command line options.""" + ARGV = ( + # Too many arguments. + ["url", "asdf"], + # Conflicting options. + ["--mirror", "--archive"], + ) + for argv in ARGV: + opts, args = self.cmd.OptionParser.parse_args(argv) + with self.assertRaises(SystemExit): + self.cmd.ValidateOptions(opts, args) diff --git a/tests/test_subcmds_sync.py b/tests/test_subcmds_sync.py index 236d54e5..5c8e606e 100644 --- a/tests/test_subcmds_sync.py +++ b/tests/test_subcmds_sync.py @@ -23,111 +23,138 @@ import command from subcmds import sync -@pytest.mark.parametrize('use_superproject, cli_args, result', [ - (True, ['--current-branch'], True), - (True, ['--no-current-branch'], True), - (True, [], True), - (False, ['--current-branch'], True), - (False, ['--no-current-branch'], False), - (False, [], None), -]) +@pytest.mark.parametrize( + "use_superproject, cli_args, result", + [ + (True, ["--current-branch"], True), + (True, ["--no-current-branch"], True), + (True, [], True), + (False, ["--current-branch"], True), + (False, ["--no-current-branch"], False), + (False, [], None), + ], +) def test_get_current_branch_only(use_superproject, cli_args, result): - """Test Sync._GetCurrentBranchOnly logic. + """Test Sync._GetCurrentBranchOnly logic. - Sync._GetCurrentBranchOnly should return True if a superproject is requested, - and otherwise the value of the current_branch_only option. - """ - cmd = sync.Sync() - opts, _ = cmd.OptionParser.parse_args(cli_args) + Sync._GetCurrentBranchOnly should return True if a superproject is + requested, and otherwise the value of the current_branch_only option. + """ + cmd = sync.Sync() + opts, _ = cmd.OptionParser.parse_args(cli_args) - with mock.patch('git_superproject.UseSuperproject', - return_value=use_superproject): - assert cmd._GetCurrentBranchOnly(opts, cmd.manifest) == result + with mock.patch( + "git_superproject.UseSuperproject", return_value=use_superproject + ): + assert cmd._GetCurrentBranchOnly(opts, cmd.manifest) == result # Used to patch os.cpu_count() for reliable results. OS_CPU_COUNT = 24 -@pytest.mark.parametrize('argv, jobs_manifest, jobs, jobs_net, jobs_check', [ - # No user or manifest settings. - ([], None, OS_CPU_COUNT, 1, command.DEFAULT_LOCAL_JOBS), - # No user settings, so manifest settings control. - ([], 3, 3, 3, 3), - # User settings, but no manifest. - (['--jobs=4'], None, 4, 4, 4), - (['--jobs=4', '--jobs-network=5'], None, 4, 5, 4), - (['--jobs=4', '--jobs-checkout=6'], None, 4, 4, 6), - (['--jobs=4', '--jobs-network=5', '--jobs-checkout=6'], None, 4, 5, 6), - (['--jobs-network=5'], None, OS_CPU_COUNT, 5, command.DEFAULT_LOCAL_JOBS), - (['--jobs-checkout=6'], None, OS_CPU_COUNT, 1, 6), - (['--jobs-network=5', '--jobs-checkout=6'], None, OS_CPU_COUNT, 5, 6), - # User settings with manifest settings. - (['--jobs=4'], 3, 4, 4, 4), - (['--jobs=4', '--jobs-network=5'], 3, 4, 5, 4), - (['--jobs=4', '--jobs-checkout=6'], 3, 4, 4, 6), - (['--jobs=4', '--jobs-network=5', '--jobs-checkout=6'], 3, 4, 5, 6), - (['--jobs-network=5'], 3, 3, 5, 3), - (['--jobs-checkout=6'], 3, 3, 3, 6), - (['--jobs-network=5', '--jobs-checkout=6'], 3, 3, 5, 6), - # Settings that exceed rlimits get capped. - (['--jobs=1000000'], None, 83, 83, 83), - ([], 1000000, 83, 83, 83), -]) + +@pytest.mark.parametrize( + "argv, jobs_manifest, jobs, jobs_net, jobs_check", + [ + # No user or manifest settings. + ([], None, OS_CPU_COUNT, 1, command.DEFAULT_LOCAL_JOBS), + # No user settings, so manifest settings control. + ([], 3, 3, 3, 3), + # User settings, but no manifest. + (["--jobs=4"], None, 4, 4, 4), + (["--jobs=4", "--jobs-network=5"], None, 4, 5, 4), + (["--jobs=4", "--jobs-checkout=6"], None, 4, 4, 6), + (["--jobs=4", "--jobs-network=5", "--jobs-checkout=6"], None, 4, 5, 6), + ( + ["--jobs-network=5"], + None, + OS_CPU_COUNT, + 5, + command.DEFAULT_LOCAL_JOBS, + ), + (["--jobs-checkout=6"], None, OS_CPU_COUNT, 1, 6), + (["--jobs-network=5", "--jobs-checkout=6"], None, OS_CPU_COUNT, 5, 6), + # User settings with manifest settings. + (["--jobs=4"], 3, 4, 4, 4), + (["--jobs=4", "--jobs-network=5"], 3, 4, 5, 4), + (["--jobs=4", "--jobs-checkout=6"], 3, 4, 4, 6), + (["--jobs=4", "--jobs-network=5", "--jobs-checkout=6"], 3, 4, 5, 6), + (["--jobs-network=5"], 3, 3, 5, 3), + (["--jobs-checkout=6"], 3, 3, 3, 6), + (["--jobs-network=5", "--jobs-checkout=6"], 3, 3, 5, 6), + # Settings that exceed rlimits get capped. + (["--jobs=1000000"], None, 83, 83, 83), + ([], 1000000, 83, 83, 83), + ], +) def test_cli_jobs(argv, jobs_manifest, jobs, jobs_net, jobs_check): - """Tests --jobs option behavior.""" - mp = mock.MagicMock() - mp.manifest.default.sync_j = jobs_manifest + """Tests --jobs option behavior.""" + mp = mock.MagicMock() + mp.manifest.default.sync_j = jobs_manifest - cmd = sync.Sync() - opts, args = cmd.OptionParser.parse_args(argv) - cmd.ValidateOptions(opts, args) + cmd = sync.Sync() + opts, args = cmd.OptionParser.parse_args(argv) + cmd.ValidateOptions(opts, args) - with mock.patch.object(sync, '_rlimit_nofile', return_value=(256, 256)): - with mock.patch.object(os, 'cpu_count', return_value=OS_CPU_COUNT): - cmd._ValidateOptionsWithManifest(opts, mp) - assert opts.jobs == jobs - assert opts.jobs_network == jobs_net - assert opts.jobs_checkout == jobs_check + with mock.patch.object(sync, "_rlimit_nofile", return_value=(256, 256)): + with mock.patch.object(os, "cpu_count", return_value=OS_CPU_COUNT): + cmd._ValidateOptionsWithManifest(opts, mp) + assert opts.jobs == jobs + assert opts.jobs_network == jobs_net + assert opts.jobs_checkout == jobs_check class GetPreciousObjectsState(unittest.TestCase): - """Tests for _GetPreciousObjectsState.""" - - def setUp(self): - """Common setup.""" - self.cmd = sync.Sync() - self.project = p = mock.MagicMock(use_git_worktrees=False, - UseAlternates=False) - p.manifest.GetProjectsWithName.return_value = [p] - - self.opt = mock.Mock(spec_set=['this_manifest_only']) - self.opt.this_manifest_only = False - - def test_worktrees(self): - """False for worktrees.""" - self.project.use_git_worktrees = True - self.assertFalse(self.cmd._GetPreciousObjectsState(self.project, self.opt)) - - def test_not_shared(self): - """Singleton project.""" - self.assertFalse(self.cmd._GetPreciousObjectsState(self.project, self.opt)) - - def test_shared(self): - """Shared project.""" - self.project.manifest.GetProjectsWithName.return_value = [ - self.project, self.project - ] - self.assertTrue(self.cmd._GetPreciousObjectsState(self.project, self.opt)) - - def test_shared_with_alternates(self): - """Shared project, with alternates.""" - self.project.manifest.GetProjectsWithName.return_value = [ - self.project, self.project - ] - self.project.UseAlternates = True - self.assertFalse(self.cmd._GetPreciousObjectsState(self.project, self.opt)) - - def test_not_found(self): - """Project not found in manifest.""" - self.project.manifest.GetProjectsWithName.return_value = [] - self.assertFalse(self.cmd._GetPreciousObjectsState(self.project, self.opt)) + """Tests for _GetPreciousObjectsState.""" + + def setUp(self): + """Common setup.""" + self.cmd = sync.Sync() + self.project = p = mock.MagicMock( + use_git_worktrees=False, UseAlternates=False + ) + p.manifest.GetProjectsWithName.return_value = [p] + + self.opt = mock.Mock(spec_set=["this_manifest_only"]) + self.opt.this_manifest_only = False + + def test_worktrees(self): + """False for worktrees.""" + self.project.use_git_worktrees = True + self.assertFalse( + self.cmd._GetPreciousObjectsState(self.project, self.opt) + ) + + def test_not_shared(self): + """Singleton project.""" + self.assertFalse( + self.cmd._GetPreciousObjectsState(self.project, self.opt) + ) + + def test_shared(self): + """Shared project.""" + self.project.manifest.GetProjectsWithName.return_value = [ + self.project, + self.project, + ] + self.assertTrue( + self.cmd._GetPreciousObjectsState(self.project, self.opt) + ) + + def test_shared_with_alternates(self): + """Shared project, with alternates.""" + self.project.manifest.GetProjectsWithName.return_value = [ + self.project, + self.project, + ] + self.project.UseAlternates = True + self.assertFalse( + self.cmd._GetPreciousObjectsState(self.project, self.opt) + ) + + def test_not_found(self): + """Project not found in manifest.""" + self.project.manifest.GetProjectsWithName.return_value = [] + self.assertFalse( + self.cmd._GetPreciousObjectsState(self.project, self.opt) + ) diff --git a/tests/test_update_manpages.py b/tests/test_update_manpages.py index 0de85be9..12b19ec4 100644 --- a/tests/test_update_manpages.py +++ b/tests/test_update_manpages.py @@ -20,9 +20,9 @@ from release import update_manpages class UpdateManpagesTest(unittest.TestCase): - """Tests the update-manpages code.""" + """Tests the update-manpages code.""" - def test_replace_regex(self): - """Check that replace_regex works.""" - data = '\n\033[1mSummary\033[m\n' - self.assertEqual(update_manpages.replace_regex(data),'\nSummary\n') + def test_replace_regex(self): + """Check that replace_regex works.""" + data = "\n\033[1mSummary\033[m\n" + self.assertEqual(update_manpages.replace_regex(data), "\nSummary\n") diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index ef879a5d..21fa094d 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -28,528 +28,615 @@ import wrapper def fixture(*paths): - """Return a path relative to tests/fixtures. - """ - return os.path.join(os.path.dirname(__file__), 'fixtures', *paths) + """Return a path relative to tests/fixtures.""" + return os.path.join(os.path.dirname(__file__), "fixtures", *paths) class RepoWrapperTestCase(unittest.TestCase): - """TestCase for the wrapper module.""" + """TestCase for the wrapper module.""" - def setUp(self): - """Load the wrapper module every time.""" - wrapper.Wrapper.cache_clear() - self.wrapper = wrapper.Wrapper() + def setUp(self): + """Load the wrapper module every time.""" + wrapper.Wrapper.cache_clear() + self.wrapper = wrapper.Wrapper() class RepoWrapperUnitTest(RepoWrapperTestCase): - """Tests helper functions in the repo wrapper - """ - - def test_version(self): - """Make sure _Version works.""" - with self.assertRaises(SystemExit) as e: - with mock.patch('sys.stdout', new_callable=StringIO) as stdout: - with mock.patch('sys.stderr', new_callable=StringIO) as stderr: - self.wrapper._Version() - self.assertEqual(0, e.exception.code) - self.assertEqual('', stderr.getvalue()) - self.assertIn('repo launcher version', stdout.getvalue()) - - def test_python_constraints(self): - """The launcher should never require newer than main.py.""" - self.assertGreaterEqual(main.MIN_PYTHON_VERSION_HARD, - self.wrapper.MIN_PYTHON_VERSION_HARD) - self.assertGreaterEqual(main.MIN_PYTHON_VERSION_SOFT, - self.wrapper.MIN_PYTHON_VERSION_SOFT) - # Make sure the versions are themselves in sync. - self.assertGreaterEqual(self.wrapper.MIN_PYTHON_VERSION_SOFT, - self.wrapper.MIN_PYTHON_VERSION_HARD) - - def test_init_parser(self): - """Make sure 'init' GetParser works.""" - parser = self.wrapper.GetParser(gitc_init=False) - opts, args = parser.parse_args([]) - self.assertEqual([], args) - self.assertIsNone(opts.manifest_url) - - def test_gitc_init_parser(self): - """Make sure 'gitc-init' GetParser works.""" - parser = self.wrapper.GetParser(gitc_init=True) - opts, args = parser.parse_args([]) - self.assertEqual([], args) - self.assertIsNone(opts.manifest_file) - - def test_get_gitc_manifest_dir_no_gitc(self): - """ - Test reading a missing gitc config file - """ - self.wrapper.GITC_CONFIG_FILE = fixture('missing_gitc_config') - val = self.wrapper.get_gitc_manifest_dir() - self.assertEqual(val, '') - - def test_get_gitc_manifest_dir(self): - """ - Test reading the gitc config file and parsing the directory - """ - self.wrapper.GITC_CONFIG_FILE = fixture('gitc_config') - val = self.wrapper.get_gitc_manifest_dir() - self.assertEqual(val, '/test/usr/local/google/gitc') - - def test_gitc_parse_clientdir_no_gitc(self): - """ - Test parsing the gitc clientdir without gitc running - """ - self.wrapper.GITC_CONFIG_FILE = fixture('missing_gitc_config') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/something'), None) - self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test'), 'test') - - def test_gitc_parse_clientdir(self): - """ - Test parsing the gitc clientdir - """ - self.wrapper.GITC_CONFIG_FILE = fixture('gitc_config') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/something'), None) - self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test'), 'test') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test/'), 'test') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test/extra'), 'test') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test'), 'test') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test/'), 'test') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test/extra'), - 'test') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/'), None) - self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/'), None) + """Tests helper functions in the repo wrapper""" + + def test_version(self): + """Make sure _Version works.""" + with self.assertRaises(SystemExit) as e: + with mock.patch("sys.stdout", new_callable=StringIO) as stdout: + with mock.patch("sys.stderr", new_callable=StringIO) as stderr: + self.wrapper._Version() + self.assertEqual(0, e.exception.code) + self.assertEqual("", stderr.getvalue()) + self.assertIn("repo launcher version", stdout.getvalue()) + + def test_python_constraints(self): + """The launcher should never require newer than main.py.""" + self.assertGreaterEqual( + main.MIN_PYTHON_VERSION_HARD, self.wrapper.MIN_PYTHON_VERSION_HARD + ) + self.assertGreaterEqual( + main.MIN_PYTHON_VERSION_SOFT, self.wrapper.MIN_PYTHON_VERSION_SOFT + ) + # Make sure the versions are themselves in sync. + self.assertGreaterEqual( + self.wrapper.MIN_PYTHON_VERSION_SOFT, + self.wrapper.MIN_PYTHON_VERSION_HARD, + ) + + def test_init_parser(self): + """Make sure 'init' GetParser works.""" + parser = self.wrapper.GetParser(gitc_init=False) + opts, args = parser.parse_args([]) + self.assertEqual([], args) + self.assertIsNone(opts.manifest_url) + + def test_gitc_init_parser(self): + """Make sure 'gitc-init' GetParser works.""" + parser = self.wrapper.GetParser(gitc_init=True) + opts, args = parser.parse_args([]) + self.assertEqual([], args) + self.assertIsNone(opts.manifest_file) + + def test_get_gitc_manifest_dir_no_gitc(self): + """ + Test reading a missing gitc config file + """ + self.wrapper.GITC_CONFIG_FILE = fixture("missing_gitc_config") + val = self.wrapper.get_gitc_manifest_dir() + self.assertEqual(val, "") + + def test_get_gitc_manifest_dir(self): + """ + Test reading the gitc config file and parsing the directory + """ + self.wrapper.GITC_CONFIG_FILE = fixture("gitc_config") + val = self.wrapper.get_gitc_manifest_dir() + self.assertEqual(val, "/test/usr/local/google/gitc") + + def test_gitc_parse_clientdir_no_gitc(self): + """ + Test parsing the gitc clientdir without gitc running + """ + self.wrapper.GITC_CONFIG_FILE = fixture("missing_gitc_config") + self.assertEqual(self.wrapper.gitc_parse_clientdir("/something"), None) + self.assertEqual( + self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/test"), "test" + ) + + def test_gitc_parse_clientdir(self): + """ + Test parsing the gitc clientdir + """ + self.wrapper.GITC_CONFIG_FILE = fixture("gitc_config") + self.assertEqual(self.wrapper.gitc_parse_clientdir("/something"), None) + self.assertEqual( + self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/test"), "test" + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/test/"), "test" + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/test/extra"), + "test", + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir( + "/test/usr/local/google/gitc/test" + ), + "test", + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir( + "/test/usr/local/google/gitc/test/" + ), + "test", + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir( + "/test/usr/local/google/gitc/test/extra" + ), + "test", + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/"), None + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir("/test/usr/local/google/gitc/"), + None, + ) class SetGitTrace2ParentSid(RepoWrapperTestCase): - """Check SetGitTrace2ParentSid behavior.""" - - KEY = 'GIT_TRACE2_PARENT_SID' - VALID_FORMAT = re.compile(r'^repo-[0-9]{8}T[0-9]{6}Z-P[0-9a-f]{8}$') - - def test_first_set(self): - """Test env var not yet set.""" - env = {} - self.wrapper.SetGitTrace2ParentSid(env) - self.assertIn(self.KEY, env) - value = env[self.KEY] - self.assertRegex(value, self.VALID_FORMAT) - - def test_append(self): - """Test env var is appended.""" - env = {self.KEY: 'pfx'} - self.wrapper.SetGitTrace2ParentSid(env) - self.assertIn(self.KEY, env) - value = env[self.KEY] - self.assertTrue(value.startswith('pfx/')) - self.assertRegex(value[4:], self.VALID_FORMAT) - - def test_global_context(self): - """Check os.environ gets updated by default.""" - os.environ.pop(self.KEY, None) - self.wrapper.SetGitTrace2ParentSid() - self.assertIn(self.KEY, os.environ) - value = os.environ[self.KEY] - self.assertRegex(value, self.VALID_FORMAT) + """Check SetGitTrace2ParentSid behavior.""" + + KEY = "GIT_TRACE2_PARENT_SID" + VALID_FORMAT = re.compile(r"^repo-[0-9]{8}T[0-9]{6}Z-P[0-9a-f]{8}$") + + def test_first_set(self): + """Test env var not yet set.""" + env = {} + self.wrapper.SetGitTrace2ParentSid(env) + self.assertIn(self.KEY, env) + value = env[self.KEY] + self.assertRegex(value, self.VALID_FORMAT) + + def test_append(self): + """Test env var is appended.""" + env = {self.KEY: "pfx"} + self.wrapper.SetGitTrace2ParentSid(env) + self.assertIn(self.KEY, env) + value = env[self.KEY] + self.assertTrue(value.startswith("pfx/")) + self.assertRegex(value[4:], self.VALID_FORMAT) + + def test_global_context(self): + """Check os.environ gets updated by default.""" + os.environ.pop(self.KEY, None) + self.wrapper.SetGitTrace2ParentSid() + self.assertIn(self.KEY, os.environ) + value = os.environ[self.KEY] + self.assertRegex(value, self.VALID_FORMAT) class RunCommand(RepoWrapperTestCase): - """Check run_command behavior.""" + """Check run_command behavior.""" - def test_capture(self): - """Check capture_output handling.""" - ret = self.wrapper.run_command(['echo', 'hi'], capture_output=True) - # echo command appends OS specific linesep, but on Windows + Git Bash - # we get UNIX ending, so we allow both. - self.assertIn(ret.stdout, ['hi' + os.linesep, 'hi\n']) + def test_capture(self): + """Check capture_output handling.""" + ret = self.wrapper.run_command(["echo", "hi"], capture_output=True) + # echo command appends OS specific linesep, but on Windows + Git Bash + # we get UNIX ending, so we allow both. + self.assertIn(ret.stdout, ["hi" + os.linesep, "hi\n"]) - def test_check(self): - """Check check handling.""" - self.wrapper.run_command(['true'], check=False) - self.wrapper.run_command(['true'], check=True) - self.wrapper.run_command(['false'], check=False) - with self.assertRaises(self.wrapper.RunError): - self.wrapper.run_command(['false'], check=True) + def test_check(self): + """Check check handling.""" + self.wrapper.run_command(["true"], check=False) + self.wrapper.run_command(["true"], check=True) + self.wrapper.run_command(["false"], check=False) + with self.assertRaises(self.wrapper.RunError): + self.wrapper.run_command(["false"], check=True) class RunGit(RepoWrapperTestCase): - """Check run_git behavior.""" + """Check run_git behavior.""" - def test_capture(self): - """Check capture_output handling.""" - ret = self.wrapper.run_git('--version') - self.assertIn('git', ret.stdout) + def test_capture(self): + """Check capture_output handling.""" + ret = self.wrapper.run_git("--version") + self.assertIn("git", ret.stdout) - def test_check(self): - """Check check handling.""" - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper.run_git('--version-asdfasdf') - self.wrapper.run_git('--version-asdfasdf', check=False) + def test_check(self): + """Check check handling.""" + with self.assertRaises(self.wrapper.CloneFailure): + self.wrapper.run_git("--version-asdfasdf") + self.wrapper.run_git("--version-asdfasdf", check=False) class ParseGitVersion(RepoWrapperTestCase): - """Check ParseGitVersion behavior.""" - - def test_autoload(self): - """Check we can load the version from the live git.""" - ret = self.wrapper.ParseGitVersion() - self.assertIsNotNone(ret) - - def test_bad_ver(self): - """Check handling of bad git versions.""" - ret = self.wrapper.ParseGitVersion(ver_str='asdf') - self.assertIsNone(ret) - - def test_normal_ver(self): - """Check handling of normal git versions.""" - ret = self.wrapper.ParseGitVersion(ver_str='git version 2.25.1') - self.assertEqual(2, ret.major) - self.assertEqual(25, ret.minor) - self.assertEqual(1, ret.micro) - self.assertEqual('2.25.1', ret.full) - - def test_extended_ver(self): - """Check handling of extended distro git versions.""" - ret = self.wrapper.ParseGitVersion( - ver_str='git version 1.30.50.696.g5e7596f4ac-goog') - self.assertEqual(1, ret.major) - self.assertEqual(30, ret.minor) - self.assertEqual(50, ret.micro) - self.assertEqual('1.30.50.696.g5e7596f4ac-goog', ret.full) + """Check ParseGitVersion behavior.""" + + def test_autoload(self): + """Check we can load the version from the live git.""" + ret = self.wrapper.ParseGitVersion() + self.assertIsNotNone(ret) + + def test_bad_ver(self): + """Check handling of bad git versions.""" + ret = self.wrapper.ParseGitVersion(ver_str="asdf") + self.assertIsNone(ret) + + def test_normal_ver(self): + """Check handling of normal git versions.""" + ret = self.wrapper.ParseGitVersion(ver_str="git version 2.25.1") + self.assertEqual(2, ret.major) + self.assertEqual(25, ret.minor) + self.assertEqual(1, ret.micro) + self.assertEqual("2.25.1", ret.full) + + def test_extended_ver(self): + """Check handling of extended distro git versions.""" + ret = self.wrapper.ParseGitVersion( + ver_str="git version 1.30.50.696.g5e7596f4ac-goog" + ) + self.assertEqual(1, ret.major) + self.assertEqual(30, ret.minor) + self.assertEqual(50, ret.micro) + self.assertEqual("1.30.50.696.g5e7596f4ac-goog", ret.full) class CheckGitVersion(RepoWrapperTestCase): - """Check _CheckGitVersion behavior.""" - - def test_unknown(self): - """Unknown versions should abort.""" - with mock.patch.object(self.wrapper, 'ParseGitVersion', return_value=None): - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper._CheckGitVersion() - - def test_old(self): - """Old versions should abort.""" - with mock.patch.object( - self.wrapper, 'ParseGitVersion', - return_value=self.wrapper.GitVersion(1, 0, 0, '1.0.0')): - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper._CheckGitVersion() - - def test_new(self): - """Newer versions should run fine.""" - with mock.patch.object( - self.wrapper, 'ParseGitVersion', - return_value=self.wrapper.GitVersion(100, 0, 0, '100.0.0')): - self.wrapper._CheckGitVersion() + """Check _CheckGitVersion behavior.""" + + def test_unknown(self): + """Unknown versions should abort.""" + with mock.patch.object( + self.wrapper, "ParseGitVersion", return_value=None + ): + with self.assertRaises(self.wrapper.CloneFailure): + self.wrapper._CheckGitVersion() + + def test_old(self): + """Old versions should abort.""" + with mock.patch.object( + self.wrapper, + "ParseGitVersion", + return_value=self.wrapper.GitVersion(1, 0, 0, "1.0.0"), + ): + with self.assertRaises(self.wrapper.CloneFailure): + self.wrapper._CheckGitVersion() + + def test_new(self): + """Newer versions should run fine.""" + with mock.patch.object( + self.wrapper, + "ParseGitVersion", + return_value=self.wrapper.GitVersion(100, 0, 0, "100.0.0"), + ): + self.wrapper._CheckGitVersion() class Requirements(RepoWrapperTestCase): - """Check Requirements handling.""" - - def test_missing_file(self): - """Don't crash if the file is missing (old version).""" - testdir = os.path.dirname(os.path.realpath(__file__)) - self.assertIsNone(self.wrapper.Requirements.from_dir(testdir)) - self.assertIsNone(self.wrapper.Requirements.from_file( - os.path.join(testdir, 'xxxxxxxxxxxxxxxxxxxxxxxx'))) - - def test_corrupt_data(self): - """If the file can't be parsed, don't blow up.""" - self.assertIsNone(self.wrapper.Requirements.from_file(__file__)) - self.assertIsNone(self.wrapper.Requirements.from_data(b'x')) - - def test_valid_data(self): - """Make sure we can parse the file we ship.""" - self.assertIsNotNone(self.wrapper.Requirements.from_data(b'{}')) - rootdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) - self.assertIsNotNone(self.wrapper.Requirements.from_dir(rootdir)) - self.assertIsNotNone(self.wrapper.Requirements.from_file(os.path.join( - rootdir, 'requirements.json'))) - - def test_format_ver(self): - """Check format_ver can format.""" - self.assertEqual('1.2.3', self.wrapper.Requirements._format_ver((1, 2, 3))) - self.assertEqual('1', self.wrapper.Requirements._format_ver([1])) - - def test_assert_all_unknown(self): - """Check assert_all works with incompatible file.""" - reqs = self.wrapper.Requirements({}) - reqs.assert_all() - - def test_assert_all_new_repo(self): - """Check assert_all accepts new enough repo.""" - reqs = self.wrapper.Requirements({'repo': {'hard': [1, 0]}}) - reqs.assert_all() - - def test_assert_all_old_repo(self): - """Check assert_all rejects old repo.""" - reqs = self.wrapper.Requirements({'repo': {'hard': [99999, 0]}}) - with self.assertRaises(SystemExit): - reqs.assert_all() - - def test_assert_all_new_python(self): - """Check assert_all accepts new enough python.""" - reqs = self.wrapper.Requirements({'python': {'hard': sys.version_info}}) - reqs.assert_all() - - def test_assert_all_old_python(self): - """Check assert_all rejects old python.""" - reqs = self.wrapper.Requirements({'python': {'hard': [99999, 0]}}) - with self.assertRaises(SystemExit): - reqs.assert_all() - - def test_assert_ver_unknown(self): - """Check assert_ver works with incompatible file.""" - reqs = self.wrapper.Requirements({}) - reqs.assert_ver('xxx', (1, 0)) - - def test_assert_ver_new(self): - """Check assert_ver allows new enough versions.""" - reqs = self.wrapper.Requirements({'git': {'hard': [1, 0], 'soft': [2, 0]}}) - reqs.assert_ver('git', (1, 0)) - reqs.assert_ver('git', (1, 5)) - reqs.assert_ver('git', (2, 0)) - reqs.assert_ver('git', (2, 5)) - - def test_assert_ver_old(self): - """Check assert_ver rejects old versions.""" - reqs = self.wrapper.Requirements({'git': {'hard': [1, 0], 'soft': [2, 0]}}) - with self.assertRaises(SystemExit): - reqs.assert_ver('git', (0, 5)) + """Check Requirements handling.""" + + def test_missing_file(self): + """Don't crash if the file is missing (old version).""" + testdir = os.path.dirname(os.path.realpath(__file__)) + self.assertIsNone(self.wrapper.Requirements.from_dir(testdir)) + self.assertIsNone( + self.wrapper.Requirements.from_file( + os.path.join(testdir, "xxxxxxxxxxxxxxxxxxxxxxxx") + ) + ) + + def test_corrupt_data(self): + """If the file can't be parsed, don't blow up.""" + self.assertIsNone(self.wrapper.Requirements.from_file(__file__)) + self.assertIsNone(self.wrapper.Requirements.from_data(b"x")) + + def test_valid_data(self): + """Make sure we can parse the file we ship.""" + self.assertIsNotNone(self.wrapper.Requirements.from_data(b"{}")) + rootdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + self.assertIsNotNone(self.wrapper.Requirements.from_dir(rootdir)) + self.assertIsNotNone( + self.wrapper.Requirements.from_file( + os.path.join(rootdir, "requirements.json") + ) + ) + + def test_format_ver(self): + """Check format_ver can format.""" + self.assertEqual( + "1.2.3", self.wrapper.Requirements._format_ver((1, 2, 3)) + ) + self.assertEqual("1", self.wrapper.Requirements._format_ver([1])) + + def test_assert_all_unknown(self): + """Check assert_all works with incompatible file.""" + reqs = self.wrapper.Requirements({}) + reqs.assert_all() + + def test_assert_all_new_repo(self): + """Check assert_all accepts new enough repo.""" + reqs = self.wrapper.Requirements({"repo": {"hard": [1, 0]}}) + reqs.assert_all() + + def test_assert_all_old_repo(self): + """Check assert_all rejects old repo.""" + reqs = self.wrapper.Requirements({"repo": {"hard": [99999, 0]}}) + with self.assertRaises(SystemExit): + reqs.assert_all() + + def test_assert_all_new_python(self): + """Check assert_all accepts new enough python.""" + reqs = self.wrapper.Requirements({"python": {"hard": sys.version_info}}) + reqs.assert_all() + + def test_assert_all_old_python(self): + """Check assert_all rejects old python.""" + reqs = self.wrapper.Requirements({"python": {"hard": [99999, 0]}}) + with self.assertRaises(SystemExit): + reqs.assert_all() + + def test_assert_ver_unknown(self): + """Check assert_ver works with incompatible file.""" + reqs = self.wrapper.Requirements({}) + reqs.assert_ver("xxx", (1, 0)) + + def test_assert_ver_new(self): + """Check assert_ver allows new enough versions.""" + reqs = self.wrapper.Requirements( + {"git": {"hard": [1, 0], "soft": [2, 0]}} + ) + reqs.assert_ver("git", (1, 0)) + reqs.assert_ver("git", (1, 5)) + reqs.assert_ver("git", (2, 0)) + reqs.assert_ver("git", (2, 5)) + + def test_assert_ver_old(self): + """Check assert_ver rejects old versions.""" + reqs = self.wrapper.Requirements( + {"git": {"hard": [1, 0], "soft": [2, 0]}} + ) + with self.assertRaises(SystemExit): + reqs.assert_ver("git", (0, 5)) class NeedSetupGnuPG(RepoWrapperTestCase): - """Check NeedSetupGnuPG behavior.""" - - def test_missing_dir(self): - """The ~/.repoconfig tree doesn't exist yet.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - self.wrapper.home_dot_repo = os.path.join(tempdir, 'foo') - self.assertTrue(self.wrapper.NeedSetupGnuPG()) - - def test_missing_keyring(self): - """The keyring-version file doesn't exist yet.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - self.wrapper.home_dot_repo = tempdir - self.assertTrue(self.wrapper.NeedSetupGnuPG()) - - def test_empty_keyring(self): - """The keyring-version file exists, but is empty.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - self.wrapper.home_dot_repo = tempdir - with open(os.path.join(tempdir, 'keyring-version'), 'w'): - pass - self.assertTrue(self.wrapper.NeedSetupGnuPG()) - - def test_old_keyring(self): - """The keyring-version file exists, but it's old.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - self.wrapper.home_dot_repo = tempdir - with open(os.path.join(tempdir, 'keyring-version'), 'w') as fp: - fp.write('1.0\n') - self.assertTrue(self.wrapper.NeedSetupGnuPG()) - - def test_new_keyring(self): - """The keyring-version file exists, and is up-to-date.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - self.wrapper.home_dot_repo = tempdir - with open(os.path.join(tempdir, 'keyring-version'), 'w') as fp: - fp.write('1000.0\n') - self.assertFalse(self.wrapper.NeedSetupGnuPG()) + """Check NeedSetupGnuPG behavior.""" + + def test_missing_dir(self): + """The ~/.repoconfig tree doesn't exist yet.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + self.wrapper.home_dot_repo = os.path.join(tempdir, "foo") + self.assertTrue(self.wrapper.NeedSetupGnuPG()) + + def test_missing_keyring(self): + """The keyring-version file doesn't exist yet.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + self.wrapper.home_dot_repo = tempdir + self.assertTrue(self.wrapper.NeedSetupGnuPG()) + + def test_empty_keyring(self): + """The keyring-version file exists, but is empty.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + self.wrapper.home_dot_repo = tempdir + with open(os.path.join(tempdir, "keyring-version"), "w"): + pass + self.assertTrue(self.wrapper.NeedSetupGnuPG()) + + def test_old_keyring(self): + """The keyring-version file exists, but it's old.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + self.wrapper.home_dot_repo = tempdir + with open(os.path.join(tempdir, "keyring-version"), "w") as fp: + fp.write("1.0\n") + self.assertTrue(self.wrapper.NeedSetupGnuPG()) + + def test_new_keyring(self): + """The keyring-version file exists, and is up-to-date.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + self.wrapper.home_dot_repo = tempdir + with open(os.path.join(tempdir, "keyring-version"), "w") as fp: + fp.write("1000.0\n") + self.assertFalse(self.wrapper.NeedSetupGnuPG()) class SetupGnuPG(RepoWrapperTestCase): - """Check SetupGnuPG behavior.""" - - def test_full(self): - """Make sure it works completely.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - self.wrapper.home_dot_repo = tempdir - self.wrapper.gpg_dir = os.path.join(self.wrapper.home_dot_repo, 'gnupg') - self.assertTrue(self.wrapper.SetupGnuPG(True)) - with open(os.path.join(tempdir, 'keyring-version'), 'r') as fp: - data = fp.read() - self.assertEqual('.'.join(str(x) for x in self.wrapper.KEYRING_VERSION), - data.strip()) + """Check SetupGnuPG behavior.""" + + def test_full(self): + """Make sure it works completely.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + self.wrapper.home_dot_repo = tempdir + self.wrapper.gpg_dir = os.path.join( + self.wrapper.home_dot_repo, "gnupg" + ) + self.assertTrue(self.wrapper.SetupGnuPG(True)) + with open(os.path.join(tempdir, "keyring-version"), "r") as fp: + data = fp.read() + self.assertEqual( + ".".join(str(x) for x in self.wrapper.KEYRING_VERSION), + data.strip(), + ) class VerifyRev(RepoWrapperTestCase): - """Check verify_rev behavior.""" - - def test_verify_passes(self): - """Check when we have a valid signed tag.""" - desc_result = self.wrapper.RunResult(0, 'v1.0\n', '') - gpg_result = self.wrapper.RunResult(0, '', '') - with mock.patch.object(self.wrapper, 'run_git', - side_effect=(desc_result, gpg_result)): - ret = self.wrapper.verify_rev('/', 'refs/heads/stable', '1234', True) - self.assertEqual('v1.0^0', ret) - - def test_unsigned_commit(self): - """Check we fall back to signed tag when we have an unsigned commit.""" - desc_result = self.wrapper.RunResult(0, 'v1.0-10-g1234\n', '') - gpg_result = self.wrapper.RunResult(0, '', '') - with mock.patch.object(self.wrapper, 'run_git', - side_effect=(desc_result, gpg_result)): - ret = self.wrapper.verify_rev('/', 'refs/heads/stable', '1234', True) - self.assertEqual('v1.0^0', ret) - - def test_verify_fails(self): - """Check we fall back to signed tag when we have an unsigned commit.""" - desc_result = self.wrapper.RunResult(0, 'v1.0-10-g1234\n', '') - gpg_result = Exception - with mock.patch.object(self.wrapper, 'run_git', - side_effect=(desc_result, gpg_result)): - with self.assertRaises(Exception): - self.wrapper.verify_rev('/', 'refs/heads/stable', '1234', True) + """Check verify_rev behavior.""" + + def test_verify_passes(self): + """Check when we have a valid signed tag.""" + desc_result = self.wrapper.RunResult(0, "v1.0\n", "") + gpg_result = self.wrapper.RunResult(0, "", "") + with mock.patch.object( + self.wrapper, "run_git", side_effect=(desc_result, gpg_result) + ): + ret = self.wrapper.verify_rev( + "/", "refs/heads/stable", "1234", True + ) + self.assertEqual("v1.0^0", ret) + + def test_unsigned_commit(self): + """Check we fall back to signed tag when we have an unsigned commit.""" + desc_result = self.wrapper.RunResult(0, "v1.0-10-g1234\n", "") + gpg_result = self.wrapper.RunResult(0, "", "") + with mock.patch.object( + self.wrapper, "run_git", side_effect=(desc_result, gpg_result) + ): + ret = self.wrapper.verify_rev( + "/", "refs/heads/stable", "1234", True + ) + self.assertEqual("v1.0^0", ret) + + def test_verify_fails(self): + """Check we fall back to signed tag when we have an unsigned commit.""" + desc_result = self.wrapper.RunResult(0, "v1.0-10-g1234\n", "") + gpg_result = Exception + with mock.patch.object( + self.wrapper, "run_git", side_effect=(desc_result, gpg_result) + ): + with self.assertRaises(Exception): + self.wrapper.verify_rev("/", "refs/heads/stable", "1234", True) class GitCheckoutTestCase(RepoWrapperTestCase): - """Tests that use a real/small git checkout.""" - - GIT_DIR = None - REV_LIST = None - - @classmethod - def setUpClass(cls): - # Create a repo to operate on, but do it once per-class. - cls.tempdirobj = tempfile.TemporaryDirectory(prefix='repo-rev-tests') - cls.GIT_DIR = cls.tempdirobj.name - run_git = wrapper.Wrapper().run_git - - remote = os.path.join(cls.GIT_DIR, 'remote') - os.mkdir(remote) - - # Tests need to assume, that main is default branch at init, - # which is not supported in config until 2.28. - if git_command.git_require((2, 28, 0)): - initstr = '--initial-branch=main' - else: - # Use template dir for init. - templatedir = tempfile.mkdtemp(prefix='.test-template') - with open(os.path.join(templatedir, 'HEAD'), 'w') as fp: - fp.write('ref: refs/heads/main\n') - initstr = '--template=' + templatedir - - run_git('init', initstr, cwd=remote) - run_git('commit', '--allow-empty', '-minit', cwd=remote) - run_git('branch', 'stable', cwd=remote) - run_git('tag', 'v1.0', cwd=remote) - run_git('commit', '--allow-empty', '-m2nd commit', cwd=remote) - cls.REV_LIST = run_git('rev-list', 'HEAD', cwd=remote).stdout.splitlines() - - run_git('init', cwd=cls.GIT_DIR) - run_git('fetch', remote, '+refs/heads/*:refs/remotes/origin/*', cwd=cls.GIT_DIR) - - @classmethod - def tearDownClass(cls): - if not cls.tempdirobj: - return - - cls.tempdirobj.cleanup() + """Tests that use a real/small git checkout.""" + + GIT_DIR = None + REV_LIST = None + + @classmethod + def setUpClass(cls): + # Create a repo to operate on, but do it once per-class. + cls.tempdirobj = tempfile.TemporaryDirectory(prefix="repo-rev-tests") + cls.GIT_DIR = cls.tempdirobj.name + run_git = wrapper.Wrapper().run_git + + remote = os.path.join(cls.GIT_DIR, "remote") + os.mkdir(remote) + + # Tests need to assume, that main is default branch at init, + # which is not supported in config until 2.28. + if git_command.git_require((2, 28, 0)): + initstr = "--initial-branch=main" + else: + # Use template dir for init. + templatedir = tempfile.mkdtemp(prefix=".test-template") + with open(os.path.join(templatedir, "HEAD"), "w") as fp: + fp.write("ref: refs/heads/main\n") + initstr = "--template=" + templatedir + + run_git("init", initstr, cwd=remote) + run_git("commit", "--allow-empty", "-minit", cwd=remote) + run_git("branch", "stable", cwd=remote) + run_git("tag", "v1.0", cwd=remote) + run_git("commit", "--allow-empty", "-m2nd commit", cwd=remote) + cls.REV_LIST = run_git( + "rev-list", "HEAD", cwd=remote + ).stdout.splitlines() + + run_git("init", cwd=cls.GIT_DIR) + run_git( + "fetch", + remote, + "+refs/heads/*:refs/remotes/origin/*", + cwd=cls.GIT_DIR, + ) + + @classmethod + def tearDownClass(cls): + if not cls.tempdirobj: + return + + cls.tempdirobj.cleanup() class ResolveRepoRev(GitCheckoutTestCase): - """Check resolve_repo_rev behavior.""" - - def test_explicit_branch(self): - """Check refs/heads/branch argument.""" - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/heads/stable') - self.assertEqual('refs/heads/stable', rrev) - self.assertEqual(self.REV_LIST[1], lrev) - - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/heads/unknown') - - def test_explicit_tag(self): - """Check refs/tags/tag argument.""" - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/tags/v1.0') - self.assertEqual('refs/tags/v1.0', rrev) - self.assertEqual(self.REV_LIST[1], lrev) - - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/tags/unknown') - - def test_branch_name(self): - """Check branch argument.""" - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'stable') - self.assertEqual('refs/heads/stable', rrev) - self.assertEqual(self.REV_LIST[1], lrev) - - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'main') - self.assertEqual('refs/heads/main', rrev) - self.assertEqual(self.REV_LIST[0], lrev) - - def test_tag_name(self): - """Check tag argument.""" - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'v1.0') - self.assertEqual('refs/tags/v1.0', rrev) - self.assertEqual(self.REV_LIST[1], lrev) - - def test_full_commit(self): - """Check specific commit argument.""" - commit = self.REV_LIST[0] - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit) - self.assertEqual(commit, rrev) - self.assertEqual(commit, lrev) - - def test_partial_commit(self): - """Check specific (partial) commit argument.""" - commit = self.REV_LIST[0][0:20] - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit) - self.assertEqual(self.REV_LIST[0], rrev) - self.assertEqual(self.REV_LIST[0], lrev) - - def test_unknown(self): - """Check unknown ref/commit argument.""" - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper.resolve_repo_rev(self.GIT_DIR, 'boooooooya') + """Check resolve_repo_rev behavior.""" + + def test_explicit_branch(self): + """Check refs/heads/branch argument.""" + rrev, lrev = self.wrapper.resolve_repo_rev( + self.GIT_DIR, "refs/heads/stable" + ) + self.assertEqual("refs/heads/stable", rrev) + self.assertEqual(self.REV_LIST[1], lrev) + + with self.assertRaises(self.wrapper.CloneFailure): + self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/heads/unknown") + + def test_explicit_tag(self): + """Check refs/tags/tag argument.""" + rrev, lrev = self.wrapper.resolve_repo_rev( + self.GIT_DIR, "refs/tags/v1.0" + ) + self.assertEqual("refs/tags/v1.0", rrev) + self.assertEqual(self.REV_LIST[1], lrev) + + with self.assertRaises(self.wrapper.CloneFailure): + self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/tags/unknown") + + def test_branch_name(self): + """Check branch argument.""" + rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "stable") + self.assertEqual("refs/heads/stable", rrev) + self.assertEqual(self.REV_LIST[1], lrev) + + rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "main") + self.assertEqual("refs/heads/main", rrev) + self.assertEqual(self.REV_LIST[0], lrev) + + def test_tag_name(self): + """Check tag argument.""" + rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "v1.0") + self.assertEqual("refs/tags/v1.0", rrev) + self.assertEqual(self.REV_LIST[1], lrev) + + def test_full_commit(self): + """Check specific commit argument.""" + commit = self.REV_LIST[0] + rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit) + self.assertEqual(commit, rrev) + self.assertEqual(commit, lrev) + + def test_partial_commit(self): + """Check specific (partial) commit argument.""" + commit = self.REV_LIST[0][0:20] + rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit) + self.assertEqual(self.REV_LIST[0], rrev) + self.assertEqual(self.REV_LIST[0], lrev) + + def test_unknown(self): + """Check unknown ref/commit argument.""" + with self.assertRaises(self.wrapper.CloneFailure): + self.wrapper.resolve_repo_rev(self.GIT_DIR, "boooooooya") class CheckRepoVerify(RepoWrapperTestCase): - """Check check_repo_verify behavior.""" + """Check check_repo_verify behavior.""" - def test_no_verify(self): - """Always fail with --no-repo-verify.""" - self.assertFalse(self.wrapper.check_repo_verify(False)) + def test_no_verify(self): + """Always fail with --no-repo-verify.""" + self.assertFalse(self.wrapper.check_repo_verify(False)) - def test_gpg_initialized(self): - """Should pass if gpg is setup already.""" - with mock.patch.object(self.wrapper, 'NeedSetupGnuPG', return_value=False): - self.assertTrue(self.wrapper.check_repo_verify(True)) + def test_gpg_initialized(self): + """Should pass if gpg is setup already.""" + with mock.patch.object( + self.wrapper, "NeedSetupGnuPG", return_value=False + ): + self.assertTrue(self.wrapper.check_repo_verify(True)) - def test_need_gpg_setup(self): - """Should pass/fail based on gpg setup.""" - with mock.patch.object(self.wrapper, 'NeedSetupGnuPG', return_value=True): - with mock.patch.object(self.wrapper, 'SetupGnuPG') as m: - m.return_value = True - self.assertTrue(self.wrapper.check_repo_verify(True)) + def test_need_gpg_setup(self): + """Should pass/fail based on gpg setup.""" + with mock.patch.object( + self.wrapper, "NeedSetupGnuPG", return_value=True + ): + with mock.patch.object(self.wrapper, "SetupGnuPG") as m: + m.return_value = True + self.assertTrue(self.wrapper.check_repo_verify(True)) - m.return_value = False - self.assertFalse(self.wrapper.check_repo_verify(True)) + m.return_value = False + self.assertFalse(self.wrapper.check_repo_verify(True)) class CheckRepoRev(GitCheckoutTestCase): - """Check check_repo_rev behavior.""" - - def test_verify_works(self): - """Should pass when verification passes.""" - with mock.patch.object(self.wrapper, 'check_repo_verify', return_value=True): - with mock.patch.object(self.wrapper, 'verify_rev', return_value='12345'): - rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, 'stable') - self.assertEqual('refs/heads/stable', rrev) - self.assertEqual('12345', lrev) - - def test_verify_fails(self): - """Should fail when verification fails.""" - with mock.patch.object(self.wrapper, 'check_repo_verify', return_value=True): - with mock.patch.object(self.wrapper, 'verify_rev', side_effect=Exception): - with self.assertRaises(Exception): - self.wrapper.check_repo_rev(self.GIT_DIR, 'stable') - - def test_verify_ignore(self): - """Should pass when verification is disabled.""" - with mock.patch.object(self.wrapper, 'verify_rev', side_effect=Exception): - rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, 'stable', repo_verify=False) - self.assertEqual('refs/heads/stable', rrev) - self.assertEqual(self.REV_LIST[1], lrev) + """Check check_repo_rev behavior.""" + + def test_verify_works(self): + """Should pass when verification passes.""" + with mock.patch.object( + self.wrapper, "check_repo_verify", return_value=True + ): + with mock.patch.object( + self.wrapper, "verify_rev", return_value="12345" + ): + rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, "stable") + self.assertEqual("refs/heads/stable", rrev) + self.assertEqual("12345", lrev) + + def test_verify_fails(self): + """Should fail when verification fails.""" + with mock.patch.object( + self.wrapper, "check_repo_verify", return_value=True + ): + with mock.patch.object( + self.wrapper, "verify_rev", side_effect=Exception + ): + with self.assertRaises(Exception): + self.wrapper.check_repo_rev(self.GIT_DIR, "stable") + + def test_verify_ignore(self): + """Should pass when verification is disabled.""" + with mock.patch.object( + self.wrapper, "verify_rev", side_effect=Exception + ): + rrev, lrev = self.wrapper.check_repo_rev( + self.GIT_DIR, "stable", repo_verify=False + ) + self.assertEqual("refs/heads/stable", rrev) + self.assertEqual(self.REV_LIST[1], lrev) -- cgit v1.2.3-54-g00ecf