import copy
import os
import pickle
import time
from clip import clip

import torch
import torch.nn as nn
import torch.optim as optim
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
from tqdm import tqdm

import wandb
from clip.clip import tokenize
from clip.modified_clip import make_model

from dataset.cc12m import cc12m
from dataset.imagenet import ImageNet, zeroshot_classifier
from metric import AverageMeter, ClassIncrementalMetric, TaskIncrementalMetric
from trainer.finetune import FinetuneCLIP
from trainer.masedit import MASEDIT, zerolike_params_dict
from trainer.utils import accuracy, get_ckpt_save_path, logging, resume

def zerolike_params_dict(model, device=None):
    """
    Create a list of (name, parameter), where parameter is initalized to zero.
    The list has as many parameters as model, with the same size.

    :param model: a pytorch model
    """

    return [
        (k, torch.zeros_like(p).to(p.device if (device == None) else device))
        for k, p in model.named_parameters()
    ]


def copy_params_dict(model, copy_grad=False, device=None):
    """
    Create a list of (name, parameter), where parameter is copied from model.
    The list has as many parameters as model, with the same size.

    :param model: a pytorch model
    :param copy_grad: if True returns gradients instead of parameter values
    """

    if copy_grad:
        return [(k, p.grad.data.detach().clone()) for k, p in model.named_parameters()]
    else:
        return [(k, p.data.detach().clone().to(p.device if (device == None) else device)) for k, p in
                model.named_parameters()]

class MaskLearn(FinetuneCLIP):
    def __init__(self, args):
        super().__init__(args)
        self.trainable_params = []
        self._lambda = self.args.scale
        self.alpha = 0.5
        self.importance_computed = False

    def get_mask_parameter(self,model):
        param_group = []
        for n,p in model.named_parameters():
            if 'mask' in n:
                p.requries_grad = True
                lr = self.args.lr_mask
            elif self.args.edit_layer in n:
                p.requries_grad = True
                lr = self.args.lr
            else:
                p.requries_grad = False
                continue
            param_group.append({
                "name": n,
                "params": p,
                "lr": lr,
            })
        return param_group
    
    @torch.no_grad()
    def init_mask(self,model,weight):
        
        for n,p in model.named_parameters():
            # print(n)
            if 'mask' in n:
                original_key = n.replace('mask_real','weight')
                p.data.copy_(weight[original_key].abs() / weight[original_key].abs().max())
                # print (n,p)
    def unfreeze_model(self, model):
        model.train()
        for name, param in model.named_parameters():
            if self.args.update_all:
                trainable_params = True
            elif name == 'visual.proj':
                if self.args.finetune_proj:
                    trainable_params = True
                else:
                    trainable_params = False
            elif 'visual' in self.args.method:
                trainable_params = self.args.edit_layer in name and 'visual' in name
            else:
                trainable_params = self.args.edit_layer in name
            if trainable_params:
                param.requires_grad = True
                if name not in self.trainable_params:
                    self.trainable_params.append(name)
            else:
                param.requires_grad = False
        print('Trainable parameters: ', self.trainable_params)

    def update_model(self,model):
        for module in model.modules():
            if hasattr(module, 'merge'):
                print (module)
                module.merge()
    
    def train(self, model, dataset, task):
        train_dataloader, buffer_loader, validset, total_batches = self.get_iterator(
            dataset, task)
        

        params_to_optimize = self.get_mask_parameter(model)
        optimizer_masks = optim.AdamW(params_to_optimize, lr=self.args.lr_mask, betas=(
                0.9, 0.999), eps=1e-8, weight_decay=0.2)
        lr_scheduler = CosineAnnealingWarmupRestarts(
                optimizer_masks,
                first_cycle_steps=self.args.mask_epochs,
                cycle_mult=1.0,
                max_lr=self.args.lr_mask,
                min_lr=0,
                warmup_steps=1)
        
        self.unfreeze_model(model)
        batch_time = AverageMeter()
        loss = AverageMeter()

        if self.args.init_mask == 'gradient':
            importance = self.compute_importance_score(model,train_dataloader)
            self.init_mask(model,importance)
        self.compute_importance(dataset, model, task=task)

        for epoch in range(self.args.mask_epochs):
            buffer_iterator = iter(buffer_loader) if buffer_loader else None
            for iiter, batch in enumerate(train_dataloader):
                batch_size = self.get_batch_size(batch)
                end = time.time()
                
                if buffer_iterator:
                    try:
                        batch_b = next(buffer_iterator)
                    except StopIteration:
                        buffer_iterator = iter(buffer_loader)
                        batch_b = next(buffer_iterator)
                else:
                    batch_b = None

                total_loss = self.compute_loss(batch, model, buffer=batch_b, epoch=epoch)
                optimizer_masks.zero_grad()
                total_loss.backward()
                optimizer_masks.step()
                batch_time.update(time.time() - end)
                loss.update(total_loss.item() / batch_size, n=batch_size)
                logging('iter', iiter + epoch * total_batches,
                        f'train_loss/{task}', loss.val, self.args)
                if iiter % self.args.print_frequency == 0:
                    print(' Epoch: [{0}/{1}], Batch: [{2}/{3}]\t'.format(epoch, self.args.mask_epochs, iiter, total_batches),
                          f'Batch Time {batch_time.val: .3f} ({batch_time.avg: .3f})\t'
                          f'Loss {loss.val:.4f} ({loss.avg: .4f}) \t'
                          f'Estimated Task Time {batch_time.avg * total_batches * self.args.mask_epochs / 3600: .3f} H'
                          )
            lr_scheduler.step()
            model.eval()
            _ = self.middle_evaluation(model, dataset, task, epoch, log_name='mask accuracy')
            model.train()
            if 'sparsity' not in self.args.mask_strategy:
                print('Num 0ed out parameters:')
                count_mask_out = 0
                count_all = 0
                for idx, module in enumerate(model.modules()):
                    if 'ElementWise' in str(type(module)):
                        # print(module)
                        num_zero = module.mask_real.data.lt(self.args.threshold).sum()
                        total = module.mask_real.data.numel()
                        count_mask_out += num_zero
                        count_all += total
                        # print(idx, num_zero, total)
                print (f' Total masked out weights: {count_mask_out}. Ratio: {count_mask_out/count_all}')
                logging('epoch', epoch,'masked out ratio', count_mask_out/count_all, self.args)
        model.eval()


        print('Update Buffer....')
        dataset.update_buffer(task)
        self.update_model(model)

    def compute_loss(self, batch, model, **kwargs):
        buffer = kwargs.get('buffer', None)
        epoch = kwargs.get('epoch', 0)
        loss_img = nn.CrossEntropyLoss()
        loss_txt = nn.CrossEntropyLoss()
        images, _, texts = batch
        if buffer and epoch > 0:
            images_b, _, texts_b = buffer
            images = torch.cat([images, images_b])
            texts = torch.cat([texts, texts_b])

        images = images.to(self.args.device)
        texts = texts.to(self.args.device)
        ground_truth = torch.arange(len(images), dtype=torch.long, device=self.args.device)

        logits_per_image, logits_per_text = model(images, texts)

        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
        penalty = self._lambda * self.compute_importance_penalty(self.args, model)
        if self.args.mask_l1_loss:
            l1_loss = self.args.l1_loss_scale * self.compute_mask_l1_loss(model)
            
        else:
            l1_loss = torch.tensor(0.0, requires_grad=True, dtype=torch.float32).cuda()
        # print(total_loss.item(), penalty.item(), l1_loss.item())
        return total_loss + penalty +  l1_loss
    
    def compute_mask_l1_loss(self,model):
        loss = torch.tensor(0.0, requires_grad=True, dtype=torch.float32).cuda()
        for n,p in model.named_parameters():
            if 'mask' in n:
                loss += torch.sum(p.abs())
        return loss

    
    def compute_importance_penalty(self, args, model):
        loss_reg = torch.tensor(0).float().to(self.args.device)

        # Apply penalty term
        for name, param in model.named_parameters():

            if name in self.trainable_params:
                loss_reg += torch.sum(
                    self.importance[name] *
                    (param - self.params[name].expand(param.shape)).pow(2)
                )

        # Update loss
        return loss_reg

    def setup_importance(self, model):
        # Parameters before the first task starts
        self.params = dict(copy_params_dict(model))
        # Initialize Fisher information weight importance
        self.importance = dict(zerolike_params_dict(model))
    
    def compute_importance(self, dataset, model, task):
        if task == 0:
            self.setup_importance(model)
            if 'condset' in self.args.mas_importance_compute:
                print ('Compute importance for conditional set...')
                condset = cc12m(transform=dataset.transform)
                cond_dataloader = self.get_loader(condset)
                self.compute_update_importance(model, cond_dataloader)
        elif 'curset' in self.args.mas_importance_compute:
            prev_set = dataset.get_dataset(task - 1, is_train=True, with_buffer=False)
            buffer = self.get_loader(prev_set)
            print ('Compute importance for the last task...')
            self.compute_update_importance(model, buffer)
    
    def compute_update_importance(self, model, dataloader):
        self.params = dict(copy_params_dict(model), device=self.args.device)

        # Get importance
        curr_importance = self._get_importance(model, dataloader)
        if not self.importance_computed:
            self.importance = curr_importance
            self.importance_computed = True
            return
        else:

            # Update importance
            for name in self.importance.keys():
                self.importance[name] = (self.alpha * self.importance[name]
                                         + (1 - self.alpha) * curr_importance[name].data)
                
    def _get_importance(self, model, dataloader,  loss_type='l2'):


        # Initialize importance matrix
        importance = dict(zerolike_params_dict(model, device=self.args.device))

        # Do forward and backward pass to accumulate L2-loss gradients
        model.train()
        size = 0
        for _, batch in enumerate(tqdm(dataloader)):
            # Get batch
            images, _, texts = batch
            images = images.to(self.args.device)
            texts = texts.to(self.args.device)

            # Forward pass
            model.zero_grad()
            logits_per_image, logits_per_text = model(images, texts)

            # Average L2-Norm of the output
            if loss_type == 'l2':
                loss = torch.norm(logits_per_image, p="fro", dim=1).pow(2).mean()
            elif loss_type == 'cn':
                ground_truth = torch.arange(len(images), dtype=torch.long, device=self.args.device)
                loss_img = nn.CrossEntropyLoss()
                loss_txt = nn.CrossEntropyLoss()
                loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
            else:
                raise ValueError
            loss.backward()

            # Accumulate importance
            for name, param in model.named_parameters():
                if param.requires_grad:
                    if param.grad is not None:
                        importance[name].data += param.grad.clone().abs() * len(images)
            size += len(images)

        # Normalize importance
        importance = {
            name: importance[name] / size
            for name in importance.keys()
        }
        if self.args.importance_max_normalize:
            for name in importance.keys():
                if name in self.trainable_params:
                    importance[name] = torch.nn.functional.normalize(importance[name],dim=-1,p=torch.inf)

        return importance
            




 
 





    def compute_importance_score(self, model, dataloader,  loss_type='l2', **kwargs):
        
        weight = model.state_dict()
        mask_model = model
        model,_ = clip.load(self.args.model, download_root='./clip_models/',args=self.args)
        model.load_state_dict(weight,strict=False)
        for module, mask_module in zip(model.modules(),mask_model.modules()):
            # print(n)
           if 'ElementWise' in str(type(mask_module)):

                module.weight.data.copy_(mask_module.weight.data)
                module.bias.data.copy_(mask_module.bias.data)
                # print (n,p)
        


        # Initialize importance matrix
        importance = dict(zerolike_params_dict(model, device=self.args.device))

        # Do forward and backward pass to accumulate L2-loss gradients
        model.train()
        model.zero_grad()

        for _, batch in enumerate(tqdm(dataloader)):
            # Get batch
            images, _, texts = batch
            images = images.to(self.args.device)
            texts = texts.to(self.args.device)

            # Forward pass
            logits_per_image, logits_per_text = model(images, texts)

            # Average L2-Norm of the output
            if loss_type == 'l2':
                loss = torch.norm(logits_per_image, p="fro", dim=1).pow(2).mean()
            elif loss_type == 'cn':
                ground_truth = torch.arange(len(images), dtype=torch.long, device=self.args.device)
                loss_img = nn.CrossEntropyLoss()
                loss_txt = nn.CrossEntropyLoss()
                loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
            else:
                raise ValueError
            loss.backward()

            # Accumulate importance
            for name, param in model.named_parameters():
                if param.requires_grad:
                    if param.grad is not None:
                        importance[name].data += param.grad.clone()

        return importance