import json
from pathlib import Path

import fire
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from loguru import logger

MARKERS = ["o", "^", "s", "v", "+", "D", "*", "p"]

plt.style.use("_classic_test_patch")

font = {'weight' : 'bold',
        'size'   : 14}

matplotlib.rc('font', **font)


class NoConvergedExperiment(Exception):
    pass


LABELS = {
    "fid": "FID",
    "generalization": "Generalization gap",
    "score_generalization": "Score-based generalization error",
    "ratio_bs_lr": r"$b / \eta$",
    "ratio_bs_lr2": r"$b / \eta^2$",
    "ratio_lr_bs": r"$\eta / b$",
    "batch_size": "Batch size",
    "learning_rate": "Learning rate",
    "train_losses": "Train loss",
    "test_losses": "Test loss",
    "train_score_losses": "Score-based train loss",
    "test_score_losses": "Score-based test loss",
    "E_alpha": r"$E^1(W^{(n)})$",
    "Positive_magnitude": r"$\mathbf{PMag}(W^{(n)})$",
    "Positive_magnitude_small": r"$\mathbf{PMag}(10^{-2} \cdot W^{(n)})$",
    "Positive_magnitude_sqrt_n": r"$\mathbf{PMag}(\sqrt{n} \cdot W^{(n)})$",
    "sgld_bound": r"$\sqrt{ \eta \langle \Vert \widehat{g}_k \Vert^2 \rangle / n}$",
    "gradient_norms": r"$\langle \Vert \widehat{g}_k \Vert^2 \rangle$",
    "gradient_norms_lr": r"$\eta \times \langle \Vert \widehat{g}_k \Vert^2 \rangle$",
    "gradient_norms_bs": r"$b \times \langle \Vert \widehat{g}_k \Vert^2 \rangle$",
    "test_wass": "Test Wasserstein metric",
    "fid_generalization": "FID generalization",
    "train_fid": "train FID",
    "worst_generalization": "worst score generalization gap"
}



def plot_from_dict(results: dict,
                    output_path: str,
                   generalization_key: str = "score_generalization",
                   complexity_key: str = "ratio_lr_bs",
                   yscale: str = "",
                   save: bool = True,
                   colorbar: bool = True,
                   ylabel: str = r"$\eta / b$"
                   ):

    experiments = [k for k in results.keys() if
                   generalization_key in results[k].keys() and
                   complexity_key in results[k].keys()
                   and results[k]["learning_rate"] < 1.e-3
                   ]
    if len(experiments) == 0:
        raise NoConvergedExperiment

    logger.info(f"Found {len(experiments)} converged experiments")

    # color_map = plt.cm.get_cmap("viridis_r")
    color_map = plt.cm.get_cmap("brg")

    all_bs = list(set([results[k]["batch_size"] for k in experiments]))

    # HACK
    sc = None

    all_dim = []
    all_gen = []

    for b_idx in np.argsort(all_bs):

        b = all_bs[b_idx]
        b_experiments = [k for k in experiments if results[k]["batch_size"] == b]

        dim_tab = [results[k][complexity_key] for k in b_experiments]
        gen_tab = [results[k][generalization_key] for k in b_experiments]

        # HACK change the key here to change what is in the color map
        COLORED_KEY = "learning_rate"

        # lr_tab = [results[k][COLORED_KEY] for k in b_experiments]
        lr_tab = [results[k][COLORED_KEY] for k in b_experiments]

        all_gen.append(gen_tab)
        all_dim.append(dim_tab)

        # HACK
        if [] in gen_tab: 
            logger.warning(f"Not full data for {generalization_key} and {complexity_key} - no plot generated")
            continue
        
        try:
            sc = plt.scatter(
                np.array(gen_tab),
                np.array(dim_tab),
                c=np.array(lr_tab),
                marker=MARKERS[b_idx],
                label=str(b),
                cmap=color_map,
                norm=matplotlib.colors.LogNorm()
            )
        except ValueError:
            logger.warning(f"no plot for {complexity_key} and {generalization_key}")
            continue
    
    if sc is None:
        return None

    if yscale == "log":
        plt.yscale("log")

    plt.xlabel(LABELS[generalization_key], fontweight="bold")
    plt.ylabel(LABELS[complexity_key], fontweight="bold")

    if colorbar:
        cbar = plt.colorbar(sc)
        cbar.set_label(LABELS[COLORED_KEY], fontweight="bold")

    plt.grid(visible=True, which="both")
    plt.legend(title="Batch size")

    try:
        plt.tight_layout()

        Path(output_path).parent.mkdir(parents=True, exist_ok=True)
        logger.info(f"Saving figure in {str(output_path)}")
        plt.savefig(str(output_path))

        plt.close()

    except ValueError:
        logger.warning(f"No figure saved for {str(output_path)}, try removing the log scale")
        plt.close()
        pass

    

def plot_from_json(json_path: str,
                   generalization_key: str = "learning_rate",
                   complexity_key: str = "ratio_bs_lr",
                   scale: str = "",
                   save: bool = True,
                   colorbar: bool = True,
                    ylabel: str = r"$\eta / b$"
                   ):

    json_path = Path(json_path)
    if not json_path.exists():
        raise FileNotFoundError(f"{str(json_path)} not found.")

    with open(str(json_path), "r") as json_file:
        results = json.load(json_file)

    if save:
        output_path = (json_path.parent / "figures" / \
            (generalization_key + "_" + complexity_key)).with_suffix(".pdf")
        output_path.parent.mkdir(parents=True, exist_ok=True)
    else:
        output_path = None

    # logger.warning(f"Only using one seed")
    plot_from_dict(
        results,
        str(output_path),
        generalization_key,
        complexity_key,
        scale,
        save,
        colorbar,
        ylabel
    )


if __name__ == "__main__":
    fire.Fire(plot_from_json)
