
import torch
import torch.nn as nn
import torch.optim as optim
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts
import torch.nn.functional as F
from clip.modified_clip import make_model
from metric import AverageMeter
from trainer.masedit import MASEDIT, zerolike_params_dict
from trainer.utils import logging
from tqdm import tqdm
import numpy as np

class MASKEDIT(MASEDIT):
    def get_mask_parameter(self,model):
        param_group = []
        for n,p in model.named_parameters():
            if 'mask_real' in n :
                if self.args.mask_strategy == 'addlowrank' and 'mask_real_' not in n:
                    # print(n)
                    p.requries_grad = False
                    continue
                p.requries_grad = True
                lr = self.args.lr_mask
            elif self.args.edit_layer in n and 'mask' not in n:
                p.requries_grad = True
                lr = self.args.lr
            else:
                p.requries_grad = False
                continue
            
            param_group.append({
                "name": n,
                "params": p,
                "lr": lr,
            })
        # print (param_group)
        return param_group
    
    @torch.no_grad()
    def init_mask(self,model,weight):
        
        for n,p in model.named_parameters():
            # print(n)
            if 'mask' in n:
                if 'real' in n and self.args.mask_learnable_param_init=='gradient':
                    original_key = n.replace('mask_real','weight')
                    if original_key in weight.keys():
                        if self.args.mask_strategy in ['addrelu', 'addtanh', 'addsigmoid','addrelutanh', 'addrelutanhneuron']:
                            if self.args.log_init:
                                if 'neuron' not in self.args.mask_strategy:
                                    p.data.copy_(torch.clip(torch.log(weight[original_key].abs()),-10,10)/10)
                                    # p.data.copy_(1e-3 * torch.ones_like(p))
                                    # visualize_matrix(p.detach().cpu())
                                else:
                                    # import pdb;pdb.set_trace()
                                    p.data.copy_(torch.nn.functional.normalize(torch.log(weight[original_key].abs()).norm(dim=1),dim=0,p=torch.inf).reshape(-1,1))
                                    # print (p.shape)
                                    # if p.data.isnan().any():
                                    #     import pdb;pdb.set_trace()
                                    visualize_matrix(p.detach().cpu())


                            else:
                                p.data.copy_(( weight[original_key].abs())/ weight[original_key].abs().max() *2-1)
                            # print(p)
                            # import pdb;pdb.set_trace()
                        else:
                            p.data.copy_(weight[original_key].abs() / weight[original_key].abs().max())
                    else:
                        print(n,original_key)
                elif 'gradient' in n and self.args.mask_frozen_param_init=='gradient':
                    original_key = n.replace('mask_gradient','weight')
                    p.data.copy_(weight[original_key])
                # print (n,p)
        for idx, module in enumerate(model.modules()):
            if 'ElementWise' in str(type(module)):
                # print (module)
                module.init()

    def compute_importance_score(self, model, dataloader, dataset=None, task=0, **kwargs):
        masked_model = make_model(self.args, weight=model)
        if self.args.init_mask == 'gradient':
            importance = super().compute_importance_score(model,dataloader)
            self.init_mask(masked_model,importance)
        masked_model.to(self.args.device)
        masked_model.train()
        params_to_optimize = self.get_mask_parameter(masked_model)
        optimizer_masks = optim.AdamW(params_to_optimize, lr=self.args.lr_mask, betas=(
                0.9, 0.999), eps=1e-8, weight_decay=0.2)
        # optimizer_masks = optim.SGD(params_to_optimize, lr=self.args.lr_mask,
        #                           weight_decay=self.args.wd)
        lr_scheduler = CosineAnnealingWarmupRestarts(
                optimizer_masks,
                first_cycle_steps=self.args.mask_epochs,
                cycle_mult=1.0,
                max_lr=self.args.lr_mask,
                min_lr=0,
                warmup_steps=0)
        avg_loss = AverageMeter()
        avg_l1_loss = AverageMeter()
        for idx in range(self.args.mask_epochs):
            for (iiter, batch) in enumerate(tqdm(dataloader)):
                images, _, texts = batch
                images = images.to(self.args.device)
                texts = texts.to(self.args.device)

                # Forward pass
                logits_per_image, logits_per_text = masked_model(images, texts)
                ground_truth = torch.arange(len(images), dtype=torch.long, device=self.args.device)
                loss_img = nn.CrossEntropyLoss()
                loss_txt = nn.CrossEntropyLoss()
                loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
                # print (loss)
                if self.args.mask_l1_loss:
                    l1_loss =  self.args.l1_loss_scale *  self.compute_mask_l1_loss(masked_model)
                    loss += l1_loss
                    if images.size(0) == self.args.batch_size:
                        avg_l1_loss.update(l1_loss.item())
                        logging('iter', iiter + idx * len(dataloader),
                        f'train_mask_l1loss/{task}', avg_l1_loss.val, self.args)
                    # print (l1_loss  )
                optimizer_masks.zero_grad()
                loss.backward()
                # if idx > 2:
                #     import pdb;pdb.set_trace()
                optimizer_masks.step()
                if images.size(0) == self.args.batch_size:
                    avg_loss.update(loss.item(), n=1)
                
                    logging('iter', iiter + idx * len(dataloader),
                        f'train_mask_loss/{task}', avg_loss.val, self.args)
                
            lr_scheduler.step()
            # masked_model.eval()
            # _ = self.middle_evaluation(masked_model, dataset, task, idx, log_name='mask accuracy')
            # # import pdb;pdb.set_trace()
            
            # masked_model.train()
            
            # if 'sparsity' not in self.args.mask_strategy:
            #     print('Num 0ed out parameters:')
            #     count_mask_out = 0
            #     count_all = 0
                
            #     for idx, module in enumerate(masked_model.modules()):
            #         if 'ElementWise' in str(type(module)):
            #             # print(module)
            #             num_zero = module.mask_real.data.lt(self.args.threshold).sum()
            #             total = module.mask_real.data.numel()
            #             count_mask_out += num_zero
            #             count_all += total
            #             # print(idx, num_zero, total)
            #     print (f' Total masked out weights: {count_mask_out}. Ratio: {count_mask_out/count_all}')
            #     logging('epoch', idx,'masked out ratio', count_mask_out/count_all, self.args)
            # _ = self.held_out_evaluation(model, dataset.transform)
        

        for idx, module in enumerate(masked_model.modules()):
            if 'ElementWise' in str(type(module)):
                # print (module)
                module.compute_importance()






        importance = dict(zerolike_params_dict(model, device=self.args.device))
        for key,p in masked_model.named_parameters():

            original_key = key.replace('mask_real','weight')
            # print(original_key)
            if original_key in self.trainable_params:
                
                if 'weight' in original_key :
                    # import pdb;pdb.set_trace()
                    importance[original_key] = p
                    # print (original_key,importance[original_key])
                    print (key)
                    visualize_matrix(p.cpu())

        # print(importance)

        
        del masked_model
        return importance
    def compute_mask_l1_loss(self,model):
        loss = torch.tensor(0.0, requires_grad=True, dtype=torch.float32).cuda()
        for idx, module in enumerate(model.modules()):
            if 'ElementWise' in str(type(module)):
                # print ('compute l1 loss of', module)
                if self.args.l1_loss_on_mask:
                    loss += module.compute_l1_loss()
                else:
                    loss += module.compute_unmasked_l1_loss()
        return loss
                





        

def visualize_matrix(matrix):
    # Calculate the histogram of the tensor values
    histogram, bin_edges = np.histogram(matrix.detach().numpy().flatten(), bins=10)

    # Print the value distribution as a textual representation
    print("Value Distribution:")
    for i, (start, end) in enumerate(zip(bin_edges[:-1], bin_edges[1:])):
        bin_count = histogram[i]
        print(f"Bin {i}: Range [{start:.2f}, {end:.2f}]: Count {bin_count}")