blob: 9f81c4e65d775b91f698e9d243038e8203e8c652 [file] [log] [blame]
Patrick Williamsb5167292025-05-23 15:58:46 -04001#!/usr/bin/env python3
2import argparse
3import os
4import re
5from typing import Any, Dict, List, Tuple
6
7import yaml
8
9
10def main() -> None:
11 parser = argparse.ArgumentParser()
12
13 parser.add_argument("--repo", help="Path to the repository", default=".")
14 parser.add_argument(
15 "--commit",
16 help="Commit the changes",
17 default=False,
18 action="store_true",
19 )
20
21 subparsers = parser.add_subparsers()
22 subparsers.required = True
23
24 parser_merge = subparsers.add_parser(
25 "merge", help="Merge a reference clang-tidy config"
26 )
27 parser_merge.add_argument(
28 "--reference", help="Path to reference clang-tidy", required=True
29 )
30 parser_merge.set_defaults(func=subcmd_merge)
31
32 parser_enable = subparsers.add_parser(
33 "enable", help="Enable a rule in a reference clang-tidy config"
34 )
35 parser_enable.add_argument("check", help="Check to enable")
36 parser_enable.set_defaults(func=subcmd_enable)
37
38 parser_disable = subparsers.add_parser(
39 "disable", help="Enable a rule in a reference clang-tidy config"
40 )
41 parser_disable.add_argument("check", help="Check to disable")
42 parser_disable.add_argument(
43 "--drop", help="Delete the check from the config", action="store_true"
44 )
45 parser_disable.set_defaults(func=subcmd_disable)
46
47 args = parser.parse_args()
48 args.func(args)
49
50
51def subcmd_merge(args: argparse.Namespace) -> None:
52 repo_path, repo_config = load_config(args.repo)
53 ref_path, ref_config = load_config(args.reference)
54
55 result = {}
56
57 all_keys_set = set(repo_config.keys()) | set(ref_config.keys())
58 special_keys = ["Checks", "CheckOptions"]
59
60 # Create ordered_keys: special keys first (if present, in their defined order),
61 # followed by the rest of the keys sorted alphabetically.
62 ordered_keys = [k for k in special_keys if k in all_keys_set] + sorted(
63 list(all_keys_set - set(special_keys))
64 )
65
66 for key in ordered_keys:
67 repo_value = repo_config.get(key)
68 ref_value = ref_config.get(key)
69
70 key_class = globals().get(f"Key_{key}")
71 if key_class and hasattr(key_class, "merge"):
72 result[key] = key_class.merge(repo_value, ref_value)
73 elif repo_value:
74 result[key] = repo_value
75 else:
76 result[key] = ref_value
77
78 with open(repo_path, "w") as f:
79 f.write(format_yaml_output(result))
80
81
82def subcmd_enable(args: argparse.Namespace) -> None:
83 repo_path, repo_config = load_config(args.repo)
84
85 if "Checks" in repo_config:
86 repo_config["Checks"] = Key_Checks.enable(
87 repo_config["Checks"], args.check
88 )
89
90 with open(repo_path, "w") as f:
91 f.write(format_yaml_output(repo_config))
92
93 pass
94
95
96def subcmd_disable(args: argparse.Namespace) -> None:
97 repo_path, repo_config = load_config(args.repo)
98
99 if "Checks" in repo_config:
100 repo_config["Checks"] = Key_Checks.disable(
101 repo_config["Checks"], args.check, args.drop
102 )
103
104 with open(repo_path, "w") as f:
105 f.write(format_yaml_output(repo_config))
106
107 pass
108
109
110class Key_Checks:
111 @staticmethod
112 def merge(repo: str, ref: str) -> str:
113 repo_checks = Key_Checks._split(repo)
114 ref_checks = Key_Checks._split(ref)
115
116 result: Dict[str, bool] = {}
117
118 for k, v in repo_checks.items():
119 result[k] = v
120 for k, v in ref_checks.items():
121 if k not in result:
122 result[k] = False
123
124 return Key_Checks._join(result)
125
126 @staticmethod
127 def enable(repo: str, check: str) -> str:
128 repo_checks = Key_Checks._split(repo)
129 repo_checks[check] = True
130 return Key_Checks._join(repo_checks)
131
132 @staticmethod
133 def disable(repo: str, check: str, drop: bool) -> str:
134 repo_checks = Key_Checks._split(repo)
135 if drop:
136 repo_checks.pop(check, None)
137 else:
138 repo_checks[check] = False
139 return Key_Checks._join(repo_checks)
140
141 @staticmethod
142 def _split(s: str) -> Dict[str, bool]:
143 result: Dict[str, bool] = {}
144 if not s:
145 return result
146 for item in s.split():
147 item = item.replace(",", "")
Patrick Williamsb8ce3812025-05-30 11:01:55 -0400148 # Ignore global wildcard because we handle that specifically.
149 if item.startswith("-*"):
150 continue
151 # Drop category wildcard disables since we already use a global wildcard.
152 if item.startswith("-") and "*" in item:
Patrick Williamsb5167292025-05-23 15:58:46 -0400153 continue
154 if item.startswith("-"):
155 result[item[1:]] = False
156 else:
157 result[item] = True
158 return result
159
160 @staticmethod
161 def _join(data: Dict[str, bool]) -> str:
162 return (
163 ",\n".join(
164 ["-*"] + [k if v else f"-{k}" for k, v in sorted(data.items())]
165 )
166 + "\n"
167 )
168
169
170class Key_CheckOptions:
171 @staticmethod
172 def merge(
173 repo: List[Dict[str, str]], ref: List[Dict[str, str]]
174 ) -> List[Dict[str, str]]:
175 unrolled_repo: Dict[str, str] = {}
176 for item in repo or []:
177 unrolled_repo[item["key"]] = item["value"]
178 for item in ref or []:
179 if item["key"] in unrolled_repo:
180 continue
181 unrolled_repo[item["key"]] = item["value"]
182
183 return [
184 {"key": k, "value": v} for k, v in sorted(unrolled_repo.items())
185 ]
186
187
188def load_config(path: str) -> Tuple[str, Dict[str, Any]]:
189 if "clang-tidy" not in path:
190 path = os.path.join(path, ".clang-tidy")
191
192 if not os.path.exists(path):
193 return (path, {})
194
195 with open(path, "r") as f:
Patrick Williams23362862025-05-30 11:00:39 -0400196 data = "\n".join([x for x in f.readlines() if not x.startswith("#")])
197 return (path, yaml.safe_load(data))
Patrick Williamsb5167292025-05-23 15:58:46 -0400198
199
200def format_yaml_output(data: Dict[str, Any]) -> str:
201 """Convert to a prettier YAML string:
202 - filter out excess empty lines
203 - insert new lines between keys
204 """
205 yaml_string = yaml.dump(data, sort_keys=False, indent=4)
206 lines: List[str] = []
207 for line in yaml_string.split("\n"):
208 # Strip excess new lines.
209 if not line:
210 continue
211 # Add new line between keys.
212 if len(lines) and re.match("[a-zA-Z0-9]+:", line):
213 lines.append("")
214 lines.append(line)
215 lines.append("")
216
217 return "\n".join(lines)
218
219
220if __name__ == "__main__":
221 main()