"""Borrowed implementation from https://github.com/centerforaisafety/wmdp/blob/main/rmu/unlearn.py"""

import re
import torch
import torch.nn.functional as F
import deepspeed
from functools import partial
from trainer.unlearn.grad_diff import GradDiff
from torch.linalg import pinv

if self.loss_type == "ours_shit":

    epsilon = 1e-3
    forget_inputs, _ = inputs
    input_ids, labels, attention_mask = forget_inputs
    outputs = model(input_ids,labels=labels, attention_mask=attention_mask, output_hidden_states=True)

    forget_loss = outputs.loss

    hidden_states = outputs.hidden_states
    output_logits = outputs.logits[..., :-1, :].contiguous()

    loss = 0.0
    shifted_labels = labels[..., 1:].contiguous()
    shifted_input_ids = input_ids[..., 1:].contiguous().unsqueeze(-1)

    loss_mask = (shifted_labels != -100)
    # Select Middle Layers
    selected_layers = [i for i in range(1,len(hidden_states)-1)]

    for layer in selected_layers:
        # logits = F.log_softmax(model.lm_head(hidden_states[layer][..., :-1, :]),dim=-1)
        logits = model.lm_head(hidden_states[layer][..., :-1, :])

        with torch.no_grad():
            new_logits = logits.detach()
            probablities = F.softmax(new_logits/self.temp,dim=-1).detach()
            mu = (probablities*new_logits).sum(-1,keepdim=True).detach()
            sigma = (probablities*torch.square(new_logits-mu)).sum(-1).sqrt().detach()

            mu = mu.squeeze(-1)

        token_logits = logits.gather(dim=-1, index=shifted_input_ids.to(logits.device)).squeeze(-1)

        with torch.no_grad():
            mask = ((token_logits - mu) + epsilon > 0.0).detach()

        loss += ((mask * (token_logits - mu + epsilon) * loss_mask).sum(-1) / (loss_mask.sum(-1))).mean()/len(selected_layers)
        

    # loss *= 3.0
    # logits = F.log_softmax(output_logits,dim=-1)

    # with torch.no_grad():
    #     new_logits = F.log_softmax(output_logits.detach(),dim=-1)
    #     probablities = F.softmax(new_logits/self.temp,dim=-1).detach()
    #     mu = (probablities*new_logits).sum(-1,keepdim=True).detach()
    #     sigma = (probablities*torch.square(new_logits-mu)).sum(-1).sqrt().detach()

    #     mu = mu.squeeze(-1)

    # token_logits = logits.gather(dim=-1, index=shifted_input_ids).squeeze(-1)

    # with torch.no_grad():
    #     mask = ((token_logits - mu) + epsilon > 0.0).detach()

    # loss += ((mask * (token_logits - mu + epsilon) * loss_mask).sum(-1) / (loss_mask.sum(-1))).mean()

    _ , retain_inputs = inputs
    retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
    
    shifted_labels = retain_labels[..., 1:].contiguous()
    shifted_input_ids = retain_input_ids[..., 1:].contiguous().unsqueeze(-1)

    loss_mask = (shifted_labels != -100)

    with torch.no_grad():
        retain_outputs = self.oracle_model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask, output_hidden_states=True)
        retain_hidden_states = retain_outputs.hidden_states
        


    outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask, output_hidden_states=True)
    retain_loss = outputs.loss

    hidden_states = outputs.hidden_states
    output_logits = outputs.logits[..., :-1, :].contiguous()

    # Select Middle Layers
    selected_layers = [i for i in range(1,len(hidden_states)-1)]

    for layer in selected_layers:
        loss += (((loss_mask*((hidden_states[layer][..., :-1, :]-retain_hidden_states[layer][..., :-1, :]).square().sum(-1))).sum(-1))/(loss_mask.sum(-1).float())).mean()/(float(len(selected_layers)))


    print("*************************************************************************")
    print(f"Forget loss: {forget_loss.item()}                Retain loss: {retain_loss.item()}                Loss: {loss.item()}")

elif self.loss_type == "ours":
    
    epsilon = -1e-5
    forget_inputs, _ = inputs
    input_ids, labels, attention_mask = forget_inputs

    # first_mask = first_occurrence_mask(input_ids)[:,1:]
    
    outputs = model(input_ids,labels=labels, attention_mask=attention_mask, output_hidden_states=True)

    forget_loss = outputs.loss


    # print(outputs.logits.dtype)

    hidden_states = outputs.hidden_states
    output_logits = outputs.logits[..., :-1, :].contiguous()

    loss = 0.0
    shifted_labels = labels[..., 1:].contiguous()
    shifted_input_ids = input_ids[..., 1:].contiguous().unsqueeze(-1)

    loss_mask = (shifted_labels != -100)
    # Select Middle Layers
    # selected_layers = [i for i in range(len(hidden_states))]
    selected_layers = [len(hidden_states)-1]
    forget_loss_sum = 0.0

    for layer in selected_layers:
        # logits = F.log_softmax(model.lm_head(hidden_states[layer][..., :-1, :]),dim=-1)
        # print(hidden_states[layer][..., :-1, :].dtype)
        # print(self.lm_head.weight.data.dtype)
        
        with torch.no_grad():
            # logits = self.lm_head(hidden_states[layer][..., :-1, :])
            if(layer==len(hidden_states)-1):
                logits = output_logits.detach()
            else:
                logits = torch.matmul(self.norm(hidden_states[layer][..., :-1, :]),self.lm_head)

            if(shifted_input_ids.device!=logits.device):
                shifted_input_ids = shifted_input_ids.to(logits.device)


            token_logits = logits.gather(dim=-1, index=shifted_input_ids)
            

            # probablities = F.softmax(logits*(10.0 ** (self.temps[f"layer_{layer}_Normal_Temp_all_"]['temp']*0.05)),dim=-1)
            probablities = F.softmax(logits,dim=-1)
            mu = (probablities*logits).sum(-1,keepdim=True)

            # mu = logits.mean(-1,keepdim=True)

            token_mask = ((token_logits - mu + epsilon) > 0.0).detach().squeeze(-1)
            # sigma = (probablities*torch.square(new_logits-mu)).sum(-1).sqrt().detach()
            # Replace token_logits in logits with mu at the specified indices
            logits.scatter_(-1, shifted_input_ids, mu)
            new_hidden_state = torch.matmul(logits,self.lm_head_pinv)

            extra_mask = (hidden_states[layer][..., :-1, :] - new_hidden_state).abs() < 1.0
            inverse_extra_mask = (hidden_states[layer][..., :-1, :] - new_hidden_state).abs() >= 1.0


        difference = hidden_states[layer][..., :-1, :] - new_hidden_state
        forget_loss_sum += ((((0.5 *extra_mask*(difference.square()))+(inverse_extra_mask*(difference.abs()-0.5))).sum(-1) * token_mask * loss_mask).sum(-1) / ((loss_mask).sum(-1))).mean()
        # forget_loss_sum += ((((0.5 *extra_mask*(difference.square()))+(inverse_extra_mask*(difference.abs()-0.5))).sum(-1) * token_mask * loss_mask * first_mask).sum(-1) / ((loss_mask*first_mask).sum(-1))).mean()/len(selected_layers)
        # forget_loss_sum += (((hidden_states[layer][..., :-1, :] - new_hidden_state).square().sum(-1) * token_mask * loss_mask * first_mask).sum(-1) / ((loss_mask*first_mask).sum(-1))).mean()/len(selected_layers)
    
    
    forget_loss_sum = forget_loss_sum * 40
    
    _ , retain_inputs = inputs
    retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
    
    shifted_labels = retain_labels[..., 1:].contiguous()
    shifted_input_ids = retain_input_ids[..., 1:].contiguous().unsqueeze(-1)

    loss_mask = (shifted_labels != -100)

    with torch.no_grad():
        retain_outputs = self.oracle_model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask, output_hidden_states=True)
        retain_hidden_states = retain_outputs.hidden_states
        


    outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask, output_hidden_states=True)
    retain_loss = outputs.loss

    hidden_states = outputs.hidden_states
    output_logits = outputs.logits[..., :-1, :].contiguous()

    # Select Middle Layers
    selected_layers = [i for i in range(len(hidden_states))]
    retain_loss_sum = 0.0

    for layer in selected_layers:
        # loss += (((loss_mask*((hidden_states[layer][..., :-1, :]-retain_hidden_states[layer][..., :-1, :]).square().sum(-1))).sum(-1))/(loss_mask.sum(-1).float())).mean()/(float(len(selected_layers)))
        retain_loss_sum += (((loss_mask*((hidden_states[layer][..., :-1, :]-retain_hidden_states[layer][..., :-1, :]).square().sum(-1))).sum(-1))/(loss_mask.sum(-1))).mean()

    loss = forget_loss_sum + retain_loss_sum
    print("*************************************************************************")
    print(f"Forget loss: {forget_loss.item()}                Retain loss: {retain_loss.item()}")
    print(f"sum forget Loss: {forget_loss_sum.item()}                sum retain Loss: {retain_loss_sum.item()}                Loss: {loss.item()}                Token Mask: {(token_mask*loss_mask).sum().item()/loss_mask.sum().item()}")



class distill_inverse(GradDiff):
    def __init__(
        self,
        module_regex="model\.layers\.7",
        trainable_params_regex=["model\.layers\.(5|6|7)\.mlp\.down_proj\.weight"],
        attention_temp = 2.0,
        *args,
        **kwargs,
    ):
        """
        RMU Trainer that fine-tunes only specific layers and parameters using regex-based filtering.

        Args:
            module_path (str): Regex pattern to match module names.
            trainable_param_paths (list of str): List of regex patterns for trainable parameters.
        """
        super().__init__(*args, **kwargs)

        self.lm_head_weight = self.model.lm_head.weight.data.clone().to(dtype = torch.float32).T.contiguous()
        self.lm_head_bias = self.model.lm_head.bias.data.clone().to(dtype = torch.float32) if self.model.lm_head.bias is not None else None
        print(self.lm_head_bias)

        self.lm_head_pinv = pinv(self.lm_head_weight)#.to(dtype = torch.bfloat16)

        print("pinv.shape:", self.lm_head_pinv.shape)
        print("pinv.device:", self.lm_head_pinv.device)

        # Create reference model if not already set
        if self.ref_model is None:
            self.ref_model = self._prepare_ref_model(self.model)

        # Unfreeze only the selected parameters
        self.trainable_params_regex = (
            trainable_params_regex  # Regex for selecting params
        )

        # Get actual module references
        self.module_regex = module_regex  # Regex for selecting modules
        self.model_module = self._get_matching_module(self.model, self.module_regex)
        self.ref_module = self._get_matching_module(self.ref_model, self.module_regex)
        self.attention_temp = attention_temp

    def create_optimizer(self):
        self._freeze_all_params(self.model, False)
        # This makes the optimizer to select only trainable params
        self._set_trainable_params(self.model, self.trainable_params_regex, True)
        super().create_optimizer()
        self._freeze_all_params(self.model, True)

    def _get_matching_module(self, model, module_regex):
        """Returns a single module matching the given regex from a DeepSpeed/DDP-wrapped model."""
        # Handle DeepSpeed and DDP-wrapped models by accessing the underlying module
        if isinstance(model, deepspeed.DeepSpeedEngine):
            model = model.module  # Extract the actual PyTorch model inside

        matched_modules = {
            name: module
            for name, module in model.named_modules()
            if re.fullmatch(module_regex, name)
        }

        if len(matched_modules) > 1:
            raise ValueError(
                f"More than one module matched with {module_regex}: {list(matched_modules.keys())}"
            )
        elif not matched_modules:
            raise ValueError(f"No module matched with {module_regex}")

        return next(iter(matched_modules.values()))  # Return the single matched module

    def _freeze_all_params(self, model, requires_grad=True):
        """Freeze all parameters in the model initially."""
        for param in model.parameters():
            param.requires_grad = requires_grad

    def _set_trainable_params(self, model, trainable_params_regex, requires_grad=True):
        """Unfreeze specific parameters that match the regex patterns."""
        for name, param in model.named_parameters():
            if any(re.fullmatch(pattern, name) for pattern in trainable_params_regex):
                param.requires_grad = requires_grad
                print(name)
                # print(f"{name}:requires_grad\t{requires_grad}")

    # def forward_with_cache(self, model, inputs, module, no_grad=True):
    #     """Performs a forward pass while caching the output of a specified module."""
    #     cache = []

    #     def hook(module, input, output):
    #         if isinstance(output, tuple):
    #             cache.append(output[0])
    #         else:
    #             cache.append(output)
    #         return None

    #     hook_handle = module.register_forward_hook(hook)
    #     with torch.set_grad_enabled(not (no_grad)):
    #         outputs = model(**inputs)
    #     hook_handle.remove()
    #     return cache[0], outputs
    
    def forward_with_cache(self, model, inputs, module, no_grad=True):
        """Performs a forward pass while caching the output of a specified module."""
        cache = []

        def hook(module, input, output):
            if isinstance(output, tuple):
                cache.append(output[0])
            else:
                cache.append(output)
            return None

        hook_handle = module.register_forward_hook(hook)
        with torch.set_grad_enabled(not (no_grad)):
            outputs = model(**inputs)

        hook_handle.remove()
        return cache[0][..., :-1, :], outputs
    

    def compute_activation_loss_mse(self, activation1, activation2, mask):
        squared_diff = torch.nn.functional.mse_loss(
            activation1, activation2, reduction="none"
        )  # Shape (b, s, d)
        expanded_mask = mask.unsqueeze(-1).expand_as(squared_diff)  # Shape: [b, s, d]
        squared_diff_sum = (
            (squared_diff * expanded_mask).sum(dim=2).sum(dim=(1))
        )  # Shape: [b, 1]
        num_tokens = mask.sum(dim=-1, keepdim=True)  # Sum over seq_len, Shape: [b, 1]
        return (squared_diff_sum / num_tokens).mean()
    
    def compute_activation_loss_huber(self, activation1, activation2, mask):
        squared_diff = torch.nn.functional.huber_loss(
            activation1, activation2, reduction="none"
        )  # Shape (b, s, d)
        expanded_mask = mask.unsqueeze(-1).expand_as(squared_diff)  # Shape: [b, s, d]
        squared_diff_sum = (
            (squared_diff * expanded_mask).sum(dim=2).sum(dim=(1))
        )  # Shape: [b, 1]
        num_tokens = mask.sum(dim=-1, keepdim=True)  # Sum over seq_len, Shape: [b, 1]
        return (squared_diff_sum / num_tokens).mean()
    

    def compute_retain_loss(self, model, retain_inputs):
        retain_loss = 0.0

        if self.retain_loss_type == "EMBED_DIFF":
            model_retain_activations, model_retain_outputs = self.forward_with_cache(
                model, retain_inputs, module=self.model_module, no_grad=False
            )
            ref_retain_activations, ref_retain_outputs = self.forward_with_cache(
                self.ref_model, retain_inputs, module=self.ref_module, no_grad=True
            )
            mask = retain_inputs["labels"][..., 1:] != -100  # Shape: [b, s]
            retain_loss = self.compute_activation_loss_mse(
                model_retain_activations,
                ref_retain_activations.to(model_retain_activations.device),
                mask,
            )
        else:
            model_retain_outputs = model(**retain_inputs)
            current_probs = F.log_softmax(model_retain_outputs.logits[..., :-1, :], dim=-1)

            with torch.no_grad():
                ref_retain_outputs = self.ref_model(**retain_inputs)
            ref_probs = F.log_softmax(ref_retain_outputs.logits[..., :-1, :], dim=-1)

            # minimum KL divergence
            retain_loss = F.kl_div(
                current_probs, ref_probs, reduction="none", log_target=True
            ).sum(-1)

        return retain_loss, model_retain_outputs, ref_retain_outputs
    
    def compute_target(self, outputs, teacher_outputs, inputs, beta):
        # Forward pass on the student (trainable) model
        logits = outputs.logits
        labels = inputs["labels"]

        shift_labels = labels[..., 1:].contiguous()
        shift_logits = logits[..., :-1, :].contiguous()

        # Forward pass on the teacher model (no grad)
        with torch.no_grad():
            teacher_logits = teacher_outputs.logits
        shift_teacher_logits = teacher_logits[..., :-1, :].contiguous()

        # Build the mask that identifies the tokens need to be unlearned
        mask = torch.zeros_like(shift_teacher_logits)
        batch_idx = torch.arange(mask.shape[0]).view(-1, 1, 1)
        seq_idx = torch.arange(mask.shape[1]).view(1, -1, 1)
        mask[batch_idx, seq_idx, shift_labels.unsqueeze(-1)] = 1.0

        # Adjust teacher logits: subtract di_strength on the correct token
        target_logits = shift_teacher_logits - mask * beta

        if self.lm_head_bias is not None:
            target_logits = target_logits - self.lm_head_bias

        target_hidden = target_logits @ self.lm_head_pinv
        
        pre_softmax = shift_teacher_logits - mask * beta
        soft_label = F.softmax(pre_softmax, dim=-1)

        loss_fct = nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            soft_label.view(-1, soft_label.size(-1)),
        )
        return target_logits
    

    def compute_forget_loss(self, model, forget_inputs):
        forget_loss = 0.0

        if self.retain_loss_type == "EMBED_DIFF":
            model_forget_activations, model_forget_outputs = self.forward_with_cache(
                model, forget_inputs, module=self.model_module, no_grad=False,
            )

            ref_forget_activations, ref_forget_outputs = self.forward_with_cache(
                self.ref_model, forget_inputs, module=self.ref_module, no_grad=True,
            )


            mask = forget_inputs["labels"][..., 1:] != -100  # Shape: [b, s]
            forget_loss = self.compute_activation_loss_mse(
                model_forget_activations,
                ref_forget_activations.to(model_forget_activations.device),
                mask,
            )
        else:
            model_forget_outputs = model(**forget_inputs)
            current_probs = F.log_softmax(model_forget_outputs.logits[..., :-1, :], dim=-1)


            with torch.no_grad():
                ref_forget_outputs = self.ref_model(**forget_inputs)
                ref_logits = ref_forget_outputs.logits[..., :-1, :]
            

            if self.lm_head_bias is not None:
                ref_logits = ref_logits - self.lm_head_bias

            ref_logits = 
            token_ref_probs = ref_probs.gather(dim=-1, index=input_ids_expanded).squeeze(-1)

            
            ref_probs = F.log_softmax(ref_logits, dim=-1)
            # minimum KL divergence
            forget_loss = F.kl_div(
                current_probs, ref_probs, reduction="none", log_target=True
            ).sum(-1)

        return forget_loss, model_forget_outputs, ref_forget_outputs


    def compute_loss(self, model, inputs, return_outputs=False):
        forget_inputs = inputs["forget"]
        forget_inputs = {
            "input_ids": forget_inputs["input_ids"],
            "attention_mask": forget_inputs["attention_mask"],
            "labels": forget_inputs["labels"],
        }

        forget_loss, forget_outputs, ref_forget_outputs = self.compute_forget_loss(model=model, forget_inputs=forget_inputs)

        retain_inputs = inputs["retain"]
        retain_inputs = {
            "input_ids": retain_inputs["input_ids"],
            "attention_mask": retain_inputs["attention_mask"],
            "labels": retain_inputs["labels"],
        }
        retain_loss, retain_outputs, ref_retain_outputs = self.compute_retain_loss(model=model, retain_inputs=retain_inputs)

        loss_f = forget_outputs.loss
        ref_loss_f = ref_forget_outputs.loss
        loss_r = retain_outputs.loss
        ref_loss_r = ref_retain_outputs.loss

        loss = forget_loss + self.alpha * retain_loss


        print(f"Forget loss: {loss_f.item()}                Ref Forget loss: {ref_loss_f.item()}                Forget Diff: {loss_f.item()-ref_loss_f.item()}")
        print(f"Retain loss: {loss_r.item()}                Ref Retain loss: {ref_loss_r.item()}                Retain Diff: {loss_r.item()-ref_loss_r.item()}                Diff of Diff: {loss_f.item()-ref_loss_f.item()-(loss_r.item()-ref_loss_r.item())}")
        print(f"our forget Loss: {forget_loss.item()}                our retain Loss: {retain_loss.item()}                Loss: {loss.item()}")

        return (loss, forget_outputs) if return_outputs else loss
    



def compute_undial_loss(model, ref_model, inputs, beta):
    # Forward pass on the student (trainable) model
    outputs = model(**inputs)
    logits = outputs.logits
    labels = inputs["labels"]

    shift_labels = labels[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous()

    # Forward pass on the teacher model (no grad)
    with torch.no_grad():
        teacher_logits = ref_model(**inputs).logits
    shift_teacher_logits = teacher_logits[..., :-1, :].contiguous()

    # Build the mask that identifies the tokens need to be unlearned
    mask = torch.zeros_like(shift_teacher_logits)
    batch_idx = torch.arange(mask.shape[0]).view(-1, 1, 1)
    seq_idx = torch.arange(mask.shape[1]).view(1, -1, 1)
    mask[batch_idx, seq_idx, shift_labels.unsqueeze(-1)] = 1.0

    # Adjust teacher logits: subtract di_strength on the correct token
    pre_softmax = shift_teacher_logits - mask * beta
    soft_label = F.softmax(pre_softmax, dim=-1)

    loss_fct = nn.CrossEntropyLoss(reduction="none")
    loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        soft_label.view(-1, soft_label.size(-1)),
    )
    return loss.mean(), outputs
