import numpy as np
from tqdm import tqdm
from copy import deepcopy
from collections import OrderedDict

import torch
from torch.utils.data import DataLoader
import wandb

from torch.nn import functional as F

from data.dataloader.clip_vqa import CLIP_VQA
from model.custom_hnet import CLIPAdapter, HyperGenerator, HyperDiscriminator, EmbeddingModule
from utils import clip_utils

from sklearn.metrics import accuracy_score
from training.utils import *

class GradientBuffer():
    def __init__(self, param_list):
        self.param_list=param_list
        self.reset_buffer()

    def reset_buffer(self):
        self.grad_list=[torch.zeros_like(param) for param in self.param_list]
        self.num = 0

    def accumulate(self):
        for param, grad in zip(self.param_list, self.grad_list):
            if param.grad is not None:
                grad += param.grad.data
        self.num += 1

    def unload(self):
        for param, grad in zip(self.param_list, self.grad_list):
            if param.grad is not None:
                param.grad.data += grad/self.num
        self.reset_buffer()

class MAML():
    def __init__(self, meta_module, meta_optimizer, image_features, text_features, ques_emb, config):
        self.meta_module=meta_module
        self.meta_optimizer=meta_optimizer
        self.image_features=image_features
        self.text_features=text_features
        self.ques_emb=ques_emb
        self.config=config

    def run_epoch(self, data, inner_epochs, inner_lr, meta_batch_size=1, train=False, second_order=False,
                  meta_grad_clip=-1, train_subtype="train", val_subtype="test", hyperclip_training=None,
                  eval_hyperclip=False, debug=False, guided_inner=False):
        # TODO: deprecate hyperclip code

        if hyperclip_training is not None:
            hyperclip_training.reset_batch_count()
            hyperclip_training.reset_val_batch_count()
        log_dict = dict()
        buffer = GradientBuffer(self.meta_module.meta_params)
        tasks = list(data.keys())
        shuffled_train_tasks = torch.randperm(len(tasks))
        for inner_train_iter in tqdm(range(len(tasks))):
            curr_log_dict = dict()
            task_idx = shuffled_train_tasks[inner_train_iter]
            train_dataset = CLIP_VQA(meta_data=data,
                                     dataSubType=train_subtype,
                                     task=tasks[task_idx],
                                     image_features=self.image_features,
                                     text_features=self.text_features,
                                     ques_emb=self.ques_emb)
            test_dataset = CLIP_VQA(meta_data=data,
                                    dataSubType=val_subtype,
                                    task=tasks[task_idx],
                                    image_features=self.image_features,
                                    text_features=self.text_features,
                                    ques_emb=self.ques_emb)

            train_dataloader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
            test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

            init_inner_params = self.meta_module.get_inner_params()
            inner_params = self.meta_module.get_inner_params()
            train_start_acc, train_start_loss = test_accuracy(self.meta_module, train_dataloader, params=inner_params)
            val_start_acc, val_start_loss = test_accuracy(self.meta_module, test_dataloader, params=inner_params)

            # Inner loop
            for _ in range(inner_epochs):
                if guided_inner and hyperclip_training is not None:
                        task_ques_emb = next(iter(train_dataloader))["ques_emb"][0]
                        weights = self.meta_module.get_mainnet_weights(ques_emb=task_ques_emb, params=inner_params)
                        hyperclip = hyperclip_training.hyperclip

                        task_weight_emb = hyperclip.encode_hyper(weights)

                        norm_task_ques_emb = task_ques_emb / task_ques_emb.norm(dim=-1, keepdim=True)
                        norm_task_weight_emb = task_weight_emb / task_weight_emb.norm(dim=-1, keepdim=True)

                        inner_product_embs_loss = - norm_task_weight_emb @ norm_task_ques_emb.T

                        init_l2_loss = torch.stack(
                            [(ip-p).pow(2).sum() for (ip,p)
                             in zip(init_inner_params.values(), inner_params.values())]).sum()\
                                       *self.config["guidance_init_l2_weight"] / 2
                        inner_loss = inner_product_embs_loss + init_l2_loss
                        if debug and inner_train_iter % meta_batch_size == 0:
                            wandb.log({"debug_inner_guidance_loss": inner_product_embs_loss.item(),
                                       "debug_inner_l2_loss": init_l2_loss.item()})
                else:
                    outputs, labels = self.get_pred(self.meta_module, train_dataloader, params=inner_params)
                    inner_loss = F.cross_entropy(outputs, labels)

                    if debug and inner_train_iter % meta_batch_size == 0:
                        wandb.log({"debug_inner_loss": inner_loss.item()})
                grads = torch.autograd.grad(inner_loss, inner_params.values(), retain_graph=True,
                                            create_graph=True if train and second_order else False)
                params_next = OrderedDict()
                for (name, param), grad in zip(list(inner_params.items()), grads):
                    params_next[name] = param - inner_lr * grad
                inner_params = params_next

            # Train set accuracy
            train_end_acc, train_end_loss = test_accuracy(self.meta_module, train_dataloader, params=inner_params)
            val_end_acc, val_end_loss = test_accuracy(self.meta_module, test_dataloader, params=inner_params)

            curr_log_dict["query_accuracy_start"] = val_start_acc
            curr_log_dict["query_accuracy_end"] = val_end_acc
            curr_log_dict["support_accuracy_start"] = train_start_acc
            curr_log_dict["support_accuracy_end"] = train_end_acc
            curr_log_dict["query_loss_start"] = val_start_loss
            curr_log_dict["query_loss_end"] = val_end_loss
            curr_log_dict["support_loss_start"] = train_start_loss
            curr_log_dict["support_loss_end"] = train_end_loss

            if train:
                # Validation loss
                self.meta_module.zero_grad()
                outputs, labels = get_pred(self.meta_module, test_dataloader, params=inner_params)
                loss=F.cross_entropy(outputs, labels)
                loss.backward()
                buffer.accumulate()


            if hyperclip_training is not None and not eval_hyperclip:
                hc_dict = hyperclip_training.train(train_dataset.ques_emb[train_dataset.task], self.meta_module.mnet.get_parameter_vector().detach())
            if hyperclip_training is not None and eval_hyperclip:
                hc_dict = hyperclip_training.test(train_dataset.ques_emb[train_dataset.task], self.meta_module.mnet.get_parameter_vector().detach())
            curr_log_dict.update(hc_dict)

            def get_gradnorm(module):
                return np.sqrt(np.sum([p.grad.pow(2).sum().item() for p in module.parameters() if p.grad is not None])) if module is not None else -1

            if train and (inner_train_iter % meta_batch_size==0 or inner_train_iter==len(tasks)-1):
                buffer.unload()
                self.meta_optimizer.step()
                if meta_grad_clip>0:
                    torch.nn.utils.clip_grad_norm_(self.meta_module.parameters(), meta_grad_clip)

                if debug :
                    curr_log_dict["gradnorm_mnet"] = get_gradnorm(self.meta_module.mnet)
                    curr_log_dict["gradnorm_hnet"] = get_gradnorm(self.meta_module.hnet)
                    curr_log_dict["gradnorm_enet"] = get_gradnorm(self.meta_module.enet)

            append_dict(log_dict, curr_log_dict)

            if debug and inner_train_iter % meta_batch_size == 0:
                log_metric(mean_dict(log_dict), prefix = "debug_")
                log_dict=dict()


        return mean_dict(log_dict)
