| #!/usr/bin/env python3 | 
 | import argparse | 
 | import os | 
 | import re | 
 | from typing import Any, Dict, List, Tuple | 
 |  | 
 | import yaml | 
 |  | 
 |  | 
 | def main() -> None: | 
 |     parser = argparse.ArgumentParser() | 
 |  | 
 |     parser.add_argument("--repo", help="Path to the repository", default=".") | 
 |     parser.add_argument( | 
 |         "--commit", | 
 |         help="Commit the changes", | 
 |         default=False, | 
 |         action="store_true", | 
 |     ) | 
 |  | 
 |     subparsers = parser.add_subparsers() | 
 |     subparsers.required = True | 
 |  | 
 |     parser_merge = subparsers.add_parser( | 
 |         "merge", help="Merge a reference clang-tidy config" | 
 |     ) | 
 |     parser_merge.add_argument( | 
 |         "--reference", help="Path to reference clang-tidy", required=True | 
 |     ) | 
 |     parser_merge.set_defaults(func=subcmd_merge) | 
 |  | 
 |     parser_format = subparsers.add_parser( | 
 |         "format", help="Format a clang-tidy config" | 
 |     ) | 
 |     parser_format.set_defaults(func=subcmd_merge) | 
 |  | 
 |     parser_enable = subparsers.add_parser( | 
 |         "enable", help="Enable a rule in a reference clang-tidy config" | 
 |     ) | 
 |     parser_enable.add_argument("check", help="Check to enable") | 
 |     parser_enable.set_defaults(func=subcmd_enable) | 
 |  | 
 |     parser_disable = subparsers.add_parser( | 
 |         "disable", help="Enable a rule in a reference clang-tidy config" | 
 |     ) | 
 |     parser_disable.add_argument("check", help="Check to disable") | 
 |     parser_disable.add_argument( | 
 |         "--drop", help="Delete the check from the config", action="store_true" | 
 |     ) | 
 |     parser_disable.set_defaults(func=subcmd_disable) | 
 |  | 
 |     args = parser.parse_args() | 
 |     args.func(args) | 
 |  | 
 |  | 
 | def subcmd_merge(args: argparse.Namespace) -> None: | 
 |     repo_path, repo_config = load_config(args.repo) | 
 |     _, ref_config = ( | 
 |         load_config(args.reference) if "reference" in args else ("", {}) | 
 |     ) | 
 |  | 
 |     result = {} | 
 |  | 
 |     all_keys_set = set(repo_config.keys()) | set(ref_config.keys()) | 
 |     special_keys = ["Checks", "CheckOptions"] | 
 |  | 
 |     # Create ordered_keys: special keys first (if present, in their defined order), | 
 |     # followed by the rest of the keys sorted alphabetically. | 
 |     ordered_keys = [k for k in special_keys if k in all_keys_set] + sorted( | 
 |         list(all_keys_set - set(special_keys)) | 
 |     ) | 
 |  | 
 |     for key in ordered_keys: | 
 |         repo_value = repo_config.get(key) | 
 |         ref_value = ref_config.get(key) | 
 |  | 
 |         key_class = globals().get(f"Key_{key}") | 
 |         if key_class and hasattr(key_class, "merge"): | 
 |             result[key] = key_class.merge(repo_value, ref_value) | 
 |         elif repo_value: | 
 |             result[key] = repo_value | 
 |         else: | 
 |             result[key] = ref_value | 
 |  | 
 |     with open(repo_path, "w") as f: | 
 |         f.write(format_yaml_output(result)) | 
 |  | 
 |  | 
 | def subcmd_enable(args: argparse.Namespace) -> None: | 
 |     repo_path, repo_config = load_config(args.repo) | 
 |  | 
 |     if "Checks" in repo_config: | 
 |         repo_config["Checks"] = Key_Checks.enable( | 
 |             repo_config["Checks"], args.check | 
 |         ) | 
 |  | 
 |     with open(repo_path, "w") as f: | 
 |         f.write(format_yaml_output(repo_config)) | 
 |  | 
 |     pass | 
 |  | 
 |  | 
 | def subcmd_disable(args: argparse.Namespace) -> None: | 
 |     repo_path, repo_config = load_config(args.repo) | 
 |  | 
 |     if "Checks" in repo_config: | 
 |         repo_config["Checks"] = Key_Checks.disable( | 
 |             repo_config["Checks"], args.check, args.drop | 
 |         ) | 
 |  | 
 |     if "CheckOptions" in repo_config: | 
 |         repo_config["CheckOptions"] = Key_CheckOptions.disable( | 
 |             repo_config["CheckOptions"], args.check, args.drop | 
 |         ) | 
 |  | 
 |     with open(repo_path, "w") as f: | 
 |         f.write(format_yaml_output(repo_config)) | 
 |  | 
 |     pass | 
 |  | 
 |  | 
 | class Key_Checks: | 
 |     @staticmethod | 
 |     def merge(repo: str, ref: str) -> str: | 
 |         repo_checks = Key_Checks._split(repo) | 
 |         ref_checks = Key_Checks._split(ref) | 
 |  | 
 |         result: Dict[str, bool] = {} | 
 |  | 
 |         for k, v in repo_checks.items(): | 
 |             result[k] = v | 
 |         for k, v in ref_checks.items(): | 
 |             if k not in result: | 
 |                 result[k] = False | 
 |  | 
 |         return Key_Checks._join(result) | 
 |  | 
 |     @staticmethod | 
 |     def enable(repo: str, check: str) -> str: | 
 |         repo_checks = Key_Checks._split(repo) | 
 |         repo_checks[check] = True | 
 |         return Key_Checks._join(repo_checks) | 
 |  | 
 |     @staticmethod | 
 |     def disable(repo: str, check: str, drop: bool) -> str: | 
 |         repo_checks = Key_Checks._split(repo) | 
 |         if drop: | 
 |             repo_checks.pop(check, None) | 
 |         else: | 
 |             repo_checks[check] = False | 
 |         return Key_Checks._join(repo_checks) | 
 |  | 
 |     @staticmethod | 
 |     def _split(s: str) -> Dict[str, bool]: | 
 |         result: Dict[str, bool] = {} | 
 |         if not s: | 
 |             return result | 
 |         for item in s.split(): | 
 |             item = item.replace(",", "") | 
 |             # Ignore global wildcard because we handle that specifically. | 
 |             if item.startswith("-*"): | 
 |                 continue | 
 |             # Drop category wildcard disables since we already use a global wildcard. | 
 |             if item.startswith("-") and "*" in item: | 
 |                 continue | 
 |             if item.startswith("-"): | 
 |                 result[item[1:]] = False | 
 |             else: | 
 |                 result[item] = True | 
 |         return result | 
 |  | 
 |     @staticmethod | 
 |     def _join(data: Dict[str, bool]) -> str: | 
 |         return ( | 
 |             ",\n".join( | 
 |                 ["-*"] + [k if v else f"-{k}" for k, v in sorted(data.items())] | 
 |             ) | 
 |             + "\n" | 
 |         ) | 
 |  | 
 |  | 
 | class Key_CheckOptions: | 
 |     @staticmethod | 
 |     def merge( | 
 |         repo: List[Dict[str, str]], ref: List[Dict[str, str]] | 
 |     ) -> List[Dict[str, str]]: | 
 |         unrolled_repo = Key_CheckOptions._unroll(repo) | 
 |         for item in ref or []: | 
 |             if item["key"] in unrolled_repo: | 
 |                 continue | 
 |             unrolled_repo[item["key"]] = item["value"] | 
 |  | 
 |         return Key_CheckOptions._roll(unrolled_repo) | 
 |  | 
 |     @staticmethod | 
 |     def disable( | 
 |         repo: List[Dict[str, str]], option: str, drop: bool | 
 |     ) -> List[Dict[str, str]]: | 
 |         if not drop: | 
 |             return repo | 
 |  | 
 |         unrolled_repo = Key_CheckOptions._unroll(repo) | 
 |  | 
 |         if option in unrolled_repo: | 
 |             unrolled_repo.pop(option, None) | 
 |  | 
 |         return Key_CheckOptions._roll(unrolled_repo) | 
 |  | 
 |     @staticmethod | 
 |     def _unroll(repo: List[Dict[str, str]]) -> Dict[str, str]: | 
 |         unrolled_repo: Dict[str, str] = {} | 
 |         for item in repo or []: | 
 |             unrolled_repo[item["key"]] = item["value"] | 
 |         return unrolled_repo | 
 |  | 
 |     @staticmethod | 
 |     def _roll(data: Dict[str, str]) -> List[Dict[str, str]]: | 
 |         return [{"key": k, "value": v} for k, v in sorted(data.items())] | 
 |  | 
 |  | 
 | def load_config(path: str) -> Tuple[str, Dict[str, Any]]: | 
 |     if "clang-tidy" not in path: | 
 |         path = os.path.join(path, ".clang-tidy") | 
 |  | 
 |     if not os.path.exists(path): | 
 |         return (path, {}) | 
 |  | 
 |     with open(path, "r") as f: | 
 |         data = "\n".join([x for x in f.readlines() if not x.startswith("#")]) | 
 |         return (path, yaml.safe_load(data)) | 
 |  | 
 |  | 
 | def format_yaml_output(data: Dict[str, Any]) -> str: | 
 |     """Convert to a prettier YAML string: | 
 |     - filter out excess empty lines | 
 |     - insert new lines between keys | 
 |     """ | 
 |     yaml_string = yaml.dump(data, sort_keys=False, indent=4) | 
 |     lines: List[str] = [] | 
 |     for line in yaml_string.split("\n"): | 
 |         # Strip excess new lines. | 
 |         if not line: | 
 |             continue | 
 |         # Add new line between keys. | 
 |         if len(lines) and re.match("[a-zA-Z0-9]+:", line): | 
 |             lines.append("") | 
 |         lines.append(line) | 
 |     lines.append("") | 
 |  | 
 |     return "\n".join(lines) | 
 |  | 
 |  | 
 | if __name__ == "__main__": | 
 |     main() |