from copy import deepcopy
import json
import os
import argparse
from sklearn.metrics import accuracy_score

import torch
from torch.nn import functional as F
import numpy as np
import wandb

from torch.utils.data import DataLoader
from data.dataloader.clip_vqa import CLIP_VQA
import model.custom_hnet as custom_hnet
from model.clip.hyperclip import build_hyperclip_from_classic_clip
from features.image_features import load_image_features
from features.ques_features import load_ques_features
from features.text_features import load_text_features
from model.custom_hnet import CLIPAdapter, EmbeddingModule, HyperEncoder, HyperGenerator, MetaModel
from utils import clip_utils
from utils.build_opt import build_optimizer
from utils.meta_utils import AlphaReparam, checkpoint_config_fields, checkpoint_config_fields_hnet, checkpoint_config_fields_gan, checkpoint_config_fields_hyperclip

parser = argparse.ArgumentParser()
parser.add_argument('--wandb_mode', type=str, default='online', help='Set to "disabled" to disable Weights & Biases logging')
parser.add_argument('--gen_run_id', type=str, default='srl_ethz/clip-hypernet/2ay9eyqh', help='The full "<entity>/<project>/<run_id>" identifier of the run to load for the generator')
parser.add_argument('--hyperclip_run_id', type=str, default=None, help='The full "<entity>/<project>/<run_id>" identifier of the run to load for the hyperclip. By default it will be the same as the generator')
parser.add_argument('--gen_load_file', type=str, default=None, help='The file to load the generative model checkpoint from')
parser.add_argument('--hyperclip_load_file', type=str, default=None, help='The file to load the hyperclip checkpoint from')

base_path = os.path.dirname(os.path.dirname(__file__))

torch.manual_seed(42)
rng = np.random.RandomState(42)
np.random.seed(42)

default_config = {

    "gan_algo": "vae",  #choose either "vae", "gan" or "hnet"

    "finetune_method": "hyperclip_guidance",      #choose either "data" or "hyperclip_guidance"

    "hypernet_embedding_dim": 128,

    "stop_loss": 0.01,              #Stop training if loss is below this value, "None" to disable
    "inner_epochs": 100,
    "inner_batch_size": 32,
    "inner_optimizer": "adam",
    "inner_learning_rate": 0.01,
    "inner_weight_decay": 0,
    "inner_adam_beta1": 0.9,
    "inner_adam_beta2": 0.999,
    "inner_momentum": 0,
    "inner_sgd_dampening": 0,
    "inner_sgd_nesterov": False,
    "inner_rmsprop_alpha": 0.99,

    "guidance_opt_parameter": "hnet_embedding",    #choose either "hnet_embedding" or "net_weights"
    "guidance_loss_term": "cosine",                    # l2 or cosine

    "guidance_stop_loss": None,
    "guidance_epochs": 100,
    "guidance_optimizer": "adam",
    "guidance_learning_rate": 0.001,
    "guidance_weight_decay": 0,
    "guidance_adam_beta1": 0.9,
    "guidance_adam_beta2": 0.999,
    "guidance_momentum": 0,
    "guidance_sgd_dampening": 0,
    "guidance_sgd_nesterov": False,
    "guidance_rmsprop_alpha": 0.99,

    "guidance_init_l2_weight": 0,
    "guidance_start_from_finetuned": False,
    "guidance on_set": "test"                  #choose either "train" or "test" (only use test to get a reliable number! train is just to see if guidance works at all!)

}

def main(args):
    config = default_config
    api = wandb.Api()
    gen_loaded_run = api.run(args.gen_run_id)
    gen_loaded_config = gen_loaded_run.config
    if "hnet" in config["gan_algo"]:
        checkpoint_config_fields_generator = checkpoint_config_fields_hnet
    elif "vae" in config["gan_algo"] or "gan" in config["gan_algo"]:
        checkpoint_config_fields_generator = checkpoint_config_fields_gan
    for key in checkpoint_config_fields_generator:
            config[key] = gen_loaded_config[key]
    if args.hyperclip_run_id is None:
        args.hyperclip_run_id = args.gen_run_id
    hyperclip_loaded_run = api.run(args.hyperclip_run_id)
    hyperclip_loaded_config = hyperclip_loaded_run.config
    for key in checkpoint_config_fields_hyperclip:
            if key in hyperclip_loaded_config.keys():
                config[key] = hyperclip_loaded_config[key]

    wandb.init(project="hyperclip-guidance", entity="srl_ethz", config=config, mode=args.wandb_mode)

    if args.gen_load_file is None:
        model_path = base_path + "/evaluation/ckp_reptile/clip_hypernet_"+str(gen_loaded_run.name)+".pth"
    else:
        model_path = base_path + "/evaluation/" + args.gen_load_file

    if args.hyperclip_load_file is None:
        hyperclip_path = base_path + "/evaluation/ckp_reptile/hyperclip_"+str(hyperclip_loaded_run.name)+".pth"
    else:
        hyperclip_path = base_path + "/evaluation/" + args.hyperclip_load_file

    # Load the model
    device = "cuda" if torch.cuda.is_available() else "cpu"

    with open(base_path + "/data/VQA/Meta/meta_train.json") as file:
        train_data = json.load(file) 
    with open(base_path + "/data/VQA/Meta/meta_test.json") as file:
        test_data = json.load(file)

    train_tasks = list(train_data.keys())
    test_tasks = list(test_data.keys())

    #pre-computed features
    image_features = load_image_features(config["clip_model"])
    text_features = load_text_features(config["clip_model"])
    ques_emb = load_ques_features(config["clip_model"])

    if config["guidance_opt_parameter"] == "hnet_embedding":
        if "hnet" in config["gan_algo"]:
            meta_module = MetaModel(inner_param="enet", mainnet_use_bias=config["mainnet_use_bias"],
                                mainnet_hidden_dim=config["mainnet_hidden_dim"], hypernet_hidden_dim=config["hypernet_hidden_dim"],
                                embedding_dim=config["hypernet_embedding_dim"], config=config).to(device)
        elif "vae" in config["gan_algo"]:
            meta_module = MetaModel(inner_param="enet", mainnet_use_bias=config["mainnet_use_bias"],
                                mainnet_hidden_dim=config["mainnet_hidden_dim"], hypernet_hidden_dim=[int(h) for h in config["hypernet_hidden_dim"].split(",")],
                                embedding_dim=config["hypergan_noise_dim"], config=config).to(device)
            
            hnet_dis = HyperEncoder(meta_module.mnet, e_dim=config["hypergan_noise_dim"],hidden_dims=[int(h) for h in config["hypernet_hidden_dim"].split(",")]).to(device)
        else:
            raise NotImplementedError("Not implemented for this GAN algorithm")
    elif config["guidance_opt_parameter"] == "net_weights":
        meta_module = MetaModel(inner_param="mnet", mainnet_use_bias=config["mainnet_use_bias"],
                                mainnet_hidden_dim=config["mainnet_hidden_dim"], config=config).to(device)

    # Load the hyperclip model from the checkpoint
    hyperclip = build_hyperclip_from_classic_clip(os.path.expanduser(clip_utils.cached_location[config["clip_model"]]),
                                                    hyper_model = config["hyperclip_model"],
                                                    mainnet_param_count = meta_module.mnet.get_parameter_vector().shape[0],
                                                    hyper_hidden_dims = config["hyperclip_hidden_dim"],
                                                    pretrained_it_location = os.path.expanduser(clip_utils.cached_location[config["clip_model"]]),
                                                    pretrained_hyper_location=hyperclip_path).to(device)
    hyperclip.eval()

    if config["weight_change_reparam"]:

        # define with-weights start_mnet if building a weight-change reparam model
        start_mnet = custom_hnet.CLIPAdapter(e_dim=clip_utils.embedding_size[config["clip_model"]], hidden_layers=config["mainnet_hidden_dim"], use_bias=config["mainnet_use_bias"], no_weights=False, ignore_passed_weights=True).to(device)
        meta_module.mnet = AlphaReparam(meta_module.mnet, alpha=config["weight_change_reparam_alpha"], custom_start_net=start_mnet)

    if "hnet" in config["gan_algo"] or config["guidance_opt_parameter"] == "net_weights":
        #TODO: right now only supports loading maml-learned mnet init, not the hnet (was only used for random-hnet experiments)
        meta_module.mnet.load_state_dict(torch.load(model_path), strict=False)
    elif "vae" in config["gan_algo"]:
        meta_module.hnet.load_state_dict(torch.load(model_path), strict=False)
        hnet_dis.load_state_dict(torch.load(model_path.replace("_gen", "_dis")), strict=False)
        #load mnet too, so that for the alpha reparam, it loads the start_net correctly
        if config["load_checkpoint_run_id"] is not None and config["load_checkpoint_run_id"] != "":
            loaded_model_path = base_path + "/evaluation/"+str(config["load_checkpoint_run_id"])
            meta_module.mnet.load_state_dict(torch.load(loaded_model_path), strict=False)
            #note: embedding in the gen hnet is still random, so generated mnet (not start_net) will still be arbitrary. Should not matter if we do "guidance_start_from_finetuned"

    if config["finetune_method"] == "data":
        #TODO: support vae
        meta_val_task_start_loss_ar = []
        meta_val_task_end_loss_ar = []
        meta_val_task_start_accuracy_ar = []
        meta_val_task_end_accuracy_ar = []
        for val_task_id in range(len(test_tasks)):
            val_train_dataset = CLIP_VQA(meta_data=test_data,
                                        dataSubType='train',
                                        task=test_tasks[val_task_id], 
                                        image_features = image_features,
                                        text_features = text_features,
                                        ques_emb = ques_emb)

            val_test_dataset = CLIP_VQA(meta_data=test_data,
                                        dataSubType='test',
                                        task=test_tasks[val_task_id],
                                        image_features = image_features,
                                        text_features = text_features,
                                        ques_emb = ques_emb)

            val_train_dataloader = DataLoader(val_train_dataset, batch_size=config["inner_batch_size"],
                                            shuffle=True)
            val_test_dataloader = DataLoader(val_test_dataset, batch_size=32, shuffle=False)

            inner_params = meta_module.get_inner_params()
            # Validation zero-shot accuracy
            val_start_acc, val_start_loss = test_accuracy(meta_module, val_test_dataloader, params=inner_params)
            meta_val_task_start_accuracy_ar += [val_start_acc]

            # Validation inner-loop training
            weights_before = deepcopy(meta_module.state_dict())
            # Important: only optimize enet, meta optimizer is for when we use data
            inner_optimizer = build_optimizer(meta_module.inner_params.values(), config, loop="inner")

            start_loss, end_loss = inner_train(meta_module,
                                            val_train_dataloader,
                                            inner_optimizer,
                                            config["inner_epochs"],
                                            device, stop_loss = config["stop_loss"],
                                            params=inner_params)
            meta_val_task_start_loss_ar += [start_loss]
            meta_val_task_end_loss_ar += [end_loss]

            # Validation inner test accuracy
            val_end_acc, val_end_loss = test_accuracy(meta_module, val_test_dataloader, params=inner_params)
            print("Validation Task {}: {}, Zero-Shot Test Acc: {}, Post-Update Test Acc: {}".format(val_task_id, test_tasks[val_task_id], val_start_acc, val_end_acc))
            meta_val_task_end_accuracy_ar += [val_end_acc]

            # restore model uncontaminated by validation
            meta_module.load_state_dict(weights_before)

        meta_val_task_start_mean_loss = np.mean(np.array(meta_val_task_start_loss_ar))
        meta_val_task_end_mean_loss = np.mean(np.array(meta_val_task_end_loss_ar))
        meta_val_task_start_mean_accuracy = np.mean(np.array(meta_val_task_start_accuracy_ar))
        meta_val_task_end_mean_accuracy = np.mean(np.array(meta_val_task_end_accuracy_ar))
        print("Validation Zero-Shot Test Accuracy: {}, Post-Update Test Accuracy: {}".format(meta_val_task_start_mean_accuracy, meta_val_task_end_mean_accuracy))
        wandb.log({"meta_val_task_start_mean_loss": meta_val_task_start_mean_loss,
                    "meta_val_task_end_mean_loss": meta_val_task_end_mean_loss,
                    "meta_val_task_start_mean_accuracy": meta_val_task_start_mean_accuracy,
                    "meta_val_task_end_mean_accuracy": meta_val_task_end_mean_accuracy,
                    "meta_val_task_diff_mean_accuracy": meta_val_task_end_mean_accuracy - meta_val_task_start_mean_accuracy})

    elif config["finetune_method"] == "hyperclip_guidance":

        if config["guidance_start_from_finetuned"]:
            task_idx = 42
            train_dataset = CLIP_VQA(meta_data = train_data, 
                                    dataSubType = 'traintest', 
                                    task = train_tasks[task_idx], 
                                    image_features = image_features,
                                    text_features = text_features,
                                    ques_emb = ques_emb)
            train_dataloader = DataLoader(train_dataset, batch_size=config["inner_batch_size"], shuffle=True)
            inner_params = meta_module.get_inner_params()
            train_start_acc, train_start_loss = test_accuracy(meta_module, train_dataloader, params=inner_params)
            inner_optimizer = build_optimizer(meta_module.inner_params.values(), config, loop="inner")

            start_loss, end_loss = inner_train(meta_module,
                                            train_dataloader,
                                            inner_optimizer,
                                            config["inner_epochs"],
                                            device, stop_loss = config["stop_loss"],
                                            params=inner_params)
            train_end_acc, train_end_loss = test_accuracy(meta_module, train_dataloader, params=inner_params)
            print("Init Train Task {}: {}, Zero-Shot Train Acc: {}, Post-Update Train Acc: {}".format(task_idx, train_tasks[task_idx], train_start_acc, train_end_acc))


        meta_val_task_start_loss_ar = []
        meta_val_task_end_loss_ar = []
        meta_val_task_start_accuracy_ar = []
        meta_val_task_end_accuracy_ar = []
        for val_task_id in range(len(test_tasks if config["guidance on_set"] == "test" else train_tasks)):

            val_test_dataset = CLIP_VQA(meta_data=test_data if config["guidance on_set"] == "test" else train_data,
                                        dataSubType='test',
                                        task=test_tasks[val_task_id] if config["guidance on_set"] == "test" else train_tasks[val_task_id],
                                        image_features = image_features,
                                        text_features = text_features,
                                        ques_emb = ques_emb)

            task_ques_emb = ques_emb[test_tasks[val_task_id]][0] if config["guidance on_set"] == "test" else ques_emb[train_tasks[val_task_id]][0]

            val_test_dataloader = DataLoader(val_test_dataset, batch_size=32, shuffle=False)

            inner_params = meta_module.get_inner_params()
            # Validation zero-shot accuracy
            val_start_acc, val_start_loss = test_accuracy(meta_module, val_test_dataloader, params=inner_params)
            meta_val_task_start_accuracy_ar += [val_start_acc]

            # Validation inner-loop training
            weights_before = deepcopy(meta_module.state_dict())

            if config["guidance_opt_parameter"] == "hnet_embedding":
                # Important: only optimize enet, meta optimizer is for when we use data
                e = meta_module.inner_params["enet.embedding"]
                guidance_optimizer = build_optimizer([e], config, loop="guidance")
                e_old = e.clone()
            elif config["guidance_opt_parameter"] == "net_weights":
                net_params = meta_module.inner_params.values()
                guidance_optimizer = build_optimizer(net_params, config, loop="guidance")
                net_params_old = [p.clone() for p in net_params]

            loss = 0
            for epoch in range(config["guidance_epochs"]):

                guidance_optimizer.zero_grad()

                weights = get_mainnet_weights(meta_module, ques_emb = task_ques_emb, reparam = config["weight_change_reparam"], params=inner_params)
            
                task_weight_emb = hyperclip.encode_hyper(weights)

                norm_task_ques_emb = task_ques_emb / task_ques_emb.norm(dim=-1, keepdim=True)
                norm_task_weight_emb = task_weight_emb / task_weight_emb.norm(dim=-1, keepdim=True)

                if config["guidance_loss_term"] == "cosine":
                    inner_product_embs_loss = - norm_task_weight_emb @ norm_task_ques_emb.T
                elif config["guidance_loss_term"] == "l2":
                    inner_product_embs_loss = (norm_task_weight_emb - norm_task_ques_emb).norm(dim=-1).div(2).arcsin().pow(2).mul(2)

                if config["guidance_opt_parameter"] == "hnet_embedding":
                    init_l2_loss = F.mse_loss(e, e_old) * config["guidance_init_l2_weight"] / 2
                elif config["guidance_opt_parameter"] == "net_weights":
                    init_l2_loss = 0
                    for p, p_old in zip(net_params, net_params_old):
                        init_l2_loss += F.mse_loss(p, p_old) * config["guidance_init_l2_weight"] / 2

                loss = inner_product_embs_loss + init_l2_loss

                if epoch == 0:
                    start_loss = loss.detach().cpu().numpy()
                if val_task_id == 0:
                    wandb.log({"guidance_loss": loss.detach().cpu().numpy()})
                loss.backward()
                guidance_optimizer.step()

                if config["guidance_stop_loss"] is not None and loss < config["guidance_stop_loss"]:
                    break

            end_loss = loss.detach().cpu().numpy()

            
            meta_val_task_start_loss_ar += [start_loss]
            meta_val_task_end_loss_ar += [end_loss]

            # Validation inner test accuracy
            val_end_acc, val_end_loss = test_accuracy(meta_module, val_test_dataloader, params=inner_params)
            print("Validation Task {}: {}, Zero-Shot Test Acc: {}, Post-Update Test Acc: {}".format(val_task_id, test_tasks[val_task_id] if config["guidance on_set"] == "test" else train_tasks[val_task_id], val_start_acc, val_end_acc))
            meta_val_task_end_accuracy_ar += [val_end_acc]

            # restore model uncontaminated by validation
            meta_module.load_state_dict(weights_before)

        meta_val_task_start_mean_loss = np.mean(np.array(meta_val_task_start_loss_ar))
        meta_val_task_end_mean_loss = np.mean(np.array(meta_val_task_end_loss_ar))
        meta_val_task_start_mean_accuracy = np.mean(np.array(meta_val_task_start_accuracy_ar))
        meta_val_task_end_mean_accuracy = np.mean(np.array(meta_val_task_end_accuracy_ar))
        print("Validation Zero-Shot Test Accuracy: {}, Post-Update Test Accuracy: {}".format(meta_val_task_start_mean_accuracy, meta_val_task_end_mean_accuracy))
        wandb.log({"meta_val_task_start_mean_loss": meta_val_task_start_mean_loss,
                    "meta_val_task_end_mean_loss": meta_val_task_end_mean_loss,
                    "meta_val_task_start_mean_accuracy": meta_val_task_start_mean_accuracy,
                    "meta_val_task_end_mean_accuracy": meta_val_task_end_mean_accuracy,
                    "meta_val_task_diff_mean_accuracy": meta_val_task_end_mean_accuracy - meta_val_task_start_mean_accuracy})


    wandb.finish()

def get_pred(meta_module, dataloader, params=None):
        y_pred = []
        y_true = []
        for sample in dataloader:
            sample_image_features = sample["image_features"]
            sample_text_features = sample["text_features"]
            sample_ques_emb = sample["ques_emb"][0]
            labels = sample["label"].to(sample_image_features.device)
            similarity = meta_module(sample_image_features, sample_text_features, sample_ques_emb, params=params)
            y_pred.append(similarity)
            y_true.append(labels)

        return torch.cat(y_pred), torch.cat(y_true)

def test_accuracy(meta_module, dataloader, params=None):
        # Validation inner-loop testing
        meta_module.eval()

        with torch.no_grad():
            output, y_true = get_pred(meta_module, dataloader, params=params)
            _, y_pred = output.topk(1)
            loss = F.cross_entropy(output, y_true)

        acc = accuracy_score(y_true.cpu().numpy(), y_pred.cpu().numpy())
        meta_module.train()
        return acc, loss.item()

def inner_train(meta_module, train_dataloader, optimizer, inner_epochs, device, stop_loss= None, params=None):
    loss = 0
    for epoch in range(inner_epochs):
        for batch_id, sample in enumerate(train_dataloader):
            sample_image_features = sample["image_features"]
            sample_text_features = sample["text_features"]
            sample_ques_emb = sample["ques_emb"]
            labels = sample["label"].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            outputs = meta_module(sample_image_features, sample_text_features, sample_ques_emb[0], params=params)

            loss = F.cross_entropy(outputs, labels)
            if epoch == 0 and batch_id == 0:
                start_loss = loss.detach().cpu().numpy()
            loss.backward()
            optimizer.step()
        if stop_loss is not None and loss < stop_loss:
            break

    end_loss = loss.detach().cpu().numpy()
    return start_loss, end_loss

def get_mainnet_weights(meta_module, ques_emb = None, reparam = False, params=None):
    if meta_module.hnet is not None:
        if meta_module.enet is None:
            return meta_module.hnet.forward(uncond_input=ques_emb, params=meta_module.get_subdict(params, "hnet"))
        else:
            return meta_module.hnet.forward(uncond_input=meta_module.enet(params=meta_module.get_subdict(params, "enet")))
    else:
        if reparam:
            return meta_module.mnet.net.get_parameter_vector()
        else:
            return meta_module.mnet.get_parameter_vector()


if __name__ == "__main__":
    args = parser.parse_args()
    main(args)


