summaryrefslogtreecommitdiffstats
path: root/tests/test_subcmds_forall.py
blob: 84744f89371c16db33d7eae8f2c0bfb057a380ab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright (C) 2024 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unittests for the forall subcmd."""

from io import StringIO
import os
from shutil import rmtree
import subprocess
import tempfile
import unittest
from unittest import mock

import git_command
import manifest_xml
import project
import subcmds


class AllCommands(unittest.TestCase):
    """Check registered all_commands."""

    def setUp(self):
        """Common setup."""
        self.tempdirobj = tempfile.TemporaryDirectory(prefix="forall_tests")
        self.tempdir = self.tempdirobj.name
        self.repodir = os.path.join(self.tempdir, ".repo")
        self.manifest_dir = os.path.join(self.repodir, "manifests")
        self.manifest_file = os.path.join(
            self.repodir, manifest_xml.MANIFEST_FILE_NAME
        )
        self.local_manifest_dir = os.path.join(
            self.repodir, manifest_xml.LOCAL_MANIFESTS_DIR_NAME
        )
        os.mkdir(self.repodir)
        os.mkdir(self.manifest_dir)

    def tearDown(self):
        """Common teardown."""
        rmtree(self.tempdir, ignore_errors=True)

    def initTempGitTree(self, git_dir):
        """Create a new empty git checkout for testing."""

        # Tests need to assume, that main is default branch at init,
        # which is not supported in config until 2.28.
        cmd = ["git", "init", "-q"]
        if git_command.git_require((2, 28, 0)):
            cmd += ["--initial-branch=main"]
        else:
            # Use template dir for init
            templatedir = os.path.join(self.tempdirobj.name, ".test-template")
            os.makedirs(templatedir)
            with open(os.path.join(templatedir, "HEAD"), "w") as fp:
                fp.write("ref: refs/heads/main\n")
            cmd += ["--template", templatedir]
        cmd += [git_dir]
        subprocess.check_call(cmd)

    def getXmlManifestWith8Projects(self):
        """Create and return a setup of 8 projects with enough dummy
        files and setup to execute forall."""

        # Set up a manifest git dir for parsing to work
        gitdir = os.path.join(self.repodir, "manifests.git")
        os.mkdir(gitdir)
        with open(os.path.join(gitdir, "config"), "w") as fp:
            fp.write(
                """[remote "origin"]
                    url = https://localhost:0/manifest
                    verbose = false
                """
            )

        # Add the manifest data
        manifest_data = """
                <manifest>
                    <remote name="origin" fetch="http://localhost" />
                    <default remote="origin" revision="refs/heads/main" />
                    <project name="project1" path="tests/path1" />
                    <project name="project2" path="tests/path2" />
                    <project name="project3" path="tests/path3" />
                    <project name="project4" path="tests/path4" />
                    <project name="project5" path="tests/path5" />
                    <project name="project6" path="tests/path6" />
                    <project name="project7" path="tests/path7" />
                    <project name="project8" path="tests/path8" />
                </manifest>
            """
        with open(self.manifest_file, "w", encoding="utf-8") as fp:
            fp.write(manifest_data)

        # Set up 8 empty projects to match the manifest
        for x in range(1, 9):
            os.makedirs(
                os.path.join(
                    self.repodir, "projects/tests/path" + str(x) + ".git"
                )
            )
            os.makedirs(
                os.path.join(
                    self.repodir, "project-objects/project" + str(x) + ".git"
                )
            )
            git_path = os.path.join(self.tempdir, "tests/path" + str(x))
            self.initTempGitTree(git_path)

        return manifest_xml.XmlManifest(self.repodir, self.manifest_file)

    # Use mock to capture stdout from the forall run
    @unittest.mock.patch("sys.stdout", new_callable=StringIO)
    def test_forall_all_projects_called_once(self, mock_stdout):
        """Test that all projects get a command run once each."""

        manifest_with_8_projects = self.getXmlManifestWith8Projects()

        cmd = subcmds.forall.Forall()
        cmd.manifest = manifest_with_8_projects

        # Use echo project names as the test of forall
        opts, args = cmd.OptionParser.parse_args(["-c", "echo $REPO_PROJECT"])
        opts.verbose = False

        # Mock to not have the Execute fail on remote check
        with mock.patch.object(
            project.Project, "GetRevisionId", return_value="refs/heads/main"
        ):
            # Run the forall command
            cmd.Execute(opts, args)

            # Verify that we got every project name in the prints
            for x in range(1, 9):
                self.assertIn("project" + str(x), mock_stdout.getvalue())

            # Split the captured output into lines to count them
            line_count = 0
            for line in mock_stdout.getvalue().split("\n"):
                # A commented out print to stderr as a reminder
                # that stdout is mocked, include sys and uncomment if needed
                # print(line, file=sys.stderr)
                if len(line) > 0:
                    line_count += 1

            # Verify that we didn't get more lines than expected
            assert line_count == 8