#!/usr/bin/env python3
import argparse
import os
import re
from typing import Any, Dict, List, Tuple

import yaml


def main() -> None:
    parser = argparse.ArgumentParser()

    parser.add_argument("--repo", help="Path to the repository", default=".")
    parser.add_argument(
        "--commit",
        help="Commit the changes",
        default=False,
        action="store_true",
    )

    subparsers = parser.add_subparsers()
    subparsers.required = True

    parser_merge = subparsers.add_parser(
        "merge", help="Merge a reference clang-tidy config"
    )
    parser_merge.add_argument(
        "--reference", help="Path to reference clang-tidy", required=True
    )
    parser_merge.set_defaults(func=subcmd_merge)

    parser_format = subparsers.add_parser(
        "format", help="Format a clang-tidy config"
    )
    parser_format.set_defaults(func=subcmd_merge)

    parser_enable = subparsers.add_parser(
        "enable", help="Enable a rule in a reference clang-tidy config"
    )
    parser_enable.add_argument("check", help="Check to enable")
    parser_enable.set_defaults(func=subcmd_enable)

    parser_disable = subparsers.add_parser(
        "disable", help="Enable a rule in a reference clang-tidy config"
    )
    parser_disable.add_argument("check", help="Check to disable")
    parser_disable.add_argument(
        "--drop", help="Delete the check from the config", action="store_true"
    )
    parser_disable.set_defaults(func=subcmd_disable)

    args = parser.parse_args()
    args.func(args)


def subcmd_merge(args: argparse.Namespace) -> None:
    repo_path, repo_config = load_config(args.repo)
    _, ref_config = (
        load_config(args.reference) if "reference" in args else ("", {})
    )

    result = {}

    all_keys_set = set(repo_config.keys()) | set(ref_config.keys())
    special_keys = ["Checks", "CheckOptions"]

    # Create ordered_keys: special keys first (if present, in their defined order),
    # followed by the rest of the keys sorted alphabetically.
    ordered_keys = [k for k in special_keys if k in all_keys_set] + sorted(
        list(all_keys_set - set(special_keys))
    )

    for key in ordered_keys:
        repo_value = repo_config.get(key)
        ref_value = ref_config.get(key)

        key_class = globals().get(f"Key_{key}")
        if key_class and hasattr(key_class, "merge"):
            result[key] = key_class.merge(repo_value, ref_value)
        elif repo_value:
            result[key] = repo_value
        else:
            result[key] = ref_value

    with open(repo_path, "w") as f:
        f.write(format_yaml_output(result))


def subcmd_enable(args: argparse.Namespace) -> None:
    repo_path, repo_config = load_config(args.repo)

    if "Checks" in repo_config:
        repo_config["Checks"] = Key_Checks.enable(
            repo_config["Checks"], args.check
        )

    with open(repo_path, "w") as f:
        f.write(format_yaml_output(repo_config))

    pass


def subcmd_disable(args: argparse.Namespace) -> None:
    repo_path, repo_config = load_config(args.repo)

    if "Checks" in repo_config:
        repo_config["Checks"] = Key_Checks.disable(
            repo_config["Checks"], args.check, args.drop
        )

    if "CheckOptions" in repo_config:
        repo_config["CheckOptions"] = Key_CheckOptions.disable(
            repo_config["CheckOptions"], args.check, args.drop
        )

    with open(repo_path, "w") as f:
        f.write(format_yaml_output(repo_config))

    pass


class Key_Checks:
    @staticmethod
    def merge(repo: str, ref: str) -> str:
        repo_checks = Key_Checks._split(repo)
        ref_checks = Key_Checks._split(ref)

        result: Dict[str, bool] = {}

        for k, v in repo_checks.items():
            result[k] = v
        for k, v in ref_checks.items():
            if k not in result:
                result[k] = False

        return Key_Checks._join(result)

    @staticmethod
    def enable(repo: str, check: str) -> str:
        repo_checks = Key_Checks._split(repo)
        repo_checks[check] = True
        return Key_Checks._join(repo_checks)

    @staticmethod
    def disable(repo: str, check: str, drop: bool) -> str:
        repo_checks = Key_Checks._split(repo)
        if drop:
            repo_checks.pop(check, None)
        else:
            repo_checks[check] = False
        return Key_Checks._join(repo_checks)

    @staticmethod
    def _split(s: str) -> Dict[str, bool]:
        result: Dict[str, bool] = {}
        if not s:
            return result
        for item in s.split():
            item = item.replace(",", "")
            # Ignore global wildcard because we handle that specifically.
            if item.startswith("-*"):
                continue
            # Drop category wildcard disables since we already use a global wildcard.
            if item.startswith("-") and "*" in item:
                continue
            if item.startswith("-"):
                result[item[1:]] = False
            else:
                result[item] = True
        return result

    @staticmethod
    def _join(data: Dict[str, bool]) -> str:
        return (
            ",\n".join(
                ["-*"] + [k if v else f"-{k}" for k, v in sorted(data.items())]
            )
            + "\n"
        )


class Key_CheckOptions:
    @staticmethod
    def merge(
        repo: List[Dict[str, str]], ref: List[Dict[str, str]]
    ) -> List[Dict[str, str]]:
        unrolled_repo = Key_CheckOptions._unroll(repo)
        for item in ref or []:
            if item["key"] in unrolled_repo:
                continue
            unrolled_repo[item["key"]] = item["value"]

        return Key_CheckOptions._roll(unrolled_repo)

    @staticmethod
    def disable(
        repo: List[Dict[str, str]], option: str, drop: bool
    ) -> List[Dict[str, str]]:
        if not drop:
            return repo

        unrolled_repo = Key_CheckOptions._unroll(repo)

        if option in unrolled_repo:
            unrolled_repo.pop(option, None)

        return Key_CheckOptions._roll(unrolled_repo)

    @staticmethod
    def _unroll(repo: List[Dict[str, str]]) -> Dict[str, str]:
        unrolled_repo: Dict[str, str] = {}
        for item in repo or []:
            unrolled_repo[item["key"]] = item["value"]
        return unrolled_repo

    @staticmethod
    def _roll(data: Dict[str, str]) -> List[Dict[str, str]]:
        return [{"key": k, "value": v} for k, v in sorted(data.items())]


def load_config(path: str) -> Tuple[str, Dict[str, Any]]:
    if "clang-tidy" not in path:
        path = os.path.join(path, ".clang-tidy")

    if not os.path.exists(path):
        return (path, {})

    with open(path, "r") as f:
        data = "\n".join([x for x in f.readlines() if not x.startswith("#")])
        return (path, yaml.safe_load(data))


def format_yaml_output(data: Dict[str, Any]) -> str:
    """Convert to a prettier YAML string:
    - filter out excess empty lines
    - insert new lines between keys
    """
    yaml_string = yaml.dump(data, sort_keys=False, indent=4)
    lines: List[str] = []
    for line in yaml_string.split("\n"):
        # Strip excess new lines.
        if not line:
            continue
        # Add new line between keys.
        if len(lines) and re.match("[a-zA-Z0-9]+:", line):
            lines.append("")
        lines.append(line)
    lines.append("")

    return "\n".join(lines)


if __name__ == "__main__":
    main()
