#!/usr/bin/env python3

"""Default processing of flake outputs for evaluating flake updates."""

import logging
import os
import sys
from argparse import Namespace
from pathlib import Path

from flupdt.cli import parse_inputs
from flupdt.common import configure_logger, partition
from flupdt.flake_build import build_output
from flupdt.flake_diff import compare_derivations
from flupdt.flake_eval import evaluate_output
from flupdt.flake_show import get_derivations

logger = logging.getLogger(__name__)


def batch_eval(args: Namespace, flake_path: str, derivations: list[str]) -> None:
    """Bulk run evaluations or builds on a derivation set.

    :params args: argument namespace to check against
    :params flake_path: path to flake to be evaluated
    :params derivations: list of derivations to run against
    :returns None
    """
    drv_map = {}
    for d in derivations:
        if args.evaluate:
            drv_map[d] = evaluate_output(flake_path, d)
        if args.build:
            drv_map[d] = build_output(flake_path, d)
    if args.json:
        with Path.open(args.json, "w+") as f:
            from json import dump

            dump(drv_map, f)
    if any(x is None for x in drv_map.values()):
        sys.exit(1)


def compare_drvs(args: Namespace) -> None:
    """Compares derivation jsons using nvd.

    param args: argparse namespace
    """
    pre_json_dict = {}
    post_json_dict = {}
    from json import load

    with (
        Path.open(args.compare_pre_json, "r") as pre,
        Path.open(args.compare_post_json, "r") as post,
    ):
        pre_json_dict = load(pre)
        post_json_dict = load(post)

    logger.debug(f"pre-snapshot derivations: {pre_json_dict}")
    logger.debug(f"post-snapshot derivations: {post_json_dict}")

    pre_json_keys = set(pre_json_dict.keys())
    post_json_keys = set(post_json_dict.keys())

    common_keys_to_eval = pre_json_keys.union(post_json_keys)

    missing_post_keys = pre_json_keys.difference(common_keys_to_eval)
    missing_pre_keys = post_json_keys.difference(common_keys_to_eval)

    if missing_pre_keys:
        logger.warning(f"Following outputs are missing from pre-snapshot: {missing_pre_keys}")
    if missing_post_keys:
        logger.warning(f"Following outputs are missing from post-snapshot: {missing_post_keys}")

    logger.info(f"Evaluating the following outputs for differences: {common_keys_to_eval}")

    out_file: str = os.devnull
    if args.compare_output_to_file:
        out_file = args.compare_output_file

    out_file_path = Path(out_file)
    with out_file_path.open("w") as f:
        for output_key in common_keys_to_eval:
            comp_out = compare_derivations(
                args.flake_path, pre_json_dict[output_key], post_json_dict[output_key]
            )
            f.write(f"comparing {output_key}:" + "\n")
            if comp_out:
                f.write(comp_out + "\n\n")
            else:
                f.write("comparison output is empty, please check script logs\n\n")


def build_or_eval(args: Namespace) -> None:
    """Builds or evaluates all outputs in a flake.

    param args: argparse namespace
    """
    flake_path = args.flake_path
    derivations, hydra_jobs = partition(
        lambda s: s.startswith("hydraJobs"), get_derivations(flake_path)
    )
    derivations, hydra_jobs = list(derivations), list(hydra_jobs)
    logger.info(f"derivations: {list(derivations)}")
    batch_eval(args, flake_path, derivations)

    if not args.keep_hydra:
        logger.info("--keep-hydra flag is not specified, removing Hydra jobs")
    else:
        batch_eval(args, flake_path, hydra_jobs)


def main() -> None:
    """Sets up logging, parses args, and runs evaluation routine.

    :returns: None

    """
    configure_logger("INFO")
    args = parse_inputs()
    if args.compare_drvs:
        compare_drvs(args)
    else:
        build_or_eval(args)


if __name__ == "__main__":
    main()