import torch.nn.functional as F
import torch

from trainer.utils import compute_batch_nll
from trainer.unlearn.grad_diff import GradDiff

import re
import deepspeed
from functools import partial


class ATTU_output(GradDiff):
    def __init__(self, attention_temp=2.0, trainable_params_regex=["model\.layers\.(5|6|7)\.mlp\.down_proj\.weight"], *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.attention_temp = attention_temp

        # Create reference model if not already set
        if self.ref_model is None:
            self.ref_model = self._prepare_ref_model(self.model)

        # layers = "|".join(str(i) for i in range(7, 50))
        # trainable_params_regex = [f"model\.layers\.({layers})\.self_attn\..*"]

        # Unfreeze only the selected parameters
        self.trainable_params_regex = (
            trainable_params_regex  # Regex for selecting params
        )

    def create_optimizer(self):
        self._freeze_all_params(self.model, False)
        # This makes the optimizer to select only trainable params
        self._set_trainable_params(self.model, self.trainable_params_regex, True)
        super().create_optimizer()
        self._freeze_all_params(self.model, True)

    def _freeze_all_params(self, model, requires_grad=True):
        """Freeze all parameters in the model initially."""
        for param in model.parameters():
            param.requires_grad = requires_grad

    def _set_trainable_params(self, model, trainable_params_regex, requires_grad=True):
        """Unfreeze specific parameters that match the regex patterns."""
        for name, param in model.named_parameters():
            if any(re.fullmatch(pattern, name) for pattern in trainable_params_regex):
                param.requires_grad = requires_grad
                print(f"{name}:requires_grad\t{requires_grad}")


    def kl_divergence_retain(self, model, target_model, inputs):

        with torch.no_grad():
            ref_outputs = target_model(**inputs)
            ref_probs = F.log_softmax(ref_outputs.logits[..., :-1, :], dim=-1)

        outputs = model(**inputs)
        current_probs = F.log_softmax(outputs.logits[..., :-1, :], dim=-1)

        # minimum KL divergence
        return F.kl_div(
            current_probs, ref_probs, reduction="none", log_target=True
        ).sum(-1), outputs, ref_outputs
    
    def kl_divergence_forget(self, model, target_model, inputs, attention_temp):

        # input_ids_expanded = inputs["input_ids"][:, 1:].unsqueeze(-1)
        with torch.no_grad():
            ref_outputs = target_model(**inputs, attention_temp=attention_temp)
            ref_probs = F.log_softmax(ref_outputs.logits[..., :-1, :], dim=-1)
        # token_ref_probs = ref_probs.gather(dim=-1, index=input_ids_expanded).squeeze(-1)

        # with torch.no_grad():
        #     ori_outputs = target_model(**inputs)
        # ori_probs = F.log_softmax(ori_outputs.logits[..., :-1, :], dim=-1)
        # token_ori_probs = ori_probs.gather(dim=-1, index=input_ids_expanded).squeeze(-1)

        outputs = model(**inputs)
        current_probs = F.log_softmax(outputs.logits[..., :-1, :], dim=-1)

        # minimum KL divergence
        return F.kl_div(
            current_probs, ref_probs, reduction="none", log_target=True
        ).sum(-1), outputs, ref_outputs.loss.detach()#, ori_outputs, token_ori_probs, token_ref_probs

        # return F.kl_div(
        #     ref_probs, current_probs, reduction="none", log_target=True
        # ).sum(-1), outputs, ref_outputs, ori_outputs, token_ori_probs, token_ref_probs
    
    def compute_retain_loss(self, model, retain_inputs):
        retain_loss = None
        retain_outputs = None
        ref_retain_outputs = None

        if self.retain_loss_type == "NLL":
            retain_outputs = model(**retain_inputs)
            # with torch.no_grad():
            #     ref_retain_outputs = self.ref_model(**retain_inputs)
            retain_loss = retain_outputs.loss
            
        elif self.retain_loss_type == "KL":
            retain_loss, retain_outputs, ref_retain_outputs = self.kl_divergence_retain(
                self.model, self.ref_model, retain_inputs
            )
        else:
            raise NotImplementedError(
                f"{self.retain_loss_type} not implemented for retain set"
            )
        if self.retain_loss_type == "NLL":
            return retain_loss, retain_outputs.loss.detach(), torch.tensor([0])
        else:
            return retain_loss, retain_outputs.loss.detach(), ref_retain_outputs.loss.detach()

    def compute_loss(self, model, inputs, return_outputs=False):

        forget_inputs = inputs["forget"]
        forget_inputs = {
            "input_ids": forget_inputs["input_ids"],
            "attention_mask": forget_inputs["attention_mask"],
            "labels": forget_inputs["labels"],
        }
        forget_labels = forget_inputs["labels"]
        forget_mask = (forget_labels[..., 1:] != -100)

        # forget_loss, forget_outputs, ref_forget_outputs, ori_forget_output, token_ori_probs, token_ref_probs = self.kl_divergence_forget(model, self.ref_model, forget_inputs, self.attention_temp)
        forget_loss, forget_outputs, ref_forget_outputs = self.kl_divergence_forget(model, self.ref_model, forget_inputs, self.attention_temp)
        forget_loss = ((forget_loss * forget_mask).sum(-1) / forget_mask.sum(-1)).mean()

        # print(forget_loss.mean().requires_grad)
        # topk_ratio = 0.2 # example: keep top 20%

        # batch_losses = []
        # for i in range(forget_loss.size(0)):
        #     loss_i = forget_loss[i]          # shape: [seq_len - 1]
        #     mask_i = forget_mask[i]          # shape: [seq_len - 1]

        #     valid_loss = loss_i[mask_i]      # shape: [valid_seq_len]
        #     k = max(1, int(topk_ratio * valid_loss.numel()))
            
        #     topk_loss, _ = torch.topk(valid_loss, k=k, largest=True)
        #     avg_topk_loss = topk_loss.mean()
        #     batch_losses.append(avg_topk_loss)

        # forget_loss = torch.stack(batch_losses).mean()
        # print(forget_loss.requires_grad)
        

        retain_inputs = inputs["retain"]
        retain_inputs = {
            "input_ids": retain_inputs["input_ids"],
            "attention_mask": retain_inputs["attention_mask"],
            "labels": retain_inputs["labels"],
        }
        retain_labels = retain_inputs["labels"]
        retain_mask = (retain_labels[..., 1:] != -100)

        retain_loss, retain_outputs, ref_retain_outputs = self.compute_retain_loss(model=model, retain_inputs=retain_inputs)

        loss_f = forget_outputs.loss
        ref_loss_f = ref_forget_outputs
        # ori_loss_f = ori_forget_output.loss

        loss_r = retain_outputs
        ref_loss_r = ref_retain_outputs

        if self.retain_loss_type == "KL":
            retain_loss = ((retain_loss * retain_mask).sum(-1) / retain_mask.sum(-1)).mean()

        loss = self.alpha * forget_loss + retain_loss

        # print(f"Forget loss: {loss_f.item()}                Ref Forget loss: {ref_loss_f.item()}                Ori Forget loss: {ori_loss_f.item()}                Forget Diff: {loss_f.item()-ori_loss_f.item()}")
        # print(f"Retain loss: {loss_r.item()}                Ref Retain loss: {ref_loss_r.item()}                Retain Diff: {loss_r.item()-ref_loss_r.item()}                Diff of Diff: {loss_f.item()-ori_loss_f.item()-(loss_r.item()-ref_loss_r.item())}")
        print(f"Forget loss: {loss_f.item()}                Ref Forget loss: {ref_loss_f.item()}")
        print(f"Retain loss: {loss_r.item()}                Ref Retain loss: {ref_loss_r.item()}                Retain Diff: {loss_r.item()-ref_loss_r.item()}")
        print(f"our forget Loss: {forget_loss.item()}                our retain Loss: {retain_loss.item()}                Loss: {loss.item()}")

        return (loss, forget_outputs) if return_outputs else loss
