import json
from pathlib import Path

import fire
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from loguru import logger
from matplotlib import rc
from tqdm import tqdm

from diffusion_plots import NoConvergedExperiment


"""
    This script is specific to the GMMs experiments, do not use it for other purposes
"""

MARKERS = ["o", "^", "s", "v", "+", "D", "*", "p"]
COLORS = ["b", "orange", "g", "r", "purple", "brown"]

plt.style.use("_classic_test_patch")

font = {'weight' : 'bold',
        'size'   : 14}

matplotlib.rc('font', **font)


ALL_PLOTS = [
    ("test_wass", "ratio_bs_lr"),
    ("test_losses", "train_losses"),
    ("score_generalization", "ratio_bs_lr"),
    ("score_generalization", "gradient_norms"),
    ("score_generalization", "gradient_norms_lr"),
    ("score_generalization", "sgld_bound"),
    ("score_generalization", "E_alpha"),
    ("score_generalization", "Positive_magnitude"),
    ("generalization", "ratio_bs_lr"),
    ("generalization", "gradient_norms"),
    ("generalization", "gradient_norms_lr"),
    ("generalization", "sgld_bound"),
    ("generalization", "E_alpha"),
    ("generalization", "Positive_magnitude"),
    ("wasserstein_generalization", "ratio_bs_lr"),
    ("wasserstein_generalization", "gradient_norms"),
    ("wasserstein_generalization", "gradient_norms_lr"),
    ("wasserstein_generalization", "sgld_bound"),
    ("wasserstein_generalization", "E_alpha"),
    ("wasserstein_generalization", "learning_rate"),
    ("wasserstein_generalization", "Positive_magnitude"),
    ("test_wass", "ratio_bs_lr"),
    ("test_wass", "gradient_norms"),
    ("test_wass", "gradient_norms_lr"),
    ("test_wass", "sgld_bound"),
    ("test_wass", "E_alpha"),
    ("test_wass", "learning_rate"),
    ("test_wass", "Positive_magnitude"),
    ("train_wass", "ratio_bs_lr"),
    ("train_wass", "gradient_norms"),
    ("train_wass", "gradient_norms_lr"),
    ("train_wass", "sgld_bound"),
    ("train_wass", "E_alpha"),
    ("train_wass", "Positive_magnitude"),
    ("train_losses", "ratio_bs_lr"),
    ("train_losses", "gradient_norms"),
    ("train_losses", "gradient_norms_lr"),
    ("train_losses", "sgld_bound"),
    ("train_losses", "E_alpha"),
    ("train_losses", "Positive_magnitude"),
]

LABELS = {
    "fid": "FID",
    "generalization": "Generalization gap",
    "score_generalization": "Score-based generalization error",
    "ratio_bs_lr": r"$b / \eta$",
    "ratio_lr_bs": r"$\eta / b$",
    "batch_size": "Batch size",
    "learning_rate": "Learning rate",
    "train_losses": "Train loss",
    "test_losses": "Test loss",
    "train_score_losses": "Score-based train loss",
    "test_score_losses": "Score-based test loss",
    "E_alpha": r"$E^1(W^{(n)})$",
    "Positive_magnitude": r"$\mathrm{PMag}(W^{(n)})$",
    "sgld_bound": r"$\sqrt{ \eta \beta \langle \Vert \widehat{g}_k \Vert^2 \rangle / n}$",
    "gradient_norms": r"$\langle \Vert \widehat{g}_k \Vert^2 \rangle$",
    "gradient_norms_lr": r"$\eta \times \langle \Vert \widehat{g}_k \Vert^2 \rangle$",
    "gradient_norms_bs": r"$b \times \langle \Vert \widehat{g}_k \Vert^2 \rangle$",
    "test_wass": "Test Wasserstein metric",
    "train_wass": "Train Wasserstein metric",
    "wasserstein_generalization": "Wasserstein generalization",
    "fid_generalization": "FID generalization"
}



RESULTS_PATH = "temp/simple_dataset_threshold_035/all_results.json"

def plot_from_dict(results: dict,
                    output_path: str,
                   generalization_key: str = "test_wass",
                   complexity_key: str = "learning_rate",
                   colored_key: str = "temperature",
                   n_samples: int = 8192,
                   scale: str = "",
                   save: bool = True,
                   colorbar: bool = True,
                   ylabel: str = ""
                   ):

    experiments = [k for k in results.keys() if
                   generalization_key in results[k].keys() and
                   complexity_key in results[k].keys()
                #    and results[k]["batch_size"] <= 128
                   ]

    logger.info(f"Found {len(experiments)} converged experiments for {generalization_key} and {complexity_key}")

    all_colored_key = list(set([results[k][colored_key] for k in experiments]))
    all_colored_key = [1.e-4]

    for b_idx in np.argsort(all_colored_key):

        b = all_colored_key[b_idx]

        # Here insert a condition to filter the results
        b_experiments = [k for k in experiments if results[k][colored_key] == b]

        dim_tab = [results[k][complexity_key] for k in b_experiments]
        gen_tab = [results[k][generalization_key] for k in b_experiments]
        gen_deviation_tab = [results[k][generalization_key + "_std"] for k in b_experiments]

        indices = np.argsort(dim_tab)
        gen_tab = np.array(gen_tab)[indices]
        dim_tab = np.array(dim_tab)[indices]
        gen_deviation_tab = np.array(gen_deviation_tab)[indices]

        plt.fill_between(dim_tab, \
                            gen_tab - gen_deviation_tab,\
                             gen_tab + gen_deviation_tab,
                             color = COLORS[b_idx],
                             alpha = 0.25)

        plt.plot(
            dim_tab,
            gen_tab,
            marker='o',
            linewidth=1,
            label=str(b),
            color = COLORS[b_idx]
        )

        # adding linear regression
        # we deal with potential empty list due to the filtering
        try:
            sns.regplot(
                x=dim_tab,
                y=gen_tab,
                line_kws={"color": "red", "linewidth": 1.5}
            )
        except TypeError:
            pass

    if scale == "log":
        plt.xscale("log")

    plt.ylabel(LABELS[generalization_key], fontweight="bold")
    plt.xlabel(LABELS[complexity_key], fontweight="bold")

    plt.grid(visible=True, which="both")
    plt.legend(title=colored_key.replace("_", " "))

    plt.tight_layout()

    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    logger.info(f"Saving figure in {str(output_path)}")
    plt.savefig(str(output_path))

    plt.close()

    
def plot_from_json(json_path: str = RESULTS_PATH,
                   generalization_key: str = "generalization",
                   complexity_key: str = "ratio_bs_lr",
                   colored_key: str = "temperature",
                   n_samples="",
                   scale: str = "",
                   save: bool = True,
                   colorbar: bool = True,
                   ylabel: str = r"$\eta / b$"
                   ):

    json_path = Path(json_path)
    if not json_path.exists():
        raise FileNotFoundError(f"{str(json_path)} not found.")

    with open(str(json_path), "r") as json_file:
        results = json.load(json_file)

    if save:
        output_path = (json_path.parent / "0___figures" / \
            (generalization_key + f"_{n_samples}_" + complexity_key)).with_suffix(".pdf")
        output_path.parent.mkdir(parents=True, exist_ok=True)
    else:
        output_path = None

    plot_from_dict(
        results,
        str(output_path),
        generalization_key,
        complexity_key,
        colored_key,
        n_samples,
        scale,
        save,
        colorbar,
        ylabel
    )


def all_plots(json_path: str = RESULTS_PATH,
                n_samples=512,
                scale: str = "",
                colored_key: str = "n_samples"):
    
    logger.debug(colored_key)
    
    for plot in tqdm(ALL_PLOTS):

        # Custom labels could be added here too
        try:
            plot_from_json(
                str(json_path),
                generalization_key=plot[0],
                complexity_key=plot[1],
                colored_key=colored_key,
                n_samples=8192,
                scale=scale,
                save=True,
                colorbar=True
            )
        except KeyError:
            logger.warning(f"One of the keys {plot} is not available in {str(json_path)}")
        except NoConvergedExperiment:
            logger.warning(f"No converged experiment for {str(json_path)}")
        except ValueError:
            logger.warning(f"no log scale")
            




if __name__ == "__main__":
    fire.Fire(all_plots)
