import argparse
import itertools
import json
import os
import traceback
from typing import Any

import numpy as np
import pandas as pd

import computational_experiments.real_world.src.data_utils as utils
from computational_experiments.real_world.src.run_experiment import run_experiment

MAX_SEED = np.iinfo(np.int32).max
CONFIG_FILENAME = "options.json"
DATAFRAME_FILENAME = "grid_results.csv"
LOG_FILENAME = "output.log"


def load_options(base_folder: str) -> dict[str, Any]:
    options_path = os.path.join(base_folder, CONFIG_FILENAME)
    with open(options_path) as fin:
        options_dict = json.load(fin)
    return options_dict


def expand_options(options_dict: dict[str, Any]) -> list[dict[str, Any]]:
    keys, options = zip(*options_dict.items())
    config_dicts = []
    for values in itertools.product(*options):
        config_dict = dict(zip(keys, values))
        config_dicts.append(config_dict)
    return config_dicts


def run_search(base_folder: str, n_repeats: int, seed: int):
    logger = utils.get_logger(os.path.join(base_folder, LOG_FILENAME))

    options_dict = load_options(base_folder)
    config_dicts = expand_options(options_dict)
    n_runs = len(config_dicts) * n_repeats

    logger.info(f"Prepared {n_runs} runs ({n_repeats=}).")

    rng = np.random.default_rng(seed)
    results = []
    for config_id, config_dict in enumerate(config_dicts):
        for repeat_id in range(n_repeats):
            run_id = config_id * n_repeats + repeat_id + 1
            config_dict["run_id"] = run_id
            config_dict["seed"] = rng.integers(MAX_SEED)
            config = utils.DataConfig(**config_dict)

            try:
                if os.path.exists(os.path.join(config.output_folder, f"run_{run_id}")):
                    print(f"os.path.join(config.output_folder, run_{run_id}) exists, skipping run.")
                    continue

                output = run_experiment(config)

                # Choose information to log.
                log_dict = {"training_time": output.training_time}
                log_dict.update(
                    {
                        f"metric_{label}": metric
                        for label, metric in output.aggregate_metrics().items()
                    }
                )
                log_dict.update(
                    {
                        f"prediction_time_{label}": time
                        for label, time in output.prediction_times.items()
                    }
                )
                log_dict["config_id"] = config_id
                log_dict.update(vars(config))
                results.append(log_dict)
                logger.info(f"Run {run_id:>3}/{n_runs}: done.")
            except Exception as e:
                msg = f"Run {run_id}/{n_runs} failed with the following config:\n\n"
                msg += f"\t{config}\n\n\t"
                msg += "\n\t".join(traceback.format_exception(e, tb=None))
                logger.error(msg)

    result_df = pd.DataFrame.from_records(results)
    result_df.set_index("run_id", inplace=True)
    result_df.to_csv(os.path.join(base_folder, DATAFRAME_FILENAME))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-f",
        "--folder",
        type=str,
        help="Base folder with options.json for storing the grid search results.",
    )
    parser.add_argument(
        "-r",
        "--n_repeats",
        type=int,
        help="Number of runs for each configuration.",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        help="Global random seed.",
    )

    args = parser.parse_args()
    run_search(args.folder, args.n_repeats, args.seed)
