import copy

import deepspeed
from transformers import Trainer

from .losses import get_loss

import torch
import torch.nn as nn
import torch.nn.functional as F

import re
import torch
import deepspeed
from functools import partial

class CustomTrainerForgettingHidden(Trainer):
    def __init__(self, *args, **kwargs):
        self.loss_type = kwargs.pop('loss_type')
        self.ref_model = kwargs.pop('ref_model')

        # self.module_regex="model\.layers\.7",
        # trainable_params_regex=["model\.layers\.(5|6|7)\.mlp\.down_proj\.weight"],

        try:
            self.trainable_params_regex = (
                kwargs.pop('trainable_params_regex')
            )
            # Get actual module references
            self.module_regex = kwargs.pop('module_regex')
        except:
            self.trainable_params_regex = (
                ['*']
            )
            self.module_regex = "model\.layers\.31"

        # the coefficient of each part in the loss function. This is used in ablation study.
        self.forget_coeff = kwargs.pop('forget_coeff')
        self.regularization_coeff = kwargs.pop('regularization_coeff')
        # beta for NPO/DPO/RS
        self.beta = kwargs.pop('beta')
        self.attention_temp = kwargs.pop('attention_temp')

        super(CustomTrainerForgettingHidden, self).__init__(*args, **kwargs)

        self.ref_model = self.e_prepare_deepspeed(self.ref_model)

        self.model_module = self._get_matching_module(self.model, self.module_regex)
        self.ref_module = self._get_matching_module(self.ref_model, self.module_regex)

    def create_optimizer(self):
        self._freeze_all_params(self.model, False)
        # This makes the optimizer to select only trainable params
        self._set_trainable_params(self.model, self.trainable_params_regex, True)
        super().create_optimizer()
        self._freeze_all_params(self.model, True)

    def _get_matching_module(self, model, module_regex):
        """Returns a single module matching the given regex from a DeepSpeed/DDP-wrapped model."""
        # Handle DeepSpeed and DDP-wrapped models by accessing the underlying module
        if isinstance(model, deepspeed.DeepSpeedEngine):
            model = model.module  # Extract the actual PyTorch model inside

        matched_modules = {
            name: module
            for name, module in model.named_modules()
            if re.fullmatch(module_regex, name)
        }

        if len(matched_modules) > 1:
            raise ValueError(
                f"More than one module matched with {module_regex}: {list(matched_modules.keys())}"
            )
        elif not matched_modules:
            raise ValueError(f"No module matched with {module_regex}")

        return next(iter(matched_modules.values()))  # Return the single matched module

    def _freeze_all_params(self, model, requires_grad=True):
        """Freeze all parameters in the model initially."""
        for param in model.parameters():
            param.requires_grad = requires_grad

    def _set_trainable_params(self, model, trainable_params_regex, requires_grad=True):
        """Unfreeze specific parameters that match the regex patterns."""
        for name, param in model.named_parameters():
            if any(re.fullmatch(pattern, name) for pattern in trainable_params_regex):
                param.requires_grad = requires_grad
                print(name)
    
    # def compute_loss(self, model, inputs, return_outputs=False):

    #     forget_loss, regularization_loss = get_loss(model, self.ref_model, inputs, self.loss_type, self.beta, self.attention_temp)
    #     loss = self.forget_coeff * forget_loss + self.regularization_coeff * regularization_loss

    #     return (loss, None) if return_outputs else loss
    

    def forward_with_cache(self, model, inputs, module, no_grad=True, attention_temp=None):
        """Performs a forward pass while caching the output of a specified module."""
        cache = []

        def hook(module, input, output):
            if isinstance(output, tuple):
                cache.append(output[0])
            else:
                cache.append(output)
            return None

        hook_handle = module.register_forward_hook(hook)
        with torch.set_grad_enabled(not (no_grad)):
            if attention_temp:
                outputs = model(**inputs, attention_temp = attention_temp, layers_id = [i for i in range(self.layer_num)])
            else:
                outputs = model(**inputs)

        hook_handle.remove()
        return cache[0][..., :-1, :], outputs
    

    def compute_activation_loss_mse(self, activation1, activation2, mask):
        squared_diff = torch.nn.functional.mse_loss(
            activation1, activation2, reduction="none"
        )  # Shape (b, s, d)
        expanded_mask = mask.unsqueeze(-1).expand_as(squared_diff)  # Shape: [b, s, d]
        squared_diff_sum = (
            (squared_diff * expanded_mask).mean(dim=2).sum(dim=1)
        )  # Shape: [b, 1]
        num_tokens = mask.sum(dim=-1, keepdim=True)  # Sum over seq_len, Shape: [b, 1]
        return (squared_diff_sum / num_tokens).mean()
    
    def compute_activation_loss_huber(self, activation1, activation2, mask):
        squared_diff = torch.nn.functional.huber_loss(
            activation1, activation2, reduction="none"
        )  # Shape (b, s, d)
        expanded_mask = mask.unsqueeze(-1).expand_as(squared_diff)  # Shape: [b, s, d]
        squared_diff_sum = (
            (squared_diff * expanded_mask).mean(dim=2).sum(dim=(1))
        )  # Shape: [b, 1]
        num_tokens = mask.sum(dim=-1, keepdim=True)  # Sum over seq_len, Shape: [b, 1]
        return (squared_diff_sum / num_tokens).mean()
    

    def compute_retain_loss(self, model, retain_inputs):
        retain_loss = 0.0

        # if self.retain_loss_type == "EMBED_DIFF":
        model_retain_activations, model_retain_outputs = self.forward_with_cache(
            model, retain_inputs, module=self.model_module, no_grad=False
        )
        ref_retain_activations, ref_retain_outputs = self.forward_with_cache(
            self.ref_model, retain_inputs, module=self.ref_module, no_grad=True
        )
        mask = retain_inputs["labels"][..., 1:] != -100  # Shape: [b, s]
        retain_loss = self.compute_activation_loss_mse(
            model_retain_activations,
            ref_retain_activations.to(model_retain_activations.device),
            mask,
        )
        # else:
        #     retain_loss = super().compute_retain_loss(model, retain_inputs)

        return retain_loss, model_retain_outputs, ref_retain_outputs
    

    def compute_forget_loss(self, model, forget_inputs):
        forget_loss = 0.0

        # if self.retain_loss_type == "EMBED_DIFF":
        model_forget_activations, model_forget_outputs = self.forward_with_cache(
            model, forget_inputs, module=self.model_module, no_grad=False
        )
        ref_forget_activations_temp, ref_forget_outputs_temp = self.forward_with_cache(
            self.ref_model, forget_inputs, module=self.ref_module, no_grad=True, attention_temp=self.attention_temp,
        )
        # ref_forget_activations, ref_forget_outputs = self.forward_with_cache(
        #     self.ref_model, forget_inputs, module=self.ref_module, no_grad=True,
        # )
        mask = forget_inputs["labels"][..., 1:] != -100  # Shape: [b, s]
        forget_loss = self.compute_activation_loss_mse(
            model_forget_activations,
            ref_forget_activations_temp.to(model_forget_activations.device),
            mask,
        )
        # else:
        #     forget_loss = - super().compute_retain_loss(model, forget_inputs)
        return forget_loss, model_forget_outputs, ref_forget_outputs_temp


    def compute_loss(self, model, inputs, return_outputs=False):
        forget_inputs = inputs[0]
        input_ids, labels, attention_mask = forget_inputs
        forget_inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

        forget_loss, forget_outputs, ref_forget_outputs = self.compute_forget_loss(model=model, forget_inputs=forget_inputs)

        retain_inputs = inputs[1]
        input_ids, labels, attention_mask = retain_inputs
        retain_inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }
        retain_loss, retain_outputs, ref_retain_outputs = self.compute_retain_loss(model=model, retain_inputs=retain_inputs)

        loss_f = forget_outputs.loss
        ref_loss_f = ref_forget_outputs.loss
        loss_r = retain_outputs.loss
        ref_loss_r = ref_retain_outputs.loss

        loss = self.forget_coeff * forget_loss + self.regularization_coeff * retain_loss


        print(f"Forget loss: {loss_f.item()}                Ref Forget loss: {ref_loss_f.item()}                Forget Diff: {loss_f.item()-ref_loss_f.item()}")
        print(f"Retain loss: {loss_r.item()}                Ref Retain loss: {ref_loss_r.item()}                Retain Diff: {loss_r.item()-ref_loss_r.item()}                Diff of Diff: {loss_f.item()-ref_loss_f.item()-(loss_r.item()-ref_loss_r.item())}")
        print(f"our forget Loss: {forget_loss.item()}                our retain Loss: {retain_loss.item()}                Loss: {loss.item()}")

        return (loss, forget_outputs) if return_outputs else loss
    

    def e_prepare_deepspeed(self, model):
        # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
        deepspeed_plugin = self.accelerator.state.deepspeed_plugin
        config_kwargs = copy.deepcopy(deepspeed_plugin.deepspeed_config)

        if model is not None:
            if hasattr(model, "config"):
                hidden_size = (
                    max(model.config.hidden_sizes)
                    if getattr(model.config, "hidden_sizes", None)
                    else getattr(model.config, "hidden_size", None)
                )
                if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
                    # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
                    # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
                    config_kwargs.update(
                        {
                            "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
                            "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
                            "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
                        }
                    )

        # If ZeRO-3 is used, we shard both the active and reference model.
        # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
        if config_kwargs["zero_optimization"]["stage"] != 3:
            config_kwargs["zero_optimization"]["stage"] = 0
        config_kwargs["optimizer"] = {"type": None}
        model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
        model.eval()
        # set the gradients to false for every parameter
        for param in model.parameters():
            param.requires_grad = False

        return model
