import json
from pathlib import Path

import fire
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import seaborn as sns
from loguru import logger

MARKERS = ["o", "^", "s", "v", "+", "D", "*", "p"]

plt.style.use("_classic_test_patch")

font = {'weight' : 'bold',
        'size'   : 20
        }

matplotlib.rc('font', **font)


class NoConvergedExperiment(Exception):
    pass


LABELS = {
    "fid": "FID",
    "generalization": "Generalization gap",
    "score_generalization": "Score-based generalization error",
    "ratio_bs_lr": r"$b / \eta$",
    "ratio_bs_lr2": r"$b / \eta^2$",
    "ratio_lr_bs": r"$\eta / b$",
    "batch_size": "Batch size",
    "learning_rate": "Learning rate",
    "train_losses": "Train loss",
    "test_losses": "Test loss",
    "train_score_losses": "Score-based train loss",
    "test_score_losses": "Score-based test loss",
    "E_alpha": r"$E^1(W^{(n)})$",
    "Positive_magnitude": r"$\mathbf{PMag}(W^{(n)})$",
    "Positive_magnitude_small": r"$\mathbf{PMag}(10^{-2} \cdot W^{(n)})$",
    "Positive_magnitude_sqrt_n": r"$\mathbf{PMag}(\sqrt{n} \cdot W^{(n)})$",
    "sgld_bound": r"$\sqrt{ \eta \langle \Vert \widehat{g}_k \Vert^2 \rangle / n}$",
    "gradient_norms": r"$\langle \Vert \widehat{g}_k \Vert^2 \rangle$",
    "gradient_norms_lr": r"$\eta \times \langle \Vert \widehat{g}_k \Vert^2 \rangle$",
    "gradient_norms_bs": r"$b \times \langle \Vert \widehat{g}_k \Vert^2 \rangle$",
    "test_wass": "Test Wasserstein metric",
    "fid_generalization": "FID generalization",
    "train_fid": "train FID",
    "worst_generalization": "worst score generalization gap"
}



def plot_from_dict(results: dict,
                   generalization_key: str = "score_generalization",
                   complexity_key: str = "ratio_lr_bs",
                   yscale: str = "",
                   save: bool = True,
                   colorbar: bool = True,
                   ylabel: str = r"$\eta / b$",
                   xaxis=True,
                   legend: bool=False
                   ):

    experiments = [k for k in results.keys() if
                   generalization_key in results[k].keys() and
                   complexity_key in results[k].keys()
                   ]
    
    if len(experiments) == 0:
        raise NoConvergedExperiment

    logger.info(f"Found {len(experiments)} converged experiments")

    color_map = plt.cm.get_cmap("brg")

    all_bs = list(set([results[k]["batch_size"] for k in experiments]))

    # HACK
    sc = None

    all_dim = []
    all_gen = []

    for b_idx in np.argsort(all_bs):

        b = all_bs[b_idx]
        b_experiments = [k for k in experiments if results[k]["batch_size"] == b]

        dim_tab = [results[k][complexity_key] for k in b_experiments]
        gen_tab = [results[k][generalization_key] for k in b_experiments]

        COLORED_KEY = "learning_rate"

        lr_tab = [results[k][COLORED_KEY] for k in b_experiments]

        all_gen.append(gen_tab)
        all_dim.append(dim_tab)

        # HACK
        if [] in gen_tab: 
            logger.warning(f"Not full data for {generalization_key} and {complexity_key} - no plot generated")
            continue
        
        try:
            sc = plt.scatter(
                np.array(gen_tab),
                np.array(dim_tab),
                c=np.array(lr_tab),
                marker=MARKERS[b_idx],
                label=str(b),
                cmap=color_map,
                norm=matplotlib.colors.LogNorm(),
                s=60
            )
        except ValueError:
            logger.warning(f"no plot for {complexity_key} and {generalization_key}")
            continue
    
    if sc is None:
        return None

    if yscale == "log":
        plt.yscale("log")

    ax = plt.gca()

    ax.yaxis.set_minor_formatter(mticker.LogFormatter())
    ax.xaxis.set_minor_formatter(mticker.LogFormatterSciNotation())

    if xaxis:
        plt.xlabel(LABELS[generalization_key], fontweight="bold")
    plt.ylabel(LABELS[complexity_key], fontweight="bold")

    if colorbar:
        cbar = plt.colorbar(sc)
        cbar.set_label(LABELS[COLORED_KEY], fontweight="bold")

    plt.grid(visible=True, which="both")

    if legend:
        plt.legend(title="Batch size")


    plt.tight_layout()

    

XY_LIST = [
    ("generalization", "gradient_norms_bs"),
    ("generalization", "E_alpha"),
    ("generalization", "Positive_magnitude_small"),
    ("generalization", "Positive_magnitude_sqrt_n")
]


def all_plots_from_json(json_path, scale: str = "log"):

    json_path = Path(json_path)
    if not json_path.exists():
        raise FileNotFoundError(f"{str(json_path)} not found.")

    with open(str(json_path), "r") as json_file:
        results = json.load(json_file)

    plt.figure(figsize=(16,11))

    plt.subplot(221)
    
    plot_from_dict(
        results,
        XY_LIST[0][0],
        XY_LIST[0][1],
        scale,
        colorbar=False,
        xaxis=False
    )

    plt.subplot(222)
    
    plot_from_dict(
        results,
        XY_LIST[1][0],
        XY_LIST[1][1],
        scale,
        xaxis=False
    )

    plt.subplot(223)
    
    plot_from_dict(
        results,
        XY_LIST[2][0],
        XY_LIST[2][1],
        scale,
        colorbar=False,
        legend=True
        )

    plt.subplot(224)
    
    plot_from_dict(
        results,
        XY_LIST[3][0],
        XY_LIST[3][1],
        scale
        )
    
    plt.tight_layout()

    output_path = (json_path.parent / "figures" / "grid").with_suffix(".pdf")
    output_path.parent.mkdir(parents=True, exist_ok=True)

    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    logger.info(f"Saving figure in {str(output_path)}")
    plt.savefig(str(output_path))



if __name__ == "__main__":
    fire.Fire(all_plots_from_json)
