import copy

import torch
import torch.nn as nn
from tqdm import tqdm
import pickle
from dataset.cc12m import cc12m
from trainer.finetune import FinetuneCLIP
import os 
def zerolike_params_dict(model, device=None):
    """
    Create a list of (name, parameter), where parameter is initalized to zero.
    The list has as many parameters as model, with the same size.

    :param model: a pytorch model
    """

    return [
        (k, torch.zeros_like(p).to(p.device if (device == None) else device))
        for k, p in model.named_parameters()
    ]


def copy_params_dict(model, copy_grad=False, device=None):
    """
    Create a list of (name, parameter), where parameter is copied from model.
    The list has as many parameters as model, with the same size.

    :param model: a pytorch model
    :param copy_grad: if True returns gradients instead of parameter values
    """

    if copy_grad:
        return [(k, p.grad.data.detach().clone()) for k, p in model.named_parameters()]
    else:
        return [(k, p.data.detach().clone().to(p.device if (device == None) else device)) for k, p in
                model.named_parameters()]


class MASEDIT(FinetuneCLIP):
    def __init__(self, args):
        super().__init__(args)
        self.magnitudes = {}
        self.mask = {}
        self.alpha = 0.5
        self._lambda = self.args.scale
        self.importance_computed = False

        self.trainable_params = []

    def setup_importance(self, model):
        # Parameters before the first task starts
        self.params = dict(copy_params_dict(model))
        # Initialize Fisher information weight importance
        self.importance = dict(zerolike_params_dict(model))

    def unfreeze_model(self, model):
        model.train()
        for name, param in model.named_parameters():
            if self.args.update_all:
                trainable_params = True
            elif name == 'visual.proj':
                if self.args.finetune_proj:
                    trainable_params = True
                else:
                    trainable_params = False
            elif 'visual' in self.args.method:
                trainable_params = self.args.edit_layer in name and 'visual' in name
            else:
                trainable_params = self.args.edit_layer in name
            if trainable_params:
                param.requires_grad = True
                if name not in self.trainable_params:
                    self.trainable_params.append(name)
            else:
                param.requires_grad = False
        print('Trainable parameters: ', self.trainable_params)

    def compute_importance(self, dataset, model, task):
        if task == 0:
            self.setup_importance(model)
            if 'condset' in self.args.mas_importance_compute:
                print ('Compute importance for conditional set...')
                condset = cc12m(transform=dataset.transform)
                cond_dataloader = self.get_loader(condset)
                self.compute_update_importance(model, cond_dataloader)
        elif 'curset' in self.args.mas_importance_compute:
            prev_set = dataset.get_dataset(task - 1, is_train=True, with_buffer=False)
            buffer = self.get_loader(prev_set)
            print ('Compute importance for the last task...')
            self.compute_update_importance(model, buffer)

        cur_set = dataset.get_dataset(task, is_train=True, with_buffer=False)
        loader = self.get_loader(cur_set, is_train=True)
        print('Compute importance for the current task...')
        cur_importance = self.compute_importance_score(model, loader, loss_type=self.args.select_loss_type,task=task,dataset=dataset)
        if self.args.save_ckpt:
            with open(os.path.join(self.args.log_path, f'task{task}_importance.torchSave'), 'wb') as file:
                torch.save(cur_importance, file)
        if self.args.score in ['mulact', 'normact', 'act']:
            cur_activation = self._get_activation(model, loader)
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name in self.trainable_params:
                    if any(param in name for param in self.args.full_update_param) or self.args.sparsity == 1.0:
                        self.mask[name] = torch.ones(param.shape, dtype=param.dtype).to(self.args.device)
                        continue
                    if name not in cur_importance.keys():
                        print(f' importance of `{name} is none')
                        continue
                    importance = cur_importance[name]
                    # sparse update of weight and  bias
                    if self.args.selection == 'neuron':
                        if self.args.score == 'norm':
                            magnitudes = torch.norm(importance, dim=1)
                            print (importance, magnitudes)
                        elif self.args.score == 'random':
                            magnitudes = torch.randn(param.grad.shape[0])
                        elif self.args.score == 'cos' and self.importance_computed:
                            magnitudes = (
                                        (importance * self.importance[name]).sum(dim=1) / self.importance[name].norm(dim=1))
                        elif self.args.score == 'soft' and self.importance_computed:
                            magnitudes = (importance * self.importance[name]).sum(dim=1) / self.importance[name].norm(dim=1)
                        elif self.args.score == 'div' and self.importance_computed:
                            magnitudes = torch.norm(importance / self.importance[name], dim=1)
                        elif self.args.score == 'normgrad':
                            magnitudes = torch.norm(importance, dim=1).cpu() / torch.norm(param.data, dim=1).cpu()
                        elif self.args.score == 'mulact':
                            magnitudes = torch.norm(importance, dim=1).cpu() * abs(
                                cur_activation['model.' + name.replace('.weight', '')])
                        elif self.args.score == 'normact':
                            magnitudes = torch.norm(importance, dim=1).cpu() / abs(
                                cur_activation['model.' + name.replace('.weight', '')])
                        elif self.args.score == 'act':
                            magnitudes = abs(cur_activation['model.' + name.replace('.weight', '')])
                        else:
                            raise ValueError
                        num_to_update = int(param.shape[0] * self.args.sparsity)
                        _, indices = torch.topk(magnitudes, num_to_update)
                        self.mask[name] = torch.zeros(param.grad.shape, dtype=param.grad.dtype).to(self.args.device)
                        self.mask[name][indices] = 1
                    elif self.args.selection == 'neuron-weight':
                        if self.args.score == 'norm':
                            magnitudes = importance.norm(dim=-1)
                            k = int(magnitudes.numel() * 0.5)
                            topk_values, topk_indices = torch.topk(magnitudes, k=k)
                            min_value = topk_values[-1]
                            mask = torch.zeros_like(magnitudes)
                            mask[magnitudes>min_value] = 1
                            magnitudes = importance.abs() * mask.reshape(-1,1)
                            k = int(magnitudes.numel() * self.args.sparsity)
                            topk_values, topk_indices = torch.topk(magnitudes.view(-1), k=k)
                            mask = torch.zeros_like(magnitudes)
                            min_value = topk_values[-1]
                            mask[magnitudes>min_value] = 1
                            self.mask[name] = mask
                            
                    elif self.args.selection == 'weight':
                        if self.args.score == 'norm':
                            magnitudes = importance.abs()
                        elif self.args.score == 'random':
                            magnitudes = torch.randn(param.grad.shape).to(self.args.device)
                        else:
                            raise ValueError

                        k = int(magnitudes.numel() * self.args.sparsity)
   
                        topk_values, topk_indices = torch.topk(magnitudes.view(-1), k=k)
                        self.mask[name] = torch.zeros_like(magnitudes).to(self.args.device)
                        self.mask[name].view(-1)[topk_indices] = 1


    def update_model(self, model, optimizer, **kwargs):
        count = kwargs.get('count', 0)
        epoch = kwargs.get('epoch', 0)
        with torch.no_grad():
            for name, param in model.named_parameters():
                gradients = param.grad
                if gradients is not None:
                    param.grad = self.mask[name] * param.grad
                    # Update only the 1% most activated entries
                    # param.data -= optimizer.param_groups[0]['lr'] * param.grad
        optimizer.step()

    def compute_importance_penalty(self, args, model):
        loss_reg = torch.tensor(0).float().to(self.args.device)

        # Apply penalty term
        for name, param in model.named_parameters():

            if name in self.trainable_params:
                loss_reg += torch.sum(
                    self.importance[name] *
                    (param - self.params[name].expand(param.shape)).pow(2)
                )

        # Update loss
        return loss_reg

    def compute_loss(self, batch, model, **kwargs):
        buffer = kwargs.get('buffer', None)
        epoch = kwargs.get('epoch', 0)
        loss_img = nn.CrossEntropyLoss()
        loss_txt = nn.CrossEntropyLoss()
        images, _, texts = batch
        if buffer and epoch > 0:
            images_b, _, texts_b = buffer
            images = torch.cat([images, images_b])
            texts = torch.cat([texts, texts_b])

        images = images.to(self.args.device)
        texts = texts.to(self.args.device)
        ground_truth = torch.arange(len(images), dtype=torch.long, device=self.args.device)

        logits_per_image, logits_per_text = model(images, texts)

        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
        if self.args.scale > 0.0:
            penalty = self._lambda * self.compute_importance_penalty(self.args, model)
            return total_loss + penalty
        else:
            return total_loss



    def compute_update_importance(self, model, dataloader):
        self.params = dict(copy_params_dict(model), device=self.args.device)

        # Get importance
        curr_importance = self._get_importance(model, dataloader)
        if not self.importance_computed:
            self.importance = curr_importance
            self.importance_computed = True
            return
        else:

            # Update importance
            for name in self.importance.keys():
                self.importance[name] = (self.alpha * self.importance[name]
                                         + (1 - self.alpha) * curr_importance[name].data)

    def _get_importance(self, model, dataloader,  loss_type='l2'):


        # Initialize importance matrix
        importance = dict(zerolike_params_dict(model, device=self.args.device))

        # Do forward and backward pass to accumulate L2-loss gradients
        model.train()
        size = 0
        
        for num_batch, batch in enumerate(tqdm(dataloader)):
            # Get batch
            images, _, texts = batch
            images = images.to(self.args.device)
            texts = texts.to(self.args.device)

            # Forward pass
            model.zero_grad()
            logits_per_image, logits_per_text = model(images, texts)

            # Average L2-Norm of the output
            if loss_type == 'l2':
                loss = torch.norm(logits_per_image, p="fro", dim=1).pow(2).mean()
            elif loss_type == 'cn':
                ground_truth = torch.arange(len(images), dtype=torch.long, device=self.args.device)
                loss_img = nn.CrossEntropyLoss()
                loss_txt = nn.CrossEntropyLoss()
                loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
            else:
                raise ValueError
            loss.backward()

            # Accumulate importance
            for name, param in model.named_parameters():
                if param.requires_grad:
                    if param.grad is not None:
                        importance[name].data += param.grad.clone().abs() * len(images)
            size += len(images)
            

        # Normalize importance
        importance = {
            name: importance[name] / size
            for name in importance.keys()
        }
        if self.args.importance_max_normalize:
            for name in importance.keys():
                if name in self.trainable_params:
                    importance[name] = torch.nn.functional.normalize(importance[name],dim=-1,p=torch.inf)

        return importance
    
    
    def compute_importance_score(self, model, dataloader,  loss_type='l2', **kwargs):


        # Initialize importance matrix
        importance = dict(zerolike_params_dict(model, device=self.args.device))

        # Do forward and backward pass to accumulate L2-loss gradients
        model.train()
        model.zero_grad()
        total_batch =  len(dataloader)
        num_batch_for_importance = total_batch * self.args.cur_importance_batch_percentage
        print(f'Total batch for importance {total_batch}, use {num_batch_for_importance} batches')
        stop_flag = 1


        for num_batch, batch in enumerate(tqdm(dataloader)):
            stop_flag = 1
            # Get batch
            images, _, texts = batch
            images = images.to(self.args.device)
            texts = texts.to(self.args.device)

            # Forward pass
            logits_per_image, logits_per_text = model(images, texts)

            # Average L2-Norm of the output
            if loss_type == 'l2':
                loss = torch.norm(logits_per_image, p="fro", dim=1).pow(2).mean()
            elif loss_type == 'cn':
                ground_truth = torch.arange(len(images), dtype=torch.long, device=self.args.device)
                loss_img = nn.CrossEntropyLoss()
                loss_txt = nn.CrossEntropyLoss()
                loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
            else:
                raise ValueError
            loss.backward()

            # Accumulate importance
            for name, param in model.named_parameters():
                if param.requires_grad:
                    if param.grad is not None:
                        importance[name].data += param.grad.clone()
                        if importance[name].data.abs().min() < 1e-12:
                            stop_flag = 0
            if num_batch > num_batch_for_importance and stop_flag:
                break

        return importance

    def _get_activation(self, model, loader):
        def parent_module(model, pname):
            comps = pname.split('.')
            parent = model
            for comp in comps[:-1]:
                if hasattr(parent, comp):
                    parent = getattr(parent, comp)
                elif comp.isdigit():
                    parent = parent[int(comp)]
                else:
                    raise RuntimeError(f"Couldn't find child module {comp}")
            assert hasattr(parent, comps[-1])
            return parent

        def linear_forward_hook(mod, activations, output):
            assert len(activations) == 1
            mod.__x__ = activations[0].detach()

            mod.__out__ = output.detach()

        def hook_clip(model, pnames):
            handles = []
            for m in [parent_module(model, pname) for pname in pnames]:
                handles.append(m.register_forward_hook(linear_forward_hook))

            model.handles = handles

        def collect_output(model, out):
            for i in range(12):
                aa = model.visual.transformer.resblocks[i].mlp.c_proj.__out__.detach().clone().cpu().to(torch.float32)
                name = f'model.visual.transformer.resblocks.{i}.mlp.c_proj'
                if name not in out.keys():
                    out[name] = aa
                else:
                    out[name] = torch.cat([out[name], aa], dim=1)
                aa = model.transformer.resblocks[i].mlp.c_proj.__out__.detach().clone().cpu().to(torch.float32)
                name = f'model.transformer.resblocks.{i}.mlp.c_proj'
                if name not in out.keys():
                    out[name] = aa
                else:
                    out[name] = torch.cat([out[name], aa], dim=1)
            return out

        out = {}
        model_for_select = copy.deepcopy(model)
        hook_clip(model_for_select, self.trainable_params)
        for iiter, batch in enumerate(tqdm(loader)):
            loss_img = nn.CrossEntropyLoss()
            loss_txt = nn.CrossEntropyLoss()
            images, _, texts = batch
            images = images.cuda()
            texts = texts.cuda()
            ground_truth = torch.arange(len(images), dtype=torch.long, device='cuda')

            logits_per_image, logits_per_text = model_for_select(images, texts)

            total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
            out = collect_output(model_for_select, out)
            if iiter > 20:
                break
        with torch.no_grad():
            num_top_tokens = 5
            layer_outputs = out
            act = {}
            # Loop over each layer in the layer_outputs dictionary
            for layer_name, layer_output in layer_outputs.items():
                # Get the mean activations for each token across all images
                mean_activations, _ = torch.max(layer_output.abs(), dim=2)
                # Get the indices of the top num_top_tokens activations for each token
                top_values = torch.topk(mean_activations, num_top_tokens, dim=0)[0]
                # Calculate the mean of the top features across all images and tokens
                mean_top_token_features = torch.mean(top_values, dim=(0, 1))
                act[layer_name] = mean_top_token_features.cpu().numpy()
        return act
