blob: f1aeb93e0b6911c1583e1228483878a236da1d88 [file] [log] [blame] [edit]
#!/usr/bin/env python3
import argparse
import os
import re
from typing import Any, Dict, List, Tuple
import yaml
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--repo", help="Path to the repository", default=".")
parser.add_argument(
"--commit",
help="Commit the changes",
default=False,
action="store_true",
)
subparsers = parser.add_subparsers()
subparsers.required = True
parser_merge = subparsers.add_parser(
"merge", help="Merge a reference clang-tidy config"
)
parser_merge.add_argument(
"--reference", help="Path to reference clang-tidy", required=True
)
parser_merge.set_defaults(func=subcmd_merge)
parser_enable = subparsers.add_parser(
"enable", help="Enable a rule in a reference clang-tidy config"
)
parser_enable.add_argument("check", help="Check to enable")
parser_enable.set_defaults(func=subcmd_enable)
parser_disable = subparsers.add_parser(
"disable", help="Enable a rule in a reference clang-tidy config"
)
parser_disable.add_argument("check", help="Check to disable")
parser_disable.add_argument(
"--drop", help="Delete the check from the config", action="store_true"
)
parser_disable.set_defaults(func=subcmd_disable)
args = parser.parse_args()
args.func(args)
def subcmd_merge(args: argparse.Namespace) -> None:
repo_path, repo_config = load_config(args.repo)
ref_path, ref_config = load_config(args.reference)
result = {}
all_keys_set = set(repo_config.keys()) | set(ref_config.keys())
special_keys = ["Checks", "CheckOptions"]
# Create ordered_keys: special keys first (if present, in their defined order),
# followed by the rest of the keys sorted alphabetically.
ordered_keys = [k for k in special_keys if k in all_keys_set] + sorted(
list(all_keys_set - set(special_keys))
)
for key in ordered_keys:
repo_value = repo_config.get(key)
ref_value = ref_config.get(key)
key_class = globals().get(f"Key_{key}")
if key_class and hasattr(key_class, "merge"):
result[key] = key_class.merge(repo_value, ref_value)
elif repo_value:
result[key] = repo_value
else:
result[key] = ref_value
with open(repo_path, "w") as f:
f.write(format_yaml_output(result))
def subcmd_enable(args: argparse.Namespace) -> None:
repo_path, repo_config = load_config(args.repo)
if "Checks" in repo_config:
repo_config["Checks"] = Key_Checks.enable(
repo_config["Checks"], args.check
)
with open(repo_path, "w") as f:
f.write(format_yaml_output(repo_config))
pass
def subcmd_disable(args: argparse.Namespace) -> None:
repo_path, repo_config = load_config(args.repo)
if "Checks" in repo_config:
repo_config["Checks"] = Key_Checks.disable(
repo_config["Checks"], args.check, args.drop
)
with open(repo_path, "w") as f:
f.write(format_yaml_output(repo_config))
pass
class Key_Checks:
@staticmethod
def merge(repo: str, ref: str) -> str:
repo_checks = Key_Checks._split(repo)
ref_checks = Key_Checks._split(ref)
result: Dict[str, bool] = {}
for k, v in repo_checks.items():
result[k] = v
for k, v in ref_checks.items():
if k not in result:
result[k] = False
return Key_Checks._join(result)
@staticmethod
def enable(repo: str, check: str) -> str:
repo_checks = Key_Checks._split(repo)
repo_checks[check] = True
return Key_Checks._join(repo_checks)
@staticmethod
def disable(repo: str, check: str, drop: bool) -> str:
repo_checks = Key_Checks._split(repo)
if drop:
repo_checks.pop(check, None)
else:
repo_checks[check] = False
return Key_Checks._join(repo_checks)
@staticmethod
def _split(s: str) -> Dict[str, bool]:
result: Dict[str, bool] = {}
if not s:
return result
for item in s.split():
item = item.replace(",", "")
if "-*" in item:
continue
if item.startswith("-"):
result[item[1:]] = False
else:
result[item] = True
return result
@staticmethod
def _join(data: Dict[str, bool]) -> str:
return (
",\n".join(
["-*"] + [k if v else f"-{k}" for k, v in sorted(data.items())]
)
+ "\n"
)
class Key_CheckOptions:
@staticmethod
def merge(
repo: List[Dict[str, str]], ref: List[Dict[str, str]]
) -> List[Dict[str, str]]:
unrolled_repo: Dict[str, str] = {}
for item in repo or []:
unrolled_repo[item["key"]] = item["value"]
for item in ref or []:
if item["key"] in unrolled_repo:
continue
unrolled_repo[item["key"]] = item["value"]
return [
{"key": k, "value": v} for k, v in sorted(unrolled_repo.items())
]
def load_config(path: str) -> Tuple[str, Dict[str, Any]]:
if "clang-tidy" not in path:
path = os.path.join(path, ".clang-tidy")
if not os.path.exists(path):
return (path, {})
with open(path, "r") as f:
return (path, yaml.safe_load(f))
def format_yaml_output(data: Dict[str, Any]) -> str:
"""Convert to a prettier YAML string:
- filter out excess empty lines
- insert new lines between keys
"""
yaml_string = yaml.dump(data, sort_keys=False, indent=4)
lines: List[str] = []
for line in yaml_string.split("\n"):
# Strip excess new lines.
if not line:
continue
# Add new line between keys.
if len(lines) and re.match("[a-zA-Z0-9]+:", line):
lines.append("")
lines.append(line)
lines.append("")
return "\n".join(lines)
if __name__ == "__main__":
main()