# Takes a hypernetwork and learns the hyperclip model on weight generated by embedding adaptation


import torch.optim as optim
import argparse
import json
import os
import numpy as np
import torch
import wandb
from functools import *

from model.clip.hyperclip import build_hyperclip_from_classic_clip
from model.custom_hnet import CLIPAdapter, HyperGenerator, HyperDiscriminator, HyperEncoder, MetaModel
from training.hyperclip_learn import HyperclipTraining
from utils.build_opt import build_optimizer
from utils.config import hypergan_defaults
from utils import clip_utils
from utils.misc_utils import str2bool
from training.utils import log_metric

from scripts.init_utils import load_metamodel_from_checkpoint, load_vae_and_metamodel_from_checkpoint
from features.image_features import load_image_features
from features.ques_features import load_ques_features
from features.text_features import load_text_features

parser = argparse.ArgumentParser()
parser.add_argument('--wandb_mode', type=str, default='online',
                    help='Set to "disabled" to disable Weights & Biases logging')

parser.add_argument('--inner_epochs', type=str)

parser.add_argument('--few_shot_checkpoint', type=str, help='required')
parser.add_argument('--vae_checkpoint', type=str, help='required')
parser.add_argument('--precompute_checkpoint', type=str, help='required')

parser.add_argument('--hyperclip_epochs', type=int)
parser.add_argument('--hyperclip_batch_size', type=int)
parser.add_argument('--val_epoch_interval', type=int)


parser.add_argument('--hyperclip_optimizer', type=str)
parser.add_argument('--hyperclip_learning_rate', type=float)
parser.add_argument('--eval_inner_epochs', type=str)
parser.add_argument('--guidance_optimizer', type=str)
parser.add_argument('--guidance_learning_rate', type=float)

parser.add_argument('--guidance_scheduler', type=str)

# hyperclip
parser.add_argument('--hyperclip_hidden_dim', type=str)
parser.add_argument('--hyperclip_model', type=str)

parser.add_argument('--guidance_init_l2_weight', type=str)
parser.add_argument('--train_on_vae', type=str2bool)

parser.add_argument('--checkpoint', type=str)
parser.add_argument('--eval', type=str2bool)
parser.add_argument('--langevin_eps', type=str)

parser.add_argument('--normalize', type=str2bool)


base_path = os.path.dirname(os.path.dirname(__file__))


default_config = {
    "inner_epochs": '50',
    "inner_optimizer": "sgd",
    "inner_learning_rate": 0.1,
    "train_subtype": "random",
    "val_subtype": "random",

    "clip_model": "ViT-L/14@336px",
    "hyperclip_model": "mlp",
    "hyperclip_hidden_dim": "128,128",

    "hyperclip_epochs": 1000,
    "hyperclip_batch_size": 32,

    "hyperclip_optimizer": "adam",
    "hyperclip_learning_rate": 0.001,
    "hyperclip_weight_decay": 0,
    "hyperclip_momentum": 0.9,
    "hyperclip_sgd_nesterov": True,

    "eval_inner_epochs": '50',
    "guidance_optimizer": "adam",
    "guidance_learning_rate": 0.001,
    "guidance_momentum": 0.9,
    "guidance_sgd_nesterov": True,
    "guidance_init_l2_weight": "0",

    "guidance_scheduler": "none",

    "val_epoch_interval": 100,
    "save_checkpoint": False,
    "train_on_vae":False,

    "eval": False,
    "normalize":False, 
    "langevin_eps": "0"
}

torch.manual_seed(42)
rng = np.random.RandomState(42)
np.random.seed(42)

def main(args):
    cfg = default_config
    cfg.update({k: v for (k, v) in vars(args).items() if v is not None})
    print(cfg)
    wandb.init(project='train_hyperclip', entity="srl_ethz", config=cfg,
               mode=args.wandb_mode)
    config = wandb.config

    # Load the model
    device = "cuda" if torch.cuda.is_available() else "cpu"

    with open(base_path + "/data/VQA/Meta/meta_train.json") as file:
        train_data = json.load(file)
    with open(base_path + "/data/VQA/Meta/meta_test.json") as file:
        test_data = json.load(file)

    hnet_gen, hnet_enc = None, None
    if "vae_checkpoint" in config and config["vae_checkpoint"] is not None:
        meta_module, hnet_gen, hnet_enc, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \
            load_vae_and_metamodel_from_checkpoint(config, device)
    else:
        meta_module, precomputed_latent, precomputed_latent_train_eval, precomputed_latent_val_eval = \
            load_metamodel_from_checkpoint(config, device)

    # pre-computed features
    image_features = load_image_features(config["clip_model"])
    text_features = load_text_features(config["clip_model"])
    ques_emb = load_ques_features(config["clip_model"])

    hyperclip = build_hyperclip_from_classic_clip(
        os.path.expanduser(clip_utils.cached_location[config["clip_model"]]),
        hyper_model=config["hyperclip_model"],
        mainnet_param_count=meta_module.mnet.get_parameter_vector().shape[0],
        hyper_hidden_dims=[] if config["hyperclip_hidden_dim"] == "" else [int(i) for i in config["hyperclip_hidden_dim"].split(",")],
        pretrained_it_location=os.path.expanduser(clip_utils.cached_location[config["clip_model"]]),
        pretrained_hyper_location=None).to(device)

    if "checkpoint" in config:
        api = wandb.Api()
        loaded_run = api.run(config["checkpoint"])
        loaded_model_path = base_path + "/evaluation/hyperclip/hyperclip_" + str(loaded_run.name) + ".pth"
        hyperclip.hyper.load_state_dict(torch.load(loaded_model_path), strict=False)


    hyperclip.hyper.train()
    hyperclip_optimizer = build_optimizer(hyperclip.hyper.parameters(), config, loop="hyperclip")
    hyperclip_training = HyperclipTraining(meta_module, hnet_gen, hnet_enc,
                                           hyperclip, hyperclip_optimizer, image_features, text_features,
                                           ques_emb, config, device, train_on_vae=config["train_on_vae"])

    unguided_inner_optim =partial(build_optimizer, config=config, loop="inner")
    guided_inner_optim =partial(build_optimizer, config=config, loop="guidance")


    # Get STD and MEAN
    if precomputed_latent is not None:
        sampled_precomputed_latent = {k:v[0:1].to(device) for (k,v) in precomputed_latent.items()}
    else:
        sampled_precomputed_latent=None

    _, optimized_params_mean, optimized_params_std, _ =  hyperclip_training.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim,
                                     batch_size=config["hyperclip_batch_size"], precomputed_latent=sampled_precomputed_latent,
                                     output_mean_std=True, train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], skip_cond=True)

    optimized_params_mean, optimized_params_std=optimized_params_mean[0], optimized_params_std[0].clip(0.01)

    if config["normalize"]: 
        hyperclip_training.set_stats(optimized_params_mean, optimized_params_std)
        wandb.log({"mean":optimized_params_mean, "std":optimized_params_std})

    for meta_epoch in range(config["hyperclip_epochs"]):
        if precomputed_latent is not None:
            curr_sample_epoch = np.random.randint(0, precomputed_latent["clip_embedding"].shape[0])
            curr_sample_batch_perm = np.random.permutation(precomputed_latent["clip_embedding"].shape[1])
            sampled_precomputed_latent = {k:v[curr_sample_epoch:curr_sample_epoch+1, curr_sample_batch_perm].to(device) for (k,v) in precomputed_latent.items()}
        else:
            sampled_precomputed_latent=None

        if not config["eval"]:
            hyperclip_training.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim,
                                     batch_size=config["hyperclip_batch_size"], precomputed_latent=sampled_precomputed_latent,
                                     train=True, train_subtype = config["train_subtype"], val_subtype=config["val_subtype"], debug=True)

        if config["eval"] or (meta_epoch + 1) % (config["val_epoch_interval"]//10+1) == 0:

            log_dict = hyperclip_training.run_epoch(train_data, config["inner_epochs"], unguided_inner_optim, precomputed_latent=precomputed_latent_train_eval)
            log_dict["epoch"]=meta_epoch
            log_metric(log_dict, "eval_train/")
            log_dict = hyperclip_training.run_epoch(test_data, config["inner_epochs"], unguided_inner_optim, precomputed_latent=precomputed_latent_val_eval)
            log_dict["epoch"]=meta_epoch
            log_metric(log_dict, "eval_val/")

        if config["eval"] or (meta_epoch + 1) % config["val_epoch_interval"] == 0:
            if config["eval_inner_epochs"] != '':
                def eval(inner_epochs, guidance_init_l2_weight, guidance_scheduler, langevin_eps, train=False):
                    guidance_scheduler_fn=None
                    if guidance_scheduler == "cos":
                        guidance_scheduler_fn = partial(optim.lr_scheduler.CosineAnnealingLR, T_max=inner_epochs, eta_min=0)

                    if train:
                        log_dict = hyperclip_training.run_epoch(train_data,  inner_epochs, guided_inner_optim, 
                                                  guided_inner=True, use_vae=True, init_guidance_at="pre-trained", skip_cond=True, guidance_init_l2_weight=guidance_init_l2_weight, langevin_eps=langevin_eps, guidance_scheduler_fn=guidance_scheduler_fn)
                        log_dict["epoch"]=meta_epoch
                        log_metric(log_dict, "guided_eval_train_{}step_{}l2_{}_{}/".format(inner_epochs, guidance_init_l2_weight, guidance_scheduler, langevin_eps))

                    log_dict = hyperclip_training.run_epoch(test_data,  inner_epochs, guided_inner_optim, guided_inner=True, use_vae=True, init_guidance_at="pre-trained", 
                                                            skip_cond=True, guidance_init_l2_weight=guidance_init_l2_weight, langevin_eps=langevin_eps, guidance_scheduler_fn=guidance_scheduler_fn)
                    log_dict["epoch"]=meta_epoch

                    log_metric(log_dict, "guided_eval_val_{}step_{}l2_{}_{}/".format(inner_epochs, guidance_init_l2_weight, guidance_scheduler, langevin_eps))

                for idx, inner_epochs in enumerate(config["eval_inner_epochs"].split(",")):
                    for _, guidance_init_l2_weight in enumerate(config["guidance_init_l2_weight"].split(",")):
                        for _, guidance_scheduler in enumerate(config["guidance_scheduler"].split(",")):
                            for _, langevin_eps in enumerate(config["langevin_eps"].split(",")):
                                eval(int(inner_epochs), float(guidance_init_l2_weight), guidance_scheduler, float(langevin_eps), train=(meta_epoch + 1) % 1000 == 0)

            if config["eval"]:
                return

            if config["save_checkpoint"]:
                hyperclip_output_path_checkpoint = base_path + "/evaluation/hyperclip/hyperclip_" + str(
                wandb.run.name) + "_" + str(
                meta_epoch) + ".pth"
                torch.save(hyperclip.hyper.state_dict(), hyperclip_output_path_checkpoint)
                print(f"Checkpoint for meta-epoch {meta_epoch} saved!")

    hyperclip_output_path_checkpoint = base_path + "/evaluation/hyperclip/hyperclip_" + str(
        wandb.run.name)+ ".pth"
    torch.save(hyperclip.hyper.state_dict(), hyperclip_output_path_checkpoint)

    wandb.finish()


if __name__ == "__main__":
    args = parser.parse_args()
    main(args)
