import fire
from loguru import logger
from tqdm import tqdm

from diffusion_plots import plot_from_json, NoConvergedExperiment


XY_LIST = [
    ("test_losses", "train_losses"),
    ("fid", "generalization"),
    ("fid", "score_generalization"),
    ("score_generalization", "ratio_bs_lr"),
    ("score_generalization", "ratio_bs_lr2"),
    ("score_generalization", "gradient_norms"),
    ("score_generalization", "gradient_norms_lr"),
    ("score_generalization", "gradient_norms_bs"),
    ("score_generalization", "sgld_bound"),
    ("score_generalization", "E_alpha"),
    ("score_generalization", "Positive_magnitude"),
    ("generalization", "ratio_bs_lr"),
    ("generalization", "ratio_bs_lr2"),
    ("generalization", "gradient_norms"),
    ("generalization", "gradient_norms_lr"),
    ("generalization", "gradient_norms_bs"),
    ("generalization", "sgld_bound"),
    ("generalization", "E_alpha"),
    ("worst_generalization", "E_alpha"),
    ("worst_generalization", "gradient_norms_bs"),
    ("generalization", "Positive_magnitude"),
    ("generalization", "Positive_magnitude_small"),
    ("generalization", "Positive_magnitude_sqrt_n"),
    ("fid_generalization", "ratio_bs_lr"),
    ("fid_generalization", "ratio_bs_lr2"),
    ("fid_generalization", "batch_size"),
    ("fid_generalization", "learning_rate"),
    ("fid_generalization", "gradient_norms"),
    ("fid_generalization", "gradient_norms_lr"),
    ("fid_generalization", "gradient_norms_bs"),
    ("fid_generalization", "sgld_bound"),
    ("fid_generalization", "E_alpha"),
    ("fid_generalization", "Positive_magnitude"),
    ("fid", "ratio_bs_lr"),
    ("fid", "learning_rate"),
    ("fid", "batch_size"),
    ("fid", "gradient_norms"),
    ("fid", "gradient_norms_lr"),
    ("fid", "gradient_norms_bs"),
    ("fid", "sgld_bound"),
    ("fid", "E_alpha"),
    ("fid", "train_losses"),
    ("fid", "train_score_losses"),
    ("fid", "test_losses"),
    ("fid", "test_score_losses"),
    ("fid", "Positive_magnitude"),
    ("test_losses", "ratio_bs_lr"),
    ("test_losses", "gradient_norms"),
    ("test_losses", "gradient_norms_lr"),
    ("test_losses", "gradient_norms_bs"),
    ("test_losses", "sgld_bound"),
    ("test_losses", "E_alpha"),
    ("test_losses", "Positive_magnitude"),
    ("test_score_losses", "ratio_bs_lr"),
    ("test_score_losses", "gradient_norms"),
    ("test_score_losses", "sgld_bound"),
    ("test_score_losses", "E_alpha"),
    ("test_score_losses", "Positive_magnitude"),
    ("test_wass", "ratio_bs_lr"),
    ("test_wass", "gradient_norms"),
    ("test_wass", "gradient_norms_lr"),
    ("test_wass", "sgld_bound"),
    ("test_wass", "E_alpha"),
    ("test_wass", "Positive_magnitude"),
    ("learning_rate", "fid"),
    ("train_fid", "ratio_bs_lr"),
    ("train_fid", "learning_rate"),
    ("train_fid", "batch_size"),
    ("train_fid", "gradient_norms"),
    ("train_fid", "gradient_norms_lr"),
    ("train_fid", "sgld_bound"),
    ("train_fid", "E_alpha"),
]


def all_plots_from_json(json_path, scale: str = "log"):

    for plot in tqdm(XY_LIST):

        # Custom labels could be added here too
        try:
            plot_from_json(
                str(json_path),
                plot[0],
                plot[1],
                scale
            )
        except KeyError:
            logger.warning(f"One of the keys {plot} is not available in {str(json_path)}")
        except NoConvergedExperiment:
            logger.warning(f"No converged experiment for {str(json_path)}")
            


if __name__ == "__main__":
    fire.Fire(all_plots_from_json)