"""Utility to extract flake output info using nix flake (show|check)."""

import json
import logging
import re
import shutil
import typing

from flupdt.common import bash_wrapper

output_regexes = [
    re.compile(r"checking derivation (.*)..."),
    re.compile(r"checking NixOS configuration \'(nixosConfigurations.*)\'\.\.\."),
]


def traverse_json_base(json_dict: dict[str, typing.Any], path: list[str]) -> list[str]:
    """Crawls through the flake outputs to get nixos-configuration and derivation types.

    :param json_dict: dict of flake outputs to check
    :param path: a list of outputs constructed so far
    :returns the output path list, plus any new paths found
    """
    final_paths = []
    for key, value in json_dict.items():
        if isinstance(value, dict):
            keys = value.keys()
            if "type" in keys and value["type"] in [
                "nixos-configuration",
                "derivation",
            ]:
                output = ".".join([*path, key])
                final_paths += [output]
            else:
                final_paths += traverse_json_base(value, [*path, key])
    return final_paths


def traverse_json(json_dict: dict) -> list[str]:
    """Crawls through the flake outputs to get nixos-configuration and derivation types.

    :param json_dict: dict of flake outputs to check
    :returns a list of outputs that can be evaluated
    """
    return traverse_json_base(json_dict, [])


def get_derivations_from_check(nix_path: str, path_to_flake: str) -> list[str]:
    """Gets all derivations in a flake, using check instead of show.

    :param nix_path: path to nix binary
    :param path_to_flake: path to flake to be checked
    :returns a list of all valid derivations in the flake
    """
    flake_check = bash_wrapper(f"{nix_path} flake check --verbose --keep-going", path=path_to_flake)
    if flake_check[2] != 0:
        logging.warning(
            "nix flake check returned non-zero exit code, collecting all available outputs"
        )
    error_out = flake_check[1].split("\n")
    possible_outputs = filter(lambda s: s.startswith("checking"), error_out)
    derivations = []
    for output in possible_outputs:
        for r in output_regexes:
            logging.debug(f"{output} {r.pattern}")
            match = r.match(output)
            if match is not None:
                logging.debug(match.groups())
                derivations += [match.groups()[0]]
    return derivations


def get_derivations(path_to_flake: str) -> list[str]:
    """Gets all derivations present in a flake.

    :param path_to_flake: path to flake to be checked
    :returns a list of all valid derivations in the flake
    :raises RuntimeError: fails if nix is not present in the PATH
    """
    nix_path = shutil.which("nix")
    derivations = []
    if nix_path is None:
        status_msg = "nix is not available in the PATH, please verify that it is installed"
        raise RuntimeError(status_msg)
    flake_show = bash_wrapper(f"{nix_path} flake show --json", path=path_to_flake)
    if flake_show[2] != 0:
        logging.error("flake show returned non-zero exit code")
        logging.warning("falling back to full evaluation via nix flake check")
        derivations = get_derivations_from_check(nix_path, path_to_flake)
    else:
        flake_show_json = json.loads(flake_show[0])
        derivations = traverse_json(flake_show_json)
    for i in range(len(derivations)):
        if derivations[i].startswith("nixosConfigurations"):
            derivations[i] += ".config.system.build.toplevel"
    return derivations