import json
from ast import literal_eval
from pathlib import Path

import fire
import numpy as np
import pandas as pd
from loguru import logger
from tqdm import tqdm


RESULT_DIR = "temp/adam_500_mnist_23_04_2025/"


# A useful function to transpose dictionary


def csv_to_json(result_dir: str = RESULT_DIR):

    result_dir = Path(result_dir)
    assert result_dir.is_dir(), str(result_dir)

    exp_number = 0
    final_results = {}

    exp_list = [d for d in result_dir.glob("*.csv")]
    logger.info(f"Found  {len(exp_list)} CSV file(s)")

    for exp in tqdm(exp_list):

        data_frame = pd.read_csv(str(exp)).transpose()
        results = data_frame.to_dict()

        all_results = {}

        for k in results.keys():

            all_results[k] = {}

            lr = results[k]["p_optim_lr"]
            bs = results[k]["p_training_batch_size"]
            train_losses = results[k]['ema_0.999_train_losses']
            test_losses = results[k]['ema_0.999_test_losses']
            train_score_losses = results[k]['ema_0.999_train_score_losses']
            test_score_losses = results[k]['ema_0.999_test_score_losses']
            logger.debug(results[k]['eval_grad_norm'][1:-1].split(" ")[-1])
            grad_norm = float(results[k]['eval_grad_norm'][1:-1].strip().split(" ")[-1].strip())
            fid = results[k]['ema_0.999_fid']

            # WARNING HACK
            if lr >= 0.001:
                continue
        
            all_results[k].update(
                {
                    "score_generalization": test_score_losses - train_score_losses,
                    "generalization": test_losses - train_losses,
                    "train_score_losses": train_score_losses,
                    "test_score_losses": test_score_losses,
                    "ratio_bs_lr": bs / lr,
                    "gradient_norms": grad_norm,
                    "sgld_bound": grad_norm * bs,
                    "learning_rate": lr,
                    "batch_size": bs,
                    "fid": fid
                }
            )

        # all_results
        output_path = result_dir / exp.stem / f"all_results.json"
        output_path.parent.mkdir(parents=True, exist_ok=True)

        logger.info(f"Collecting all results in {str(output_path)}")

        with open(str(output_path), "w") as output_file:
            json.dump(all_results, output_file, indent=2)

    
if __name__ == "__main__":
    fire.Fire(csv_to_json)