import argparse
import os
import itertools
import json

import numpy as np
from typing import Any
import pandas as pd
import traceback

from climate.src.run_experiment import run_experiment
import climate.src.climate_utils as utils

MAX_SEED = np.iinfo(np.int32).max
CONFIG_FILENAME = "options.json"
DATAFRAME_FILENAME = "grid_results.csv"
LOG_FILENAME = "output.log"

def load_options(base_folder: str) -> dict[str, Any]:
    options_path = os.path.join(base_folder, CONFIG_FILENAME)
    with open(options_path) as fin:
        options_dict = json.load(fin)
    return options_dict

def expand_options(options_dict: dict[str, Any]) -> list[dict[str, Any]]:
    keys, options = zip(*options_dict.items())
    config_dicts = []
    for values in itertools.product(*options):
        config_dict = dict(zip(keys, values))
        config_dicts.append(config_dict)
    return config_dicts



def run_search(base_folder: str, n_repeats: int, seed: int):

    logger = utils.get_logger(os.path.join(base_folder, LOG_FILENAME))

    options_dict = load_options(base_folder)
    config_dicts = expand_options(options_dict)
    n_runs = len(config_dicts) * n_repeats

    logger.info(f"Prepared {n_runs} runs ({n_repeats=}).")

    rng = np.random.default_rng(seed)
    results = []
    for config_id, config_dict in enumerate(config_dicts):
        for repeat_id in range(n_repeats):
            run_id = config_id * n_repeats + repeat_id + 1
            config_dict["run_id"] = run_id
            config_dict["seed"] = rng.integers(MAX_SEED)
            config = utils.ClimateConfig(**config_dict)

            try:
                output = run_experiment(config)

                # Choose information to log.
                log_dict = {"training_time": output.training_time}
                log_dict.update({f"metric_{label}": metric
                                 for label, metric in output.aggregate_metrics().items()})
                log_dict.update({f"prediction_time_{label}": time
                                 for label, time in output.prediction_times.items()})
                log_dict["config_id"] = config_id
                log_dict.update(vars(config))
                results.append(log_dict)
                logger.info(f"Run {run_id:>3}/{n_runs}: done.")
            except Exception as e:
                msg = f"Run {run_id}/{n_runs} failed with the following config:\n\n"
                msg += f"\t{config}\n\n\t"
                msg += "\n\t".join(traceback.format_exception(e))
                logger.error(msg)

    result_df = pd.DataFrame.from_records(results)
    result_df.set_index("run_id", inplace=True)
    result_df.to_csv(os.path.join(base_folder, DATAFRAME_FILENAME))



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--folder", type=str, help="Base folder with options.json for storing the grid search results.")
    parser.add_argument(
        "-r",
        "--n_repeats",
        type=int,
        help="Number of runs for each configuration.",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        help="Global random seed.",
    )

    args = parser.parse_args()
    run_search(args.folder, args.n_repeats, args.seed)
