import time
import torch
import torch.nn.functional as F
import utils
from copy import deepcopy

from .impl import iterative_unlearn

class ContrastiveUnlearner:
    def __init__(self, original_model, patch_size=16, tau=0.07, mask_ratio=0.05):
        self.original_model = original_model.eval()
        self.patch_size = patch_size
        self.tau = tau
        self.mask_ratio = mask_ratio
        
        self.attentions = []
        self.hook_handles = []
        
        last_block = original_model.blocks[-1]
        target_attention = last_block.attn
        
        self.hook_handles.append(
            target_attention.attn_drop.register_forward_pre_hook(
                self._save_attention_hook
            )
        )

    def _save_attention_hook(self, module, input):
        self.attentions.append(input[0].detach().clone())

    def _get_attention_mask(self, imgs):
        self.attentions.clear()
        
        with torch.no_grad():
            _ = self.original_model(imgs)
            
            if not self.attentions:
                raise RuntimeError("Failed to capture attention weights")
            
            attn_weights = self.attentions[-1]  # [B, H, N+1, N+1]
        
        B, num_heads, seq_len, _ = attn_weights.shape
        cls_attentions = attn_weights[..., 0, 1:]  # [B, H, N]
        
        cls_attentions = cls_attentions.mean(dim=1)  # [B, N]
        
        k = int(self.mask_ratio * cls_attentions.size(-1))
        _, topk_indices = cls_attentions.topk(k, dim=-1)
        
        return topk_indices

    def _mask_images(self, imgs, topk_indices):
        B, C, H, W = imgs.shape
        device = imgs.device
        ph, pw = H // self.patch_size, W // self.patch_size
        
        mask = torch.zeros(B, ph*pw, dtype=torch.bool, device=device)
        mask.scatter_(1, topk_indices, True)
        
        mask = mask.view(B, 1, ph, 1, pw, 1)
        mask = mask.repeat_interleave(self.patch_size, 3)\
                  .repeat_interleave(self.patch_size, 5)
        mask = mask.view(B, H, W).unsqueeze(1)
        
        return imgs * (~mask)


    def compute_loss(self, current_model, imgs):
        with torch.no_grad():
            negative_logits = self.original_model(imgs)
            
            topk_indices = self._get_attention_mask(imgs)
            masked_imgs = self._mask_images(imgs, topk_indices)
            positive_logits = self.original_model(masked_imgs)
        
        anchor_logits = current_model(imgs)
        pos_sim = F.cosine_similarity(anchor_logits, positive_logits, dim=-1) / self.tau
        neg_sim = F.cosine_similarity(anchor_logits, negative_logits, dim=-1) / self.tau
        
        logits = torch.stack([pos_sim, neg_sim], dim=-1)
        labels = torch.zeros(imgs.size(0), dtype=torch.long).to(imgs.device)
        return F.cross_entropy(logits, labels)

    def __del__(self):
        for handle in self.hook_handles:
            handle.remove()



@iterative_unlearn
def CON(data_loaders, model, criterion, optimizer, epoch, args, mask=None):  
    forget_loader = data_loaders["forget"]
    retain_loader = data_loaders["retain"]
    start = time.time()
    loader_len = len(forget_loader) + len(retain_loader)
    
    original_model= deepcopy(model)

    contrast_module = ContrastiveUnlearner(
        original_model=original_model,  
        patch_size=16,
        tau=0.07
    )

    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    model.train()
    if epoch < 2:
        for m in range(1):
            for i, (image, target) in enumerate(forget_loader):

                image = image.cuda()

                loss_contrast = contrast_module.compute_loss(model, image)
                # total_loss += loss_contrast
                
                optimizer.zero_grad()
                loss_contrast.backward()
                
                if mask:
                    for name, param in model.named_parameters():
                        if param.grad is not None:
                            param.grad *= mask[name]
                optimizer.step()
    else:
        for n in range(1):
            for i, (image, target) in enumerate(retain_loader):
                image = image.cuda()
                target = target.cuda()
                
                # compute output
                output_clean = model(image)
                loss = criterion(output_clean, target)
                
                optimizer.zero_grad()
                loss.backward()
                
                if mask:
                    for name, param in model.named_parameters():
                        if param.grad is not None:
                            param.grad *= mask[name]
                
                optimizer.step()
                
                output = output_clean.float()
                loss = loss.float()
                # measure accuracy and record loss
                prec1 = utils.accuracy(output.data, target)[0]
                
                losses.update(loss.item(), image.size(0))
                top1.update(prec1.item(), image.size(0))
                
                if (i + 1) % args.print_freq == 0:
                    end = time.time()
                    print('Epoch: [{0}][{1}/{2}]\t'
                            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                            'Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                            'Time {3:.2f}'.format(
                                epoch, i, loader_len, end-start, loss=losses, top1=top1))
    
    
    return top1.avg