import argparse
import json
import os
import numpy as np
from tqdm import tqdm
from copy import deepcopy

import torch
from torch.utils.data import DataLoader
import wandb

from data.dataloader.clip_vqa import CLIP_VQA
from model.classic_hypernet import PrimaryNetwork
from model.clip.hyperclip import build_hyperclip_from_classic_clip
from model.custom_hnet import CLIPAdapter, HyperGenerator, HyperDiscriminator, EmbeddingModule, MetaModel
from training.hyperclip_learn import HyperclipTraining
from training.maml import MAML, log_metric
import training.gan_learn as gan_learn
from utils.build_opt import build_optimizer
from utils.config import hypergan_defaults
from utils import clip_utils
from utils.meta_utils import AlphaReparam
from utils.wandb_utils import populate_wandb_table
from utils.misc_utils import str2bool
from model import custom_hnet

from features.image_features import load_image_features
from features.ques_features import load_ques_features
from features.text_features import load_text_features

import argparse


parser = argparse.ArgumentParser()
parser.add_argument('--wandb_mode', type=str, default='online',
                    help='Set to "disabled" to disable Weights & Biases logging')
parser.add_argument('--wandb_project', type=str, default='maml_baseline', help='Wandb project name')

parser.add_argument('--meta_batch_size', type=int)
parser.add_argument('--inner_epochs', type=int)
parser.add_argument('--inner_learning_rate', type=float)
parser.add_argument('--eval_inner_epochs', type=str)
parser.add_argument('--second_order', type=str2bool)
parser.add_argument('--train_subtype', type=str)
parser.add_argument('--val_subtype', type=str)
parser.add_argument('--meta_learning_rate', type=float)
parser.add_argument('--meta_grad_clip', type=float)
parser.add_argument('--train_hyperclip', type=str2bool)
parser.add_argument('--weight_change_reparam', type=str2bool)
parser.add_argument('--weight_change_reparam_alpha', type=float)

# meta_module
parser.add_argument('--inner_param', type=str, default="enet")
parser.add_argument('--hypernet_hidden_dim', type=str, default="128,128,128")
parser.add_argument('--straight_through', type=str2bool, default=False)

parser.add_argument('--load_checkpoint', type=str)

base_path = os.path.dirname(os.path.dirname(__file__))

torch.manual_seed(42)
rng = np.random.RandomState(42)
np.random.seed(42)

default_config = {
    #"meta_module" : {
    "inner_param": "mnet",
    "hypernet_hidden_dim": "128,128,128", #[128,128], #[128, 128, 128],
    "straight_through": True,

    "mainnet_use_bias": True,
    "mainnet_hidden_dim": [256],
    "embedding_dim": 128,
    # }

    "clip_model": "ViT-L/14@336px",

    "train_subtype": "test",
    "val_subtype": "train",

    "meta_epochs": 200, 
    "meta_batch_size": 32,
    "second_order": False, #True, 
    "meta_grad_clip": 10,

    "inner_epochs": 10,
    "inner_learning_rate": 0.1, 

    "eval_inner_epochs": '50',
    "meta_optimizer": "adam",
    "meta_learning_rate": 0.001,
    "meta_weight_decay": 0,
    "meta_adam_beta1": 0.9,
    "meta_adam_beta2": 0.999,

    "meta_momentum": 0.9,
    "meta_sgd_dampening": 0,
    "meta_sgd_nesterov": True,
    "meta_rmsprop_alpha": 0.99,

    "train_hyperclip": True,
    "hyperclip_model": "mlp",
    "hyperclip_hidden_dim": [128,128],
    "hyperclip_batch_size": 32,

    "hyperclip_optimizer": "adam",
    "hyperclip_learning_rate": 0.002,
    "hyperclip_weight_decay": 0,
    "hyperclip_adam_beta1": 0.9,
    "hyperclip_adam_beta2": 0.999,
    "hyperclip_momentum": 0.9,
    "hyperclip_sgd_dampening": 0,
    "hyperclip_sgd_nesterov": True,
    "hyperclip_rmsprop_alpha": 0.99,

    "weight_change_reparam": True,
    "weight_change_reparam_alpha": 0.1,


    "guidance_init_l2_weight":0,

    "val_epoch_interval": 10,
    "save_checkpoint": False,
    "load_checkpoint": "",
}


def main(args):
    cfg=default_config
    cfg.update({k:v for (k,v) in vars(parser.parse_args()).items() if v is not None})

    wandb.init(project=args.wandb_project, entity="srl_ethz", config=cfg,
               mode=args.wandb_mode)
    config = wandb.config

    # Load the model
    device = "cuda" if torch.cuda.is_available() else "cpu"

    with open(base_path + "/data/VQA/Meta/meta_train.json") as file:
        train_data = json.load(file)
    with open(base_path + "/data/VQA/Meta/meta_test.json") as file:
        test_data = json.load(file)

    # pre-computed features
    image_features = load_image_features(config["clip_model"])
    text_features = load_text_features(config["clip_model"])
    ques_emb = load_ques_features(config["clip_model"])

    meta_module = MetaModel(
       inner_param=config["inner_param"],
       mainnet_use_bias=config["mainnet_use_bias"],
       mainnet_hidden_dim=config["mainnet_hidden_dim"],
       hypernet_hidden_dim=[int(i) for i in config["hypernet_hidden_dim"].split(",")],
       embedding_dim=config["embedding_dim"],
       straight_through=config["straight_through"],
       config=config).to(device)

    if config["load_checkpoint"] != "":
        loaded_model_path = base_path + str(config["load_checkpoint"])
        meta_module.load_state_dict(torch.load(loaded_model_path), strict=False)

    if config["weight_change_reparam"]:
        new_base_mnet = custom_hnet.CLIPAdapter(e_dim=clip_utils.embedding_size[config["clip_model"]], 
                                                hidden_layers=config["mainnet_hidden_dim"], 
                                                use_bias=config["mainnet_use_bias"], no_weights=False, ignore_passed_weights=True).to(device)
        meta_module.mnet = AlphaReparam(meta_module.mnet, alpha=config["weight_change_reparam_alpha"], custom_start_net=new_base_mnet)

    meta_optimizer = build_optimizer(meta_module.meta_params, config, loop="meta")
    meta_trainer = MAML(meta_module, meta_optimizer, image_features, text_features, ques_emb, config)

    hyperclip_training = None
    if config["train_hyperclip"]:
        hyperclip = build_hyperclip_from_classic_clip(os.path.expanduser(clip_utils.cached_location[config["clip_model"]]),
                                                      hyper_model = config["hyperclip_model"],
                                                      mainnet_param_count = meta_module.mnet.get_parameter_vector().shape[0],
                                                      hyper_hidden_dims = config["hyperclip_hidden_dim"],
                                                      pretrained_it_location =
                                                      os.path.expanduser(clip_utils.cached_location[config["clip_model"]])).to(device)
        hyperclip.hyper.train()
        hyperclip_optimizer = build_optimizer(hyperclip.hyper.parameters(), config, loop="hyperclip")
        hyperclip_training = HyperclipTraining(hyperclip, hyperclip_optimizer,
                                               clip_utils.embedding_size[config["clip_model"]], config["hyperclip_batch_size"],
                                               len(list(test_data.keys())), device)

    best_val_acc=[0]*(1+len(config["eval_inner_epochs"].split(",")))
    best_val_epoch=[0]*(1+len(config["eval_inner_epochs"].split(",")))

    for meta_epoch in range(config["meta_epochs"]):
        meta_trainer.run_epoch(train_data, config["inner_epochs"], config["inner_learning_rate"], meta_batch_size=config["meta_batch_size"], 
                               train=config["load_checkpoint"] == "", second_order=config["second_order"], meta_grad_clip=config["meta_grad_clip"],
                               train_subtype = config["train_subtype"], val_subtype=config["val_subtype"],
                               hyperclip_training=hyperclip_training, debug=True)

        if meta_epoch % config["val_epoch_interval"] == 0:
            log_dict = meta_trainer.run_epoch(train_data,  config["inner_epochs"], config["inner_learning_rate"],
                                              hyperclip_training=hyperclip_training, eval_hyperclip=True)
            log_metric(log_dict, "eval_train/")
            log_dict = meta_trainer.run_epoch(test_data,  config["inner_epochs"], config["inner_learning_rate"],
                                              hyperclip_training=hyperclip_training, eval_hyperclip=True)

            if best_val_acc[0] < log_dict["query_accuracy_end"]:
                best_val_acc[0] = log_dict["query_accuracy_end"]
                best_val_epoch[0] = meta_epoch
            log_dict["best_accuracy"] = best_val_acc[0]
            log_dict["best_epoch"] = best_val_epoch[0]
            log_metric(log_dict, "eval_val/")

            if log_dict["query_accuracy_end"] < 0.3:
                print("Stopping training")
                return

            if config["train_hyperclip"]:
                log_dict = meta_trainer.run_epoch(train_data, config["inner_epochs"], config["inner_learning_rate"],
                                                  hyperclip_training=hyperclip_training, eval_hyperclip=True, guided_inner=True)
                log_metric(log_dict, "guided_eval_train/")
                log_dict = meta_trainer.run_epoch(test_data, config["inner_epochs"], config["inner_learning_rate"],
                                                  hyperclip_training=hyperclip_training, eval_hyperclip=True, guided_inner=True)
                log_metric(log_dict, "guided_eval_val/")

            if config["eval_inner_epochs"] != '':
                for idx, inner_epochs in enumerate(config["eval_inner_epochs"].split(",")):
                    log_dict = meta_trainer.run_epoch(train_data,  int(inner_epochs), config["inner_learning_rate"])
                    log_metric(log_dict, "eval_train_{}step/".format(inner_epochs))

                    log_dict = meta_trainer.run_epoch(test_data,  int(inner_epochs), config["inner_learning_rate"])

                    if best_val_acc[idx+1] < log_dict["query_accuracy_end"]:
                        best_val_acc[idx+1] = log_dict["query_accuracy_end"]
                        best_val_epoch[idx+1] = meta_epoch
                    log_dict["best_accuracy"] = best_val_acc[idx+1]
                    log_dict["best_epoch"] = best_val_epoch[idx+1]

                    log_metric(log_dict, "eval_val_{}step/".format(inner_epochs))

        if config["save_checkpoint"] and meta_epoch % 10 == 0:
            model_output_path_checkpoint = base_path + "/evaluation/maml/meta_module" + str(wandb.run.name) + "_" + str(
                meta_epoch) + ".pth"
            torch.save(meta_module.state_dict(), model_output_path_checkpoint)
            if config["train_hyperclip"]:
                hyperclip_output_path_checkpoint = base_path + "/evaluation/maml/hyperclip_" + str(wandb.run.name) + "_" + str(
                    meta_epoch) + ".pth"
                torch.save(hyperclip.hyper.state_dict(), hyperclip_output_path_checkpoint)
            print(f"Checkpoint for meta-epoch {meta_epoch} saved!")

    model_output_path_checkpoint = base_path + "/evaluation/maml/meta_module" + str(wandb.run.name) + "_" + str(
        meta_epoch) + ".pth"
    print(f"Checkpoint for meta-epoch {meta_epoch} saved!")
    torch.save(meta_module.state_dict(), model_output_path_checkpoint)

    wandb.finish()


if __name__ == "__main__":
    args = parser.parse_args()
    main(args)
