import json
from pathlib import Path

import fire
import numpy as np
import torch
from loguru import logger
from sklearn.metrics.pairwise import pairwise_distances
from tqdm import tqdm

from analysis.plot_E_alpha import E_alpha
from analysis.plot_magnitude_simple import positive_magnitude


RESULT_DIR = "temp/adam_22_04_2025/"
PATH = "mnist/new_eval_500_False/"

EVAL_KEYS = [
    'losses_batch',
    'score_losses_batch',
    'grad_norm'
]

EMA_KEYS = [
    'precision',
    'recall',
    'f_1_pr',
    'density',
    'coverage',
    'f_1_dc', 
    'fid',
    'train_fid',
    'train_losses',
    'train_score_losses',
    'test_losses',
    'test_score_losses',
    'train_wass',
    'test_wass'
]

def pt_to_json(result_dir: str = RESULT_DIR,
                topological: bool = True,
                several_seeds: bool = False,
                ema: bool=True):
    
    # HACK
    if several_seeds:
        logger.warning("several_seeds_activated, will only work for gmms")

    result_dir = Path(result_dir)
    assert result_dir.is_dir(), str(result_dir)

    exp_number = 0
    final_results = {}

    exp_list = [d for d in result_dir.glob("*") if d.is_dir()]
    logger.info(f"Found  {len(exp_list)} folders")

    for exp in tqdm(exp_list):

        # The lines below are very specific to the current experiment (names of the keys, ...)
        try:
            metadata = [f for f in exp.rglob("*.pt") if f.stem.startswith("parameters")][0]
            results = [f for f in exp.rglob("*.pt") if f.stem.startswith("eval")][0]
        except IndexError:
            logger.warning(f"there is an empty experiment folder {str(exp)}")
            continue

        exp_dict = {}
        metadata = torch.load(metadata)


        if several_seeds:
            # HACK, this is specific to the file names formats
            seed = exp.stem.split("_")[0]
            exp_dict.update({"seed": seed})

        exp_dict.update(
            {
                "learning_rate": metadata["optim"]['lr'],
                "batch_size": metadata["training"]['batch_size'],
                "n_samples": metadata["data"]["n_samples"]
            }
        )

        try:
            exp_dict.update(
                {
                    "temperature": metadata["optim"]['temperature']
                }
            )
        except KeyError:
            pass

        results_eval = torch.load(results)["eval"]
        if ("ema_evals" in torch.load(results).keys()) and ema:
            results_ema = torch.load(results)["ema_evals"]
            logger.debug("Using ema evaluations")
            for k in EMA_KEYS:
                try:
                    try:
                        exp_dict.update({k: results_ema[0][0][k][-1]})
                    except IndexError:
                        exp_dict.update({k: results_ema[0][0][k]})
                except KeyError:
                    continue
        # HACK to handle the simple datasets
        else:
            for k in EMA_KEYS:
                try:
                    try:
                        exp_dict.update({k: results_eval[k][-1]})
                    except IndexError:
                        exp_dict.update({k: results_eval[k]})
                except KeyError:
                    continue

        # HACK
        if type(exp_dict["test_wass"]) == float:
            exp_dict.update(
                {
                    "wasserstein_generalization": exp_dict["test_wass"] - exp_dict["train_wass"]
                }
            )
        
        # Now we add some keys in the results
        exp_dict.update(
            {
                "score_generalization": exp_dict["test_score_losses"] - exp_dict["train_score_losses"],
                "generalization": exp_dict["test_losses"] - exp_dict["train_losses"],
                "ratio_bs_lr": exp_dict["batch_size"] / exp_dict["learning_rate"],
                "ratio_bs_lr2": exp_dict["batch_size"] / (exp_dict["learning_rate"]**2),
                "ratio_lr_bs": exp_dict["learning_rate"] / exp_dict["batch_size"],
            }
        )

        try:
            exp_dict.update(
                {
                    "fid_generalization": exp_dict["fid"] -  exp_dict["train_fid"]
                }
            )
        except KeyError:
            pass
        except TypeError:
            pass

        # Gradient norms
        N_GRADIENTS = len(np.array(results_eval["grad_norm"])) 

        if "topological" in torch.load(results).keys():
            gn_array = torch.load(results)["topological"]["grad_norms"]
        else:
            gn_array = np.array(results_eval["grad_norm"])[-N_GRADIENTS:]

        if "temperature" in exp_dict.keys():
            beta = 1. / exp_dict["temperature"]
        else:
            beta = 1

        if len(gn_array) >= 1:
            gn_sgld = float(
                    np.sqrt((exp_dict["learning_rate"] * beta * gn_array * gn_array).mean() / exp_dict["n_samples"])
                )
            gn_raw = float((gn_array * gn_array).mean())
            gn_bs = gn_raw *  exp_dict["batch_size"]
            gn_lr = exp_dict["learning_rate"] * gn_raw
        else:
            gn_sgld, gn_bs, gn_raw, gn_lr = None, None, None, None
        
        exp_dict.update(
            {
                "gradient_norms": gn_raw,
                "sgld_bound": gn_sgld,
                "gradient_norms_bs": gn_bs,
                "gradient_norms_lr": gn_lr
            }
        )

        # Topological quantities
        if topological:

            if "topological" in torch.load(results).keys():
                results_topological = torch.load(results)["topological"]["losses"]

                # estimating the worst generalization error
                best_train_losss = results_topological.mean(dim=1).min()
                exp_dict.update(
                    {
                        "worst_generalization": exp_dict["train_losses"] - float(best_train_losss)
                    }
                )

                dist_matrix = pairwise_distances(
                    np.array(results_topological)[:, :],
                    metric="manhattan"
                        )

                assert dist_matrix.ndim == 2, dist_matrix.shape

                ealpha = E_alpha(dist_matrix)
                pmag_sqrt_n = positive_magnitude(dist_matrix, t=32.)
                pmag_small = positive_magnitude(dist_matrix, t=0.01)
                pmag = positive_magnitude(dist_matrix, t=1.)

                exp_dict.update(
                    {
                        "E_alpha": float(ealpha),
                        "Positive_magnitude": float(pmag),
                        "Positive_magnitude_small": float(pmag_small),
                        "Positive_magnitude_sqrt_n": float(pmag_sqrt_n)
                    }
                )

            else:
                logger.warning(f"No topological bounds in {str(results)}")

        final_results.update({exp_number: exp_dict})
        exp_number += 1

    # all_results
    output_path = result_dir / f"all_results.json"
    output_path.parent.mkdir(parents=True, exist_ok=True)

    logger.info(f"Collecting all results in {str(output_path)}")

    with open(str(output_path), "w") as output_file:
        json.dump(final_results, output_file, indent=2)



if __name__ == "__main__":
    fire.Fire(pt_to_json)