"""Borrowed implementation from https://github.com/centerforaisafety/wmdp/blob/main/rmu/unlearn.py"""

import re
import torch
import deepspeed
from functools import partial
from trainer.unlearn.grad_diff import GradDiff


class ATTU(GradDiff):
    def __init__(
        self,
        module_regex="model\.layers\.7",
        trainable_params_regex=["model\.layers\.(5|6|7)\.mlp\.down_proj\.weight"],
        attention_temp = 2.0,
        *args,
        **kwargs,
    ):
        """
        RMU Trainer that fine-tunes only specific layers and parameters using regex-based filtering.

        Args:
            module_path (str): Regex pattern to match module names.
            trainable_param_paths (list of str): List of regex patterns for trainable parameters.
        """
        super().__init__(*args, **kwargs)

        # Create reference model if not already set
        if self.ref_model is None:
            self.ref_model = self._prepare_ref_model(self.model)

        # Unfreeze only the selected parameters
        self.trainable_params_regex = (
            trainable_params_regex  # Regex for selecting params
        )

        # Get actual module references
        self.module_regex = module_regex  # Regex for selecting modules
        self.model_module = self._get_matching_module(self.model, self.module_regex)
        self.ref_module = self._get_matching_module(self.ref_model, self.module_regex)
        self.attention_temp = attention_temp

    def create_optimizer(self):
        self._freeze_all_params(self.model, False)
        # This makes the optimizer to select only trainable params
        self._set_trainable_params(self.model, self.trainable_params_regex, True)
        super().create_optimizer()
        self._freeze_all_params(self.model, True)

    def _get_matching_module(self, model, module_regex):
        """Returns a single module matching the given regex from a DeepSpeed/DDP-wrapped model."""
        # Handle DeepSpeed and DDP-wrapped models by accessing the underlying module
        if isinstance(model, deepspeed.DeepSpeedEngine):
            model = model.module  # Extract the actual PyTorch model inside

        matched_modules = {
            name: module
            for name, module in model.named_modules()
            if re.fullmatch(module_regex, name)
        }

        # if len(matched_modules) > 1:
        #     raise ValueError(
        #         f"More than one module matched with {module_regex}: {list(matched_modules.keys())}"
        #     )
        # elif not matched_modules:
        #     raise ValueError(f"No module matched with {module_regex}")
        

        if not matched_modules:
            raise ValueError(f"No module matched with {module_regex}")
        

        return next(iter(matched_modules.values()))  # Return the single matched module

    def _freeze_all_params(self, model, requires_grad=True):
        """Freeze all parameters in the model initially."""
        for param in model.parameters():
            param.requires_grad = requires_grad

    def _set_trainable_params(self, model, trainable_params_regex, requires_grad=True):
        """Unfreeze specific parameters that match the regex patterns."""
        for name, param in model.named_parameters():
            if any(re.fullmatch(pattern, name) for pattern in trainable_params_regex):
                param.requires_grad = requires_grad
                # print(f"{name}:requires_grad\t{requires_grad}")

    def forward_with_cache(self, model, inputs, module, no_grad=True, attention_temp=None):
        """Performs a forward pass while caching the output of a specified module."""
        cache = []

        def hook(module, input, output):
            if isinstance(output, tuple):
                cache.append(output[0])
            else:
                cache.append(output)
            return None

        hook_handle = module.register_forward_hook(hook)
        with torch.set_grad_enabled(not (no_grad)):
            if attention_temp:
                outputs = model(**inputs, attention_temp = attention_temp)
            else:
                outputs = model(**inputs)

        hook_handle.remove()
        return cache[0], outputs
    

    def forward_with_caches(self, model, inputs, modules_dict, no_grad=True, attention_temp=None):
        """Performs a forward pass and caches outputs from multiple modules."""
        
        # cache = {}
        # def hook(name):
        #     return lambda mod, inp, out: cache.setdefault(name, out[0] if isinstance(out, tuple) else out)

        # handles = []
        # for name, module in modules_dict.items():
        #     handles.append(module.register_forward_hook(hook(name)))

        cache = {}
        def hook(name, mod, inp, out):
            cache[name] = out[0] if isinstance(out, tuple) else out
        
        handles = []
        for name, module in modules_dict.items():
            handles.append(module.register_forward_hook(partial(hook, name)))

        with torch.set_grad_enabled(not no_grad):
            if attention_temp:
                outputs = model(**inputs, attention_temp = attention_temp)
            else:
                outputs = model(**inputs)

        for h in handles:
            h.remove()

        return cache, outputs
    

    def compute_activation_loss_mse(self, activation1, activation2, mask):
        squared_diff = torch.nn.functional.mse_loss(
            activation1, activation2, reduction="none"
        )  # Shape (b, s, d)
        expanded_mask = mask.unsqueeze(-1).expand_as(squared_diff)  # Shape: [b, s, d]
        squared_diff_sum = (
            (squared_diff * expanded_mask).mean(dim=2).sum(dim=(1))
        )  # Shape: [b, 1]
        num_tokens = mask.sum(dim=-1, keepdim=True)  # Sum over seq_len, Shape: [b, 1]
        return (squared_diff_sum / num_tokens).mean()
    

    def match_mean_and_variance(self, h1, h2, eps=1e-5):
        # Compute per-token mean and std for each
        mean1 = h1.mean(dim=-1, keepdim=True)  # [batch_size, seq_len, 1]
        std1 = h1.std(dim=-1, keepdim=True) + eps

        mean2 = h2.mean(dim=-1, keepdim=True)
        std2 = h2.std(dim=-1, keepdim=True)

        # Normalize h1
        h1_norm = (h1 - mean1) / std1

        # Rescale to match h2's mean and variance
        h1_matched = h1_norm * std2 + mean2

        return h1_matched
    

    def match_magnitude(self, h1, h2, eps=1e-8):
        # Compute L2 norms for each token vector
        norm1 = h1.norm(dim=-1, keepdim=True) + eps  # [batch_size, seq_len, 1]
        norm2 = h2.norm(dim=-1, keepdim=True)        # [batch_size, seq_len, 1]

        # Scale h1 to have same L2 norm as h2
        h1_scaled = h1 / norm1 * norm2

        return h1_scaled

    def compute_activation_loss_huber(self, activation1, activation2, mask, activation3=None):
        if activation3:
            # activation2 = self.match_magnitude(activation2, activation3)
            activation2 = self.match_mean_and_variance(activation2, activation3)

        squared_diff = torch.nn.functional.huber_loss(
            activation1, activation2, reduction="none"
        )  # Shape (b, s, d)

        expanded_mask = mask.unsqueeze(-1).expand_as(squared_diff)  # Shape: [b, s, d]
        squared_diff_sum = (
            (squared_diff * expanded_mask).mean(dim=2).sum(dim=(1))
        )  # Shape: [b, 1]
        num_tokens = mask.sum(dim=-1, keepdim=True)  # Sum over seq_len, Shape: [b, 1]
        return (squared_diff_sum / num_tokens).mean()
    

    def compute_retain_loss(self, model, retain_inputs):
        retain_loss = 0.0

        if self.retain_loss_type == "EMBED_DIFF":
            model_retain_activations, model_retain_outputs = self.forward_with_cache(
                model, retain_inputs, module=self.model_module, no_grad=False
            )
            ref_retain_activations, ref_retain_outputs = self.forward_with_cache(
                self.ref_model, retain_inputs, module=self.ref_module, no_grad=True
            )
            mask = retain_inputs["labels"] != -100  # Shape: [b, s]

            if isinstance(model_retain_activations, dict):
                losses = []
                for name in model_retain_activations:
                    losses.append(self.compute_activation_loss_mse(
                        model_retain_activations[name],
                        ref_retain_activations[name].to(model_retain_activations[name].device),
                        mask,
                    ))
                retain_loss = torch.stack(losses).mean()

            else:
                retain_loss = self.compute_activation_loss_mse(
                    model_retain_activations,
                    ref_retain_activations.to(model_retain_activations.device),
                    mask,
                )
        else:
            retain_loss = super().compute_retain_loss(model, retain_inputs)

        return retain_loss, model_retain_outputs, ref_retain_outputs
    

    def compute_forget_loss(self, model, forget_inputs):
        forget_loss = 0.0

        if self.retain_loss_type == "EMBED_DIFF":
            model_forget_activations, model_forget_outputs = self.forward_with_cache(
                model, forget_inputs, module=self.model_module, no_grad=False
            )
            ref_forget_activations_temp, ref_forget_outputs_temp = self.forward_with_cache(
                self.ref_model, forget_inputs, module=self.ref_module, no_grad=True, attention_temp=self.attention_temp,
            )
            ref_forget_activations, ref_forget_outputs = self.forward_with_cache(
                self.ref_model, forget_inputs, module=self.ref_module, no_grad=True,
            )
            mask = forget_inputs["labels"] != -100  # Shape: [b, s]

            if isinstance(model_forget_activations, dict):
                losses = []
                for name in model_forget_activations:
                    losses.append(self.compute_activation_loss_huber(
                        model_forget_activations[name],
                        ref_forget_activations_temp[name].to(model_forget_activations[name].device),
                        ref_forget_activations[name].to(model_forget_activations[name].device),
                        mask,
                    ))
                forget_loss = torch.stack(losses).mean()
                
            else:
                forget_loss = self.compute_activation_loss_huber(
                    model_forget_activations,
                    ref_forget_activations.to(model_forget_activations.device),
                    mask,
                )
        else:
            forget_loss = - super().compute_retain_loss(model, forget_inputs)
        return forget_loss, model_forget_outputs, ref_forget_outputs

    def compute_loss(self, model, inputs, return_outputs=False):
        forget_inputs = inputs["forget"]
        forget_inputs = {
            "input_ids": forget_inputs["input_ids"],
            "attention_mask": forget_inputs["attention_mask"],
            "labels": forget_inputs["labels"],
        }

        forget_loss, forget_outputs, ref_forget_outputs = self.compute_forget_loss(model=model, forget_inputs=forget_inputs)

        retain_inputs = inputs["retain"]
        retain_inputs = {
            "input_ids": retain_inputs["input_ids"],
            "attention_mask": retain_inputs["attention_mask"],
            "labels": retain_inputs["labels"],
        }
        retain_loss, retain_outputs, ref_retain_outputs = self.compute_retain_loss(model=model, retain_inputs=retain_inputs)

        loss_f = forget_outputs.loss
        ref_loss_f = ref_forget_outputs.loss
        loss_r = retain_outputs.loss
        ref_loss_r = ref_retain_outputs.loss

        print(f"Forget loss: {loss_f.item()}                Ref Forget loss: {ref_loss_f.item()}                Forget Diff: {loss_f.item()-ref_loss_f.item()}")
        print(f"Retain loss: {loss_r.item()}                Ref Retain loss: {ref_loss_r.item()}                Retain Diff: {loss_r.item()-ref_loss_r.item()}                Diff of Diff: {loss_f.item()-ref_loss_f.item()-(loss_r.item()-ref_loss_r.item())}")
        print(f"our forget Loss: {forget_loss.item()}                our retain Loss: {retain_loss.item()}                Loss: {loss.item()}")

        loss = self.gamma * forget_loss + self.alpha * retain_loss

        return (loss, forget_outputs) if return_outputs else loss
