import json
from pathlib import Path

import fire
import numpy as np
from loguru import logger
from tqdm import tqdm


COMMON_KEYS = [
    "ratio_bs_lr",
    "ratio_lr_bs",
    "learning_rate",
    "batch_size",
    "n_samples",
    "temperature"
]

AVERAGE_KEYS = [
    "train_losses",
    "train_score_losses",
    "test_losses",
    "test_score_losses",
    "train_wass",
    "test_wass",
    "score_generalization",
    "generalization",
    "wasserstein_generalization",
    "sgld_bound",
    "gradient_norms",
    "gradient_norms_bs"
]


JOINED_KEYS = [
    "batch_size",
    "learning_rate",
    "n_samples",
    "temperature"
]


def transpose_by_seed(all_results_path: str):
    """
    the json file should contain experiments
    where seed is a key in each experiment
    """

    all_results_path = Path(all_results_path)
    if not all_results_path.exists():
        raise FileNotFoundError(f"{str(all_results_path)} not found.")

    with open(str(all_results_path), "r") as json_file:
        all_results = json.load(json_file)

    transposed_dict = {}

    for key in all_results.keys():

        assert "seed" in all_results[key].keys()
        seed = all_results[key]["seed"]

        if seed not in transposed_dict.keys():
            transposed_dict[seed] = {}

        transposed_dict[seed].update(
            {
                key: all_results[key]
            }
        )

    # all_results
    output_path = all_results_path.parent / f"transposed_all_results.json"

    logger.info(f"Collecting all results in {str(output_path)}")

    with open(str(output_path), "w") as output_file:
        json.dump(transposed_dict, output_file, indent=2)
    logger.info(f"Saving transposed JSON in {str(output_path)}")

    return transposed_dict
    

def average_results(all_results: dict) -> dict:
    """
    all_results contains the results of several dicts
    results are collected based on id_sigma and id_alpha
    """
    
    # we first construct a dict of lists
    dict_of_lists = {}

    for key_seed in tqdm(all_results.keys()):
        for key_exp in tqdm(all_results[key_seed].keys()):

            # HACK to remove unconverged experiments
            # if all_results[key_seed][key_exp]["train_wass"] >= 1.:
            #     continue

            key_id = "_".join(list([
                str(all_results[key_seed][key_exp][k]) for k in JOINED_KEYS
            ]))

            # creating the dict and adding the elements
            if key_id not in dict_of_lists.keys():
                dict_of_lists[key_id] = {}

            # All the following should not change with the seed
            for common_key in COMMON_KEYS:
                dict_of_lists[key_id][common_key] = all_results[key_seed][key_exp][common_key]

            for average_key in AVERAGE_KEYS:
                if average_key not in dict_of_lists[key_id].keys():
                    dict_of_lists[key_id][average_key] = []
                dict_of_lists[key_id][average_key].append(all_results[key_seed][key_exp][average_key])

    # turn the dict_of_lists into the desired dict
    for key_id in tqdm(dict_of_lists.keys()):

        for average_key in AVERAGE_KEYS:
            average_list = np.array(dict_of_lists[key_id][average_key])

            try:            
                m = average_list.mean()
                s = average_list.std()

                dict_of_lists[key_id][average_key] = m
                dict_of_lists[key_id][average_key + "_std"] = s
                dict_of_lists[key_id][average_key + "_list"] = list(average_list)
            except TypeError:
                pass

    return dict_of_lists


def average_from_json(json_path: str):

    json_path = Path(json_path)
    assert json_path.exists(), str(json_path)

    transposed_results = transpose_by_seed(json_path)

    averaged = average_results(transposed_results)

    # all_results
    output_path = json_path.parent / f"average_results.json"
    output_path.parent.mkdir(parents=True, exist_ok=True)

    logger.info(f"Collecting all results in {str(output_path)}")

    with open(str(output_path), "w") as output_file:
        json.dump(averaged, output_file, indent=2)


if __name__ == "__main__":
    fire.Fire(average_from_json)