# Takes a hypernetwork and learns the hyperclip model on weight generated by embedding adaptation

import argparse
import json
import os
import numpy as np
import torch
import wandb
from functools import *

from model.clip.hyperclip import build_hyperclip_from_classic_clip
from model.custom_hnet import CLIPAdapter, HyperGenerator, HyperDiscriminator, HyperEncoder, MetaModel
from model.diffusion.latent_diffuser import LatentDiffuser, LatentDiffuserV2
from training.hyperclip_learn import HyperclipTraining
from training.latent_diffusion_learn import EMA, LatentDiffusionTraining
from utils.build_opt import build_optimizer
from utils.config import hypergan_defaults
from utils import clip_utils
from training.utils import append_dict, log_metric
from utils.misc_utils import str2bool

from scripts.init_utils import load_metamodel_from_checkpoint, load_vae_and_metamodel_from_checkpoint
from features.image_features import load_image_features
from features.ques_features import load_ques_features
from features.text_features import load_text_features
from utils.diffusion_utils import make_beta_schedule
from utils.misc_utils import str2bool

parser = argparse.ArgumentParser()
parser.add_argument('--wandb_mode', type=str, default='online',
                    help='Set to "disabled" to disable Weights & Biases logging')

parser.add_argument('--inner_epochs', type=str)
parser.add_argument('--inner_optimizer', type=str)
parser.add_argument('--inner_learning_rate', type=float)

parser.add_argument('--few_shot_checkpoint', type=str, help='required')
parser.add_argument('--vae_checkpoint', type=str, help='required')
parser.add_argument('--precompute_checkpoint', type=str, help='required')

parser.add_argument('--diffusion_epochs', type=int)
parser.add_argument('--diffusion_batch_size', type=int)
parser.add_argument('--fisher_metric', type=str2bool)

parser.add_argument('--diffusion_n_timestep', type=int)
parser.add_argument('--diffusion_optimizer', type=str)
parser.add_argument('--diffusion_learning_rate', type=float)
parser.add_argument('--diffusion_beta_schedule', type=str)
parser.add_argument('--diffusion_linear_start', type=float)
parser.add_argument('--diffusion_linear_end', type=float)
parser.add_argument('--diffusion_timestep_emb_dim', type=int)
# todo: select different samplers for diffusion (as opposed to guidance optimizers)

parser.add_argument('--diffusion_hidden_dim', type=str)
parser.add_argument('--diffusion_model', type=str)

parser.add_argument('--diffusion_class_free_guidance', type=str2bool)
parser.add_argument('--diffusion_class_free_training_uncond_rate', type=float)
parser.add_argument('--diffusion_class_free_guidance_gammas', type=str)

parser.add_argument('--diffusion_few_shot_guidance_gammas', type=str)
parser.add_argument('--diffusion_few_shot_guidance', type=str2bool)


parser.add_argument('--diffusion_try_init_from_pretrained', type=str2bool)
parser.add_argument('--diffusion_start_from_t_frac', type=str)

parser.add_argument('--diffusion_fast_sampling_factor', type=int)

parser.add_argument('--diffusion_keep_tasks_frac', type=float)

parser.add_argument('--n_shot_trials_maxN', type=int)
parser.add_argument('--filter_tasks_answers', type=str2bool)
parser.add_argument('--n_shot_trials_reruns', type=int)
parser.add_argument('--skip_train_guidance', type=str2bool)

parser.add_argument('--val_epoch_interval', type=int)
parser.add_argument('--train_on_vae', type=str2bool)
parser.add_argument('--save_checkpoint', type=str2bool)

parser.add_argument('--seed', type=int, default=42)

parser.add_argument('--num_ensemble', type=int, default=1)

parser.add_argument('--checkpoint', type=str)
parser.add_argument('--eval', type=str2bool)
base_path = os.path.dirname(os.path.dirname(__file__))


default_config = {
    "inner_epochs": '10',
    "inner_optimizer": "sgd",
    "inner_learning_rate": 0.1,
    "train_subtype": "random",
    "val_subtype": "random",

    "clip_model": "ViT-L/14@336px",
    "diffusion_model": "latent_diffuser",
    "diffusion_hidden_dim": "128,128",

    "fisher_metric": False,

    "diffusion_epochs": 200,
    "diffusion_batch_size": 32,
    "diffusion_n_timestep": 1000, # no of diffusion timesteps (does not impact training time, only eval sampling time if using standard DDPM sampling)
    "diffusion_beta_schedule": "cosine",    # linear, cosine, sqrt_linear, sqrt  (cosine should be best, minidiffusion example uses sqrt_linear)
    "diffusion_linear_start": 1e-4,
    "diffusion_linear_end": 2e-2,
    "diffusion_timestep_emb_dim": 100,     # embedding dimension for timestep encoding

    "diffusion_optimizer": "adam",
    "diffusion_learning_rate": 0.001,
    "diffusion_weight_decay": 0,
    "diffusion_momentum": 0.9,
    "diffusion_sgd_nesterov": True,
    "diffusion_ema_decay": 0.9999,

    "diffusion_class_free_guidance": False,  # enable classifier free guidance training
    "diffusion_class_free_training_uncond_rate": 0.2, # fraction of unconditioned embeddings given during diffusion training
    "diffusion_class_free_guidance_gammas": "0.,1.,1.5,4.,7.5,15.",    # classifier free guidance gamma values used in different generation evals

    "diffusion_try_init_from_pretrained": False,  # alongside normal diffusion eval, also try to diffuse from a pre-trained init latent
    "diffusion_start_from_t_frac": '1.,0.5,0.1,0.01', # (used when "diffusion_try_init_from_pretrained" is True) fraction of timesteps to start diffusion from in the pre-trained init evals

    "diffusion_fast_sampling_factor": None,   # how many diffusion samples to skip when using fast sampling from IDDPM paper (none or 1 = no skipping)

    "diffusion_keep_tasks_frac": 1,    # how many tasks to cut from the training set (to test having an uncond trainset larger than cond trainset)

    "diffusion_few_shot_guidance": False, 
    "diffusion_few_shot_guidance_gammas": "0.3,1,3",    # classifier free guidance gamma values used in different generation evals


    "n_shot_trials_maxN": None, # if not None, do n-shot trials for n in range(1, n_shot_trials_maxN+1)
    "filter_tasks_answers": False, # if True, don't filter out tasks, but only filter out answers in tasks that don't have enough examples for n-shot trials
    "n_shot_trials_reruns": 1,  # how many times to repeat each n-shot trial for mean and std accuracy calculation
    "skip_train_guidance": False,

    "val_epoch_interval": 10,
    "save_checkpoint": False,
    "train_on_vae":False,
    "eval": False,
}

args = parser.parse_args()
torch.manual_seed(args.seed)
rng = np.random.RandomState(args.seed)
np.random.seed(args.seed)

def main(args):
    cfg = default_config
    cfg.update({k: v for (k, v) in vars(args).items() if v is not None})
    print(cfg)
    wandb.init(project='train_latent_diffusion', entity="srl_ethz", config=cfg,
               mode=args.wandb_mode)
    config = wandb.config

    # Load the model
    device = "cuda" if torch.cuda.is_available() else "cpu"

    with open(base_path + "/data/VQA/Meta/meta_train.json") as file:
        train_data = json.load(file)
    with open(base_path + "/data/VQA/Meta/meta_test.json") as file:
        test_data = json.load(file)

    hnet_gen, hnet_enc = None, None
    if "vae_checkpoint" in config and config["vae_checkpoint"] is not None:
        meta_module, hnet_gen, hnet_enc, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \
            load_vae_and_metamodel_from_checkpoint(config, device)
    else:
        meta_module, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \
            load_metamodel_from_checkpoint(config, device)
    embedding_dim = meta_module.inner_params["enet.embedding"].flatten().shape[0] if hnet_gen is None or not config["train_on_vae"] else hnet_gen.input_dim

    # pre-computed features
    image_features = load_image_features(config["clip_model"])
    text_features = load_text_features(config["clip_model"])
    ques_emb = load_ques_features(config["clip_model"])

    if config["diffusion_model"] == "latent_diffuser":
        latent_diffuser = LatentDiffuser(
            x_dim=embedding_dim,
            cond_emb_dim=clip_utils.embedding_size[config["clip_model"]],
            timestep_emb_dim=config["diffusion_timestep_emb_dim"],
            hidden_dims=[int(i) for i in config["diffusion_hidden_dim"].split(",")],
        ).to(device)
    elif config["diffusion_model"] == "latent_diffuser_v2":
        latent_diffuser = LatentDiffuserV2(
            x_dim=embedding_dim,
            cond_emb_dim=clip_utils.embedding_size[config["clip_model"]],
            timestep_emb_dim=config["diffusion_timestep_emb_dim"],
            hidden_dims=[int(i) for i in config["diffusion_hidden_dim"].split(",")],
        ).to(device)
    else:
        raise ValueError("Unknown diffusion_model")

    if "checkpoint" in config:
        api = wandb.Api()
        loaded_run = api.run(config["checkpoint"])
        loaded_model_path = base_path + "/evaluation/diffusion/diffusion_" + str(loaded_run.name) + ".pth"
        latent_diffuser.load_state_dict(torch.load(loaded_model_path), strict=False)

    diffusion_optimizer = build_optimizer(latent_diffuser.parameters(), config, loop="diffusion")
    unguided_inner_optim = partial(build_optimizer, config=config, loop="inner")

    if precomputed_latent is not None:
        curr_sample_epoch = np.random.randint(0, precomputed_latent["clip_embedding"].shape[0])
        curr_sample_batch_perm = np.random.permutation(precomputed_latent["clip_embedding"].shape[1])
        sampled_precomputed_latent = {k:v[curr_sample_epoch:curr_sample_epoch+1, curr_sample_batch_perm].to(device) for (k,v) in precomputed_latent.items()}
    else:
        sampled_precomputed_latent=None

    diffusion_training = LatentDiffusionTraining(meta_module=meta_module, 
                                                 diffusion_model=latent_diffuser,
                                                 optimizer=diffusion_optimizer,
                                                 n_timestep=config["diffusion_n_timestep"],
                                                 beta=make_beta_schedule(config["diffusion_beta_schedule"], config["diffusion_n_timestep"], linear_start=config["diffusion_linear_start"], linear_end=config["diffusion_linear_end"]),
                                                 ema=EMA(latent_diffuser, config["diffusion_ema_decay"]),
                                                 net_feature_dim=embedding_dim,
                                                 train_data_for_mean_std=train_data,
                                                 inner_opt_for_mean_std=unguided_inner_optim,
                                                 image_features=image_features,
                                                 text_features=text_features,
                                                 ques_emb=ques_emb,
                                                 config=config,
                                                 compute_hessian=config["fisher_metric"],
                                                 device=device,
                                                 hnet_gen=hnet_gen, hnet_enc=hnet_enc, train_on_vae=config["train_on_vae"],
                                                 compute_latents_mean_std=True,
                                                 class_free_guidance=config["diffusion_class_free_guidance"],
                                                 class_free_training_uncond_rate=config["diffusion_class_free_training_uncond_rate"],
                                                 precomputed_latent_for_mean_std=sampled_precomputed_latent)

    if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
        n_shot_trials = []

    best_val_dict = {"best_val_accuracy": 0, "best_val_epoch": 0}

    for meta_epoch in range(config["diffusion_epochs"]):
        if precomputed_latent is not None:
            curr_sample_epoch = np.random.randint(0, precomputed_latent["clip_embedding"].shape[0])
            curr_sample_batch_perm = np.random.permutation(precomputed_latent["clip_embedding"].shape[1])
            sampled_precomputed_latent = {k:v[curr_sample_epoch:curr_sample_epoch+1, curr_sample_batch_perm].to(device) for (k,v) in precomputed_latent.items()}
        else:
            sampled_precomputed_latent=None

        if not config["eval"]:
            diffusion_training.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim,
                                     batch_size=config["diffusion_batch_size"],
                                     train=True, train_subtype = config["train_subtype"], val_subtype=config["val_subtype"],
                                     precomputed_latent = sampled_precomputed_latent, keep_tasks_frac=config["diffusion_keep_tasks_frac"], debug=True)

        if config["eval"] or (meta_epoch+1) % config["val_epoch_interval"] == 0:

            if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
                n_shot_trial_dict = {"epoch": meta_epoch}

            log_dict = diffusion_training.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim,
                                     precomputed_latent = precomputed_latent_train_eval, keep_tasks_frac=config["diffusion_keep_tasks_frac"])
            log_dict["epoch"]=meta_epoch
            log_metric(log_dict, "eval_train/")
            log_dict = diffusion_training.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim,
                                     precomputed_latent = precomputed_latent_val_eval )
            log_dict["epoch"]=meta_epoch
            log_metric(log_dict, "eval_val/")
            if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
                for n_shot in range(1, config["n_shot_trials_maxN"]+1):
                    n_shot_trial_dict[f"eval_val/n_shot_{n_shot}/"] = {}
                    for i in range(config["n_shot_trials_reruns"]):
                        log_dict = diffusion_training.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim, use_vae=True, skip_cond=True,
                                        filter_tasks_by_max_k=config["n_shot_trials_maxN"], filter_tasks_answers=config["filter_tasks_answers"],
                                        n_shot_training=n_shot)
                        log_dict["epoch"]=meta_epoch
                        append_dict(n_shot_trial_dict[f"eval_val/n_shot_{n_shot}/"], log_dict)


            if config["diffusion_class_free_guidance"]:
                for idx, gamma in enumerate(config["diffusion_class_free_guidance_gammas"].split(",")):
                    if not config["skip_train_guidance"]:
                        log_dict = diffusion_training.run_epoch(train_data, None, None, guided_inner=True, use_vae=True, skip_cond=True, class_guidance_gamma=float(gamma),
                                                                init_guidance_at="random", guidance_start_from_t_frac=1., fast_sampling_factor=config["diffusion_fast_sampling_factor"],
                                                                keep_tasks_frac=config["diffusion_keep_tasks_frac"])
                        log_dict["epoch"]=meta_epoch
                        log_metric(log_dict, "diffused_eval_train_{}gamma/".format(gamma))

                    log_dict, _, _, opt_latents = diffusion_training.run_epoch(test_data, None, None, guided_inner=True, use_vae=True, skip_cond=True, class_guidance_gamma=float(gamma),
                                                            init_guidance_at="random", guidance_start_from_t_frac=1., fast_sampling_factor=config["diffusion_fast_sampling_factor"],
                                                            output_mean_std=True)
                    log_dict["epoch"]=meta_epoch
                    if best_val_dict["best_val_accuracy"] < log_dict["query_accuracy_end"]:
                        best_val_dict["best_val_accuracy"] = log_dict["query_accuracy_end"]
                        best_val_dict["best_val_epoch"] = meta_epoch
                    log_metric(log_dict, "diffused_eval_val_{}gamma/".format(gamma))
                    log_metric(best_val_dict)
                    
                    if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
                        for n_shot in range(1, config["n_shot_trials_maxN"]+1):
                            n_shot_trial_dict[f"diffused_eval_val_{gamma}gamma/n_shot_{n_shot}/"] = {}
                            for i in range(config["n_shot_trials_reruns"]):
                                log_dict = diffusion_training.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim, use_vae=True, skip_cond=True,
                                            filter_tasks_by_max_k=config["n_shot_trials_maxN"],  filter_tasks_answers=config["filter_tasks_answers"],
                                            n_shot_training=n_shot, opt_latents_for_n_shot=opt_latents)
                                log_dict["epoch"]=meta_epoch
                                append_dict(n_shot_trial_dict[f"diffused_eval_val_{gamma}gamma/n_shot_{n_shot}/"], log_dict)

                    for i in range(config["n_shot_trials_reruns"]):
                        log_dict = diffusion_training.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim, use_vae=True, skip_cond=True,
                                    n_shot_training="full", opt_latents_for_n_shot=opt_latents)
                        log_dict["epoch"]=meta_epoch
                        log_metric(log_dict, f"diffused_eval_val_{gamma}gamma/few_shot/")

                    if config["num_ensemble"]>1:
                        log_dict, _, _, opt_latents = diffusion_training.run_epoch(test_data, None, None, guided_inner=True, use_vae=True, skip_cond=True, class_guidance_gamma=float(gamma),
                                                            init_guidance_at="random", guidance_start_from_t_frac=1., fast_sampling_factor=config["diffusion_fast_sampling_factor"],
                                                            output_mean_std=True, num_ensemble=config["num_ensemble"])
                        log_dict["epoch"]=meta_epoch
                        log_metric(log_dict, "diffused_eval_val_{}gamma_{}ens/".format(gamma, config["num_ensemble"]))
                        log_metric(best_val_dict)

                        if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
                            for n_shot in range(1, config["n_shot_trials_maxN"]+1):
                                n_shot_trial_dict[f"diffused_eval_val_{gamma}gamma_{config['num_ensemble']}ens/n_shot_{n_shot}/"] = {}
                                for i in range(config["n_shot_trials_reruns"]):
                                    log_dict = diffusion_training.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim, use_vae=True, skip_cond=True,
                                                filter_tasks_by_max_k=config["n_shot_trials_maxN"],  filter_tasks_answers=config["filter_tasks_answers"],
                                                n_shot_training=n_shot, opt_latents_for_n_shot=opt_latents)
                                    log_dict["epoch"]=meta_epoch
                                    append_dict(n_shot_trial_dict[f"diffused_eval_val_{gamma}gamma_{config['num_ensemble']}ens/n_shot_{n_shot}/"], log_dict)

                        for i in range(config["n_shot_trials_reruns"]):
                            log_dict = diffusion_training.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim, use_vae=True, skip_cond=True,
                                        n_shot_training="full", opt_latents_for_n_shot=opt_latents)
                            log_dict["epoch"]=meta_epoch
                            log_metric(log_dict, f"diffused_eval_val_{gamma}gamma_{config['num_ensemble']}ens/few_shot/")


                    if config["diffusion_few_shot_guidance"]:
                        for few_shot_gamma in config["diffusion_few_shot_guidance_gammas"].split(","):
                            few_shot_gamma=float(few_shot_gamma)
                            if not config["skip_train_guidance"]:
                                log_dict = diffusion_training.run_epoch(train_data, None, None, guided_inner=True, use_vae=True, skip_cond=True, class_guidance_gamma=float(gamma),
                                                                    init_guidance_at="random", guidance_start_from_t_frac=1., fast_sampling_factor=config["diffusion_fast_sampling_factor"],
                                                                    keep_tasks_frac=config["diffusion_keep_tasks_frac"], 
                                                                    few_shot_gamma=few_shot_gamma, few_shot_guidance=True)
                                log_dict["epoch"]=meta_epoch
                                log_metric(log_dict, "diffused_eval_train_{}gamma_{}fewshotgamma/".format(gamma, few_shot_gamma))
   
                            log_dict = diffusion_training.run_epoch(test_data, None, None, guided_inner=True, use_vae=True, skip_cond=True, class_guidance_gamma=float(gamma),
                                                                 init_guidance_at="random", guidance_start_from_t_frac=1., fast_sampling_factor=config["diffusion_fast_sampling_factor"],
                                                                 few_shot_gamma=few_shot_gamma, few_shot_guidance=True)
                            log_dict["epoch"]=meta_epoch
                            log_metric(log_dict, "diffused_eval_val_{}gamma_{}fewshotgamma/".format(gamma, few_shot_gamma))


                            if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
                                for n_shot in range(1, config["n_shot_trials_maxN"]+1):
                                    n_shot_trial_dict[f"diffused_eval_val_{gamma}gamma/n_shot_{n_shot}/"] = {}
                                    for i in range(config["n_shot_trials_reruns"]):
   
                                        log_dict = diffusion_training.run_epoch(test_data, None, None, guided_inner=True, use_vae=True, skip_cond=True, class_guidance_gamma=float(gamma),
                                                                                filter_tasks_by_max_k=config["n_shot_trials_maxN"],  filter_tasks_answers=config["filter_tasks_answers"],
                                                                                n_shot_training=n_shot, 
                                                                                init_guidance_at="random", guidance_start_from_t_frac=1., fast_sampling_factor=config["diffusion_fast_sampling_factor"],
                                                                                few_shot_gamma=few_shot_gamma, few_shot_guidance=True)

                                        log_dict["epoch"]=meta_epoch
                                        append_dict(n_shot_trial_dict[f"diffused_eval_val_{gamma}gamma_{few_shot_gamma}fewshotgamma/n_shot_{n_shot}/"], log_dict)

                            for i in range(config["n_shot_trials_reruns"]):
                                log_dict = diffusion_training.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim, use_vae=True, skip_cond=True,
                                            n_shot_training="full", opt_latents_for_n_shot=opt_latents, few_shot_gamma=few_shot_gamma, few_shot_guidance=True)
                                log_dict["epoch"]=meta_epoch
                                log_metric(log_dict, f"diffused_eval_val_{gamma}gamma_{few_shot_gamma}fewshotgamma/few_shot/")

            else:
                if not config["skip_train_guidance"]:
                    log_dict = diffusion_training.run_epoch(train_data, None, None, guided_inner=True, use_vae=True, skip_cond=True,
                                                            init_guidance_at="random", guidance_start_from_t_frac=1., fast_sampling_factor=config["diffusion_fast_sampling_factor"],
                                                            keep_tasks_frac=config["diffusion_keep_tasks_frac"])
                    log_dict["epoch"]=meta_epoch
                    log_metric(log_dict, "diffused_eval_train/")

                log_dict, _, _, opt_latents = diffusion_training.run_epoch(test_data, None, None, guided_inner=True, use_vae=True, skip_cond=True,
                                                        init_guidance_at="random", guidance_start_from_t_frac=1., fast_sampling_factor=config["diffusion_fast_sampling_factor"])
                log_dict["epoch"]=meta_epoch
                if best_val_dict["best_val_accuracy"] < log_dict["query_accuracy_end"]:
                    best_val_dict["best_val_accuracy"] = log_dict["query_accuracy_end"]
                    best_val_dict["best_val_epoch"] = meta_epoch
                log_metric(log_dict, "diffused_eval_val/")
                log_metric(best_val_dict)
                if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
                    for n_shot in range(1, config["n_shot_trials_maxN"]+1):
                        n_shot_trial_dict[f"diffused_eval_val/n_shot_{n_shot}/"] = {}
                        for i in range(config["n_shot_trials_reruns"]):
                            log_dict = diffusion_training.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim, use_vae=True, skip_cond=True,
                                            filter_tasks_by_max_k=config["n_shot_trials_maxN"],  filter_tasks_answers=config["filter_tasks_answers"],
                                            n_shot_training=n_shot, opt_latents_for_n_shot=opt_latents)
                            log_dict["epoch"]=meta_epoch
                            append_dict(n_shot_trial_dict[f"diffused_eval_val/n_shot_{n_shot}/"], log_dict)

                for i in range(config["n_shot_trials_reruns"]):
                        log_dict = diffusion_training.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim, use_vae=True, skip_cond=True,
                                    n_shot_training="full", opt_latents_for_n_shot=opt_latents)
                        log_dict["epoch"]=meta_epoch
                        log_metric(log_dict, f"diffused_eval_val/few_shot/")

            if config["diffusion_try_init_from_pretrained"]:
                # right now only running this without class-free guidance

                for idx, start_from_t_frac in enumerate(config["diffusion_start_from_t_frac"].split(",")):
                    if not config["skip_train_guidance"]:
                        log_dict = diffusion_training.run_epoch(train_data, None, None, guided_inner=True, use_vae=True, skip_cond=True,
                                                                init_guidance_at="pre-trained", guidance_start_from_t_frac=float(start_from_t_frac), fast_sampling_factor=config["diffusion_fast_sampling_factor"],
                                                                keep_tasks_frac=config["diffusion_keep_tasks_frac"])
                        log_dict["epoch"]=meta_epoch
                        log_metric(log_dict, "diffused_eval_train_from_pretrained_{}T/".format(start_from_t_frac))

                    log_dict = diffusion_training.run_epoch(test_data, None, None, guided_inner=True, use_vae=True, skip_cond=True,
                                                            init_guidance_at="pre-trained", guidance_start_from_t_frac=float(start_from_t_frac), fast_sampling_factor=config["diffusion_fast_sampling_factor"])
                    log_dict["epoch"]=meta_epoch
                    if best_val_dict["best_val_accuracy"] < log_dict["query_accuracy_end"]:
                        best_val_dict["best_val_accuracy"] = log_dict["query_accuracy_end"]
                        best_val_dict["best_val_epoch"] = meta_epoch
                    log_metric(log_dict, "diffused_eval_val_from_pretrained_{}T/".format(start_from_t_frac))
                    log_metric(best_val_dict)

            if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
                n_shot_trials += [n_shot_trial_dict]
        
        if config["eval"]:
            return

        if config["save_checkpoint"] and meta_epoch % config["val_epoch_interval"] == 0:
            diffusion_output_path_checkpoint = base_path + "/evaluation/diffusion/diffusion_" + str(
                wandb.run.name) + "_" + str(
                meta_epoch) + ".pth"
            torch.save(latent_diffuser.state_dict(), diffusion_output_path_checkpoint)
            print(f"Checkpoint for meta-epoch {meta_epoch} saved!")

        if "n_shot_trials_maxN" in config and config["n_shot_trials_maxN"] is not None:
            diffusion_output_path_checkpoint = base_path + "/evaluation/diffusion/diffusion_" + str(
                    wandb.run.name) + "_n_shot.npy"
            np.save(diffusion_output_path_checkpoint, n_shot_trials)

    diffusion_output_path_checkpoint = base_path + "/evaluation/diffusion/diffusion_" + str(
        wandb.run.name)+ ".pth"
    torch.save(latent_diffuser.state_dict(), diffusion_output_path_checkpoint)

    wandb.finish()


if __name__ == "__main__":
    main(args)
