"""Borrowed implementation from https://github.com/centerforaisafety/wmdp/blob/main/rmu/unlearn.py"""

import re
import torch
import deepspeed
from functools import partial
from trainer.unlearn.grad_diff import GradDiff


class ATTU_hidden(GradDiff):
    def __init__(
        self,
        module_regex="model\.layers\.7",
        trainable_params_regex=["model\.layers\.(5|6|7)\.mlp\.down_proj\.weight"],
        attention_temp = 2.0,
        *args,
        **kwargs,
    ):
        """
        RMU Trainer that fine-tunes only specific layers and parameters using regex-based filtering.

        Args:
            module_path (str): Regex pattern to match module names.
            trainable_param_paths (list of str): List of regex patterns for trainable parameters.
        """
        super().__init__(*args, **kwargs)

        # Create reference model if not already set
        if self.ref_model is None:
            self.ref_model = self._prepare_ref_model(self.model)

        # Unfreeze only the selected parameters
        self.trainable_params_regex = (
            trainable_params_regex  # Regex for selecting params
        )

        # Get actual module references
        self.module_regex = module_regex  # Regex for selecting modules
        self.model_module = self._get_matching_module(self.model, self.module_regex)
        self.ref_module = self._get_matching_module(self.ref_model, self.module_regex)
        self.attention_temp = attention_temp
        
        self.layer_num = None
        match = re.search(r"model\\.layers\\.(\d+)", module_regex)
        if match:
            self.layer_num = int(match.group(1))
            print(self.layer_num)

    def create_optimizer(self):
        self._freeze_all_params(self.model, False)
        # This makes the optimizer to select only trainable params
        self._set_trainable_params(self.model, self.trainable_params_regex, True)
        super().create_optimizer()
        self._freeze_all_params(self.model, True)

    def _get_matching_module(self, model, module_regex):
        """Returns a single module matching the given regex from a DeepSpeed/DDP-wrapped model."""
        # Handle DeepSpeed and DDP-wrapped models by accessing the underlying module
        if isinstance(model, deepspeed.DeepSpeedEngine):
            model = model.module  # Extract the actual PyTorch model inside

        matched_modules = {
            name: module
            for name, module in model.named_modules()
            if re.fullmatch(module_regex, name)
        }

        if len(matched_modules) > 1:
            raise ValueError(
                f"More than one module matched with {module_regex}: {list(matched_modules.keys())}"
            )
        elif not matched_modules:
            raise ValueError(f"No module matched with {module_regex}")

        return next(iter(matched_modules.values()))  # Return the single matched module

    def _freeze_all_params(self, model, requires_grad=True):
        """Freeze all parameters in the model initially."""
        for param in model.parameters():
            param.requires_grad = requires_grad

    def _set_trainable_params(self, model, trainable_params_regex, requires_grad=True):
        """Unfreeze specific parameters that match the regex patterns."""
        for name, param in model.named_parameters():
            if any(re.fullmatch(pattern, name) for pattern in trainable_params_regex):
                param.requires_grad = requires_grad
                print(name)
                # print(f"{name}:requires_grad\t{requires_grad}")

    # def forward_with_cache(self, model, inputs, module, no_grad=True):
    #     """Performs a forward pass while caching the output of a specified module."""
    #     cache = []

    #     def hook(module, input, output):
    #         if isinstance(output, tuple):
    #             cache.append(output[0])
    #         else:
    #             cache.append(output)
    #         return None

    #     hook_handle = module.register_forward_hook(hook)
    #     with torch.set_grad_enabled(not (no_grad)):
    #         outputs = model(**inputs)
    #     hook_handle.remove()
    #     return cache[0], outputs
    
    def forward_with_cache(self, model, inputs, module, no_grad=True, attention_temp=None):
        """Performs a forward pass while caching the output of a specified module."""
        cache = []

        def hook(module, input, output):
            if isinstance(output, tuple):
                cache.append(output[0])
            else:
                cache.append(output)
            return None

        hook_handle = module.register_forward_hook(hook)
        with torch.set_grad_enabled(not (no_grad)):
            if attention_temp:
                outputs = model(**inputs, attention_temp = attention_temp, layers_id = [i for i in range(self.layer_num)])
            else:
                outputs = model(**inputs)

        hook_handle.remove()
        return cache[0][..., :-1, :], outputs
    

    def compute_activation_loss_mse(self, activation1, activation2, mask):
        squared_diff = torch.nn.functional.mse_loss(
            activation1, activation2, reduction="none"
        )  # Shape (b, s, d)
        expanded_mask = mask.unsqueeze(-1).expand_as(squared_diff)  # Shape: [b, s, d]
        squared_diff_sum = (
            (squared_diff * expanded_mask).mean(dim=2).sum(dim=1)
        )  # Shape: [b, 1]
        num_tokens = mask.sum(dim=-1, keepdim=True)  # Sum over seq_len, Shape: [b, 1]
        return (squared_diff_sum / num_tokens).mean()
    
    def compute_activation_loss_huber(self, activation1, activation2, mask):
        squared_diff = torch.nn.functional.huber_loss(
            activation1, activation2, reduction="none"
        )  # Shape (b, s, d)
        expanded_mask = mask.unsqueeze(-1).expand_as(squared_diff)  # Shape: [b, s, d]
        squared_diff_sum = (
            (squared_diff * expanded_mask).mean(dim=2).sum(dim=(1))
        )  # Shape: [b, 1]
        num_tokens = mask.sum(dim=-1, keepdim=True)  # Sum over seq_len, Shape: [b, 1]
        return (squared_diff_sum / num_tokens).mean()
    

    def compute_retain_loss(self, model, retain_inputs):
        retain_loss = 0.0

        # if self.retain_loss_type == "EMBED_DIFF":
        model_retain_activations, model_retain_outputs = self.forward_with_cache(
            model, retain_inputs, module=self.model_module, no_grad=False
        )
        ref_retain_activations, ref_retain_outputs = self.forward_with_cache(
            self.ref_model, retain_inputs, module=self.ref_module, no_grad=True
        )
        mask = retain_inputs["labels"][..., 1:] != -100  # Shape: [b, s]
        retain_loss = self.compute_activation_loss_mse(
            model_retain_activations,
            ref_retain_activations.to(model_retain_activations.device),
            mask,
        )
        # else:
        #     retain_loss = super().compute_retain_loss(model, retain_inputs)

        return retain_loss, model_retain_outputs, ref_retain_outputs
    

    def compute_forget_loss(self, model, forget_inputs):
        forget_loss = 0.0

        # if self.retain_loss_type == "EMBED_DIFF":
        model_forget_activations, model_forget_outputs = self.forward_with_cache(
            model, forget_inputs, module=self.model_module, no_grad=False
        )
        ref_forget_activations_temp, ref_forget_outputs_temp = self.forward_with_cache(
            self.ref_model, forget_inputs, module=self.ref_module, no_grad=True, attention_temp=self.attention_temp,
        )
        # ref_forget_activations, ref_forget_outputs = self.forward_with_cache(
        #     self.ref_model, forget_inputs, module=self.ref_module, no_grad=True,
        # )
        mask = forget_inputs["labels"][..., 1:] != -100  # Shape: [b, s]
        forget_loss = self.compute_activation_loss_mse(
            model_forget_activations,
            ref_forget_activations_temp.to(model_forget_activations.device),
            mask,
        )
        # else:
        #     forget_loss = - super().compute_retain_loss(model, forget_inputs)
        return forget_loss, model_forget_outputs, ref_forget_outputs_temp


    def compute_loss(self, model, inputs, return_outputs=False):
        forget_inputs = inputs["forget"]
        forget_inputs = {
            "input_ids": forget_inputs["input_ids"],
            "attention_mask": forget_inputs["attention_mask"],
            "labels": forget_inputs["labels"],
        }

        forget_loss, forget_outputs, ref_forget_outputs = self.compute_forget_loss(model=model, forget_inputs=forget_inputs)

        retain_inputs = inputs["retain"]
        retain_inputs = {
            "input_ids": retain_inputs["input_ids"],
            "attention_mask": retain_inputs["attention_mask"],
            "labels": retain_inputs["labels"],
        }
        retain_loss, retain_outputs, ref_retain_outputs = self.compute_retain_loss(model=model, retain_inputs=retain_inputs)

        loss_f = forget_outputs.loss
        ref_loss_f = ref_forget_outputs.loss
        loss_r = retain_outputs.loss
        ref_loss_r = ref_retain_outputs.loss

        loss = self.alpha * forget_loss + retain_loss


        print(f"Forget loss: {loss_f.item()}                Ref Forget loss: {ref_loss_f.item()}                Forget Diff: {loss_f.item()-ref_loss_f.item()}")
        print(f"Retain loss: {loss_r.item()}                Ref Retain loss: {ref_loss_r.item()}                Retain Diff: {loss_r.item()-ref_loss_r.item()}                Diff of Diff: {loss_f.item()-ref_loss_f.item()-(loss_r.item()-ref_loss_r.item())}")
        print(f"our forget Loss: {forget_loss.item()}                our retain Loss: {retain_loss.item()}                Loss: {loss.item()}")

        return (loss, forget_outputs) if return_outputs else loss
