import copy

import torch
import torch.nn as nn
import torch.optim as optim
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

from metric import AverageMeter
from trainer.masedit import MASEDIT, zerolike_params_dict
from trainer.utils import logging


class FAREDIT(MASEDIT):
    def get_mask_parameter(self,model):
        param_group = []
        for n,p in model.named_parameters():
            if self.args.edit_layer in n:
                p.requries_grad = True
                lr = self.args.lr_mask
            else:
                p.requries_grad = False
                continue
            param_group.append({
                "name": n,
                "params": p,
                "lr": lr,
            })
        return param_group

    def compute_importance_score(self, model, dataloader, dataset=None, task=0, **kwargs):
        masked_model = copy.deepcopy(model)
        masked_model.to(self.args.device)
        masked_model.train()
        params_to_optimize = self.get_mask_parameter(masked_model)
        optimizer_masks = optim.AdamW(params_to_optimize, lr=self.args.lr_mask, betas=(
                0.9, 0.999), eps=1e-8, weight_decay=0.2)
        lr_scheduler = CosineAnnealingWarmupRestarts(
                optimizer_masks,
                first_cycle_steps=self.args.mask_epochs,
                cycle_mult=1.0,
                max_lr=self.args.lr_mask,
                min_lr=0,
                warmup_steps=1)
        avg_loss = AverageMeter()
        for idx in range(self.args.mask_epochs):
            for iiter, batch in enumerate(dataloader):
                images, _, texts = batch
                images = images.to(self.args.device)
                texts = texts.to(self.args.device)

                # Forward pass
                logits_per_image, logits_per_text = masked_model(images, texts)
                ground_truth = torch.arange(len(images), dtype=torch.long, device=self.args.device)
                loss_img = nn.CrossEntropyLoss()
                loss_txt = nn.CrossEntropyLoss()
                loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
                optimizer_masks.zero_grad()
                loss.backward()
                optimizer_masks.step()
                avg_loss.update(loss.item() / images.size(0), n=1)
                logging('iter', iiter + idx * len(dataloader),
                        f'train_mask_loss/{task}', avg_loss.val, self.args)
            
            lr_scheduler.step()
            _ = self.middle_evaluation(masked_model, dataset, task, idx,log_name='far accuracy')


        importance = dict(zerolike_params_dict(model, device=self.args.device))
        for (key,p),(key_original, p_original) in zip(masked_model.named_parameters(), model.named_parameters()):
            assert(key == key_original)
            if key in self.trainable_params:
                if 'weight' in key :
                    importance[key] = (p-p_original).abs()
                    # print (original_key,importance[original_key])

        # print(importance)

        
        del masked_model
        return importance
                





        

