import torch
import re

import torch.nn.functional as F

from constants import DEVICE_CPU
from enums import LossFunction


def _compute_kl_divergence(q_out, float_logits, batch, pad_id, t=1):
    """Compute KL divergence between outputs."""
    # Get log probabilities
    float_logp = F.log_softmax(float_logits.float() / t, dim=-1)
    q_logp = F.log_softmax(q_out.float() / t, dim=-1)

    B, T, V = q_logp.shape

    # KL divergence per token
    kl_per_token = (t ** 2) * F.kl_div(
        q_logp.view(B * T, V),
        float_logp.view(B * T, V),
        log_target=True,
        reduction="none",
    ).sum(dim=-1)  # sum over vocab → [B*T]

    # Mask padding tokens
    mask = (batch != pad_id).view(B * T).float()
    kl_masked = kl_per_token * mask

    # Mean over real tokens
    loss = kl_masked.sum() / mask.sum()
    return loss


def _compute_mse(q_out, float_logits, batch, pad_id=None):
    """Compute MSE between model outputs.

    For LLM logits shaped [B, T, V], this computes elementwise MSE on logits and
    then averages over vocab to get a per-token loss [B, T].

    If `pad_id` is provided, it masks out positions where `batch == pad_id`.
    """
    # Per-token MSE on logits: [B, T, V] -> [B, T]
    mse_per_token = F.mse_loss(
        q_out.float(),
        float_logits.float(),
        reduction="none",
    ).mean(dim=-1)

    if pad_id is not None:
        # Mask padding tokens (batch is token IDs): [B, T]
        mask = (batch != pad_id).to(mse_per_token.dtype)
        denom = mask.sum().clamp_min(1.0)
        return (mse_per_token * mask).sum() / denom

    return mse_per_token.mean()


class OutDistill:
    """Output distillation loss with configurable distance metric.

    Computes loss between quantized and float model outputs using the specified
    distance metric (e.g., KL divergence, MSE, etc.).
    """

    def __init__(self, model, distance_metric: str, pad_id=None, t=1):
        """
        Args:
            model: Float model for distillation (teacher)
            distance_metric: Distance metric to use ("kl", "mse", etc.)
            pad_id: Padding token ID (required for language models)
        """
        self.model = model
        self.distance_metric = distance_metric
        self.pad_id = pad_id
        self.t = t

        # Validate that pad_id is provided for metrics that need it
        if self.distance_metric in ["kl"] and self.pad_id is None:
            raise ValueError(f"pad_id is required for {distance_metric} metric")


    def __call__(self, q_out, batch):
        """
        Compute output distillation loss.

        Args:
            q_out: Logits from quantized model (tensor)
            batch: Input batch (token IDs for LLM)

        Returns:
            Loss value (scalar tensor)
        """
        device = batch.device
        self.model.to(device)
        self.model.eval()

        # Get teacher (float model) outputs
        with torch.no_grad():
            float_outputs = self.model(batch)
            float_logits = (
                float_outputs.logits
                if hasattr(float_outputs, "logits")
                else float_outputs
            )

        self.model.to("cpu")
        torch.cuda.empty_cache()

        # Compute loss based on distance metric
        if self.distance_metric == "kl":
            loss = _compute_kl_divergence(q_out, float_logits, batch, self.pad_id, t=self.t)
        elif self.distance_metric == "mse":
            loss = _compute_mse(q_out, float_logits, batch, self.pad_id)
        else:
            raise ValueError(f"Unsupported distance metric: {self.distance_metric}")

        self.model.to("cpu")
        return loss


# Legacy alias for backward compatibility
OutKLLoss = OutDistill


class FlatQDistill:
    """Frobenius norm distillation loss at transformer block outputs.

    Computes Frobenius norm distance between quantized and float model outputs
    at each transformer block, then averages across all blocks.
    """

    def __init__(self, model, distance_metric: str, q_model, pad_id=None):
        """
        Args:
            model: Float model for distillation (teacher)
            distance_metric: Distance metric to use (currently only "frobenius" supported)
            q_model: Quantized model (student)
            pad_id: Padding token ID (optional, for masking)
        """
        self.model = model
        self.distance_metric = distance_metric
        self.pad_id = pad_id

        # Find all transformer blocks (model.layers.<i>)
        pattern = re.compile(r"^model\.layers\.\d+$")

        self.modules_to_distill = [
            n for n, _ in model.named_modules()
            if pattern.match(n)
        ]

        if len(self.modules_to_distill) == 0:
            raise ValueError("No transformer blocks found matching pattern 'model.layers.<i>'")

        self.per_module_loss = {}
        self.float_batch_per_module = {}
        self._clear_module_losses()

        # Equal weights for all blocks
        self.weights = torch.tensor([1 / len(self.modules_to_distill)] * len(self.modules_to_distill))

        # Float output hooks
        self.float_handles = self._register_float_hooks()

        # Loss hooks on quantized model
        self.handles = self._register_loss_hooks(q_model)

        self.batch = None

    def set_batch(self, batch):
        """Set the current batch for loss computation."""
        self.batch = batch

    def _clear_module_losses(self):
        self.per_module_loss = {n: 0.0 for n in self.modules_to_distill}
        self.float_batch_per_module = {n: None for n in self.modules_to_distill}
        self.batch = None

    def _create_float_capture_hook(self, module_name):
        """Create a hook that captures float model outputs and stores them on CPU."""
        def hook(module, input, output):
            # Store output on CPU
            if isinstance(output, tuple):
                # Store first element if tuple (common for transformers)
                self.float_batch_per_module[module_name] = output[0].detach().cpu()
            else:
                self.float_batch_per_module[module_name] = output.detach().cpu()

        return hook

    def _register_float_hooks(self):
        """Register hooks on float model to capture intermediate outputs."""
        print("Registering float output capture hooks for FlatQDistill...")

        handles = []

        for module_name in self.modules_to_distill:
            module = self._get_module_by_name(self.model, module_name)
            if module is None:
                raise Exception(f"Module '{module_name}' not found in float model")

            hook = self._create_float_capture_hook(module_name)
            handle = module.register_forward_hook(hook)
            handles.append(handle)

        print(f"Registered {len(handles)} float capture hooks for FlatQDistill")

        return handles

    def compute_float_outputs(self):
        """Run forward pass through float model to populate cached outputs."""
        device = self.batch.device
        self.model.to(device)
        self.model.eval()

        with torch.no_grad():
            _ = self.model(self.batch)

        self.model.to("cpu")

    def _get_module_by_name(self, model, module_name):
        """
        Get a module by its name path (e.g., 'model.layers.0').

        Args:
            model: The model
            module_name: Dot-separated path to the module

        Returns:
            The module, or None if not found
        """
        parts = module_name.split('.')
        module = model

        try:
            for part in parts:
                module = getattr(module, part)
            return module
        except AttributeError:
            return None

    def _compute_frobenius_norm_loss(self, q_out, float_out, batch, pad_id):
        """
        Compute Frobenius norm distance between quantized and float outputs.

        Args:
            q_out: Quantized model output [B, T, H]
            float_out: Float model output [B, T, H]
            batch: Input batch (token IDs)
            pad_id: Padding token ID

        Returns:
            Scalar loss (averaged over non-padding tokens)
        """
        # Compute squared Frobenius norm: ||Q - F||_F^2 = sum((Q - F)^2)
        diff = q_out.float() - float_out.float()

        if pad_id is not None:
            # Mask padding tokens: [B, T]
            mask = (batch != pad_id).unsqueeze(-1).float()  # [B, T, 1]
            diff = diff * mask.to(diff.device)
            denom = mask.sum().clamp_min(1.0)
        else:
            B, T, H = diff.shape
            denom = B * T

        # Frobenius norm squared, then take sqrt and normalize
        frob_squared = (diff ** 2).sum()
        loss = torch.sqrt(frob_squared) / denom

        return loss

    def _create_loss_hook(self, module_name):
        """Create a hook that computes Frobenius norm loss."""
        def hook(module, input, q_out):
            # Handle tuple outputs
            if isinstance(q_out, tuple):
                q_output = q_out[0]
            else:
                q_output = q_out

            float_output = self.float_batch_per_module[module_name].to(q_output.device)

            loss = self._compute_frobenius_norm_loss(
                q_output, float_output, self.batch, self.pad_id
            )

            self.per_module_loss[module_name] += loss
            self.float_batch_per_module[module_name] = self.float_batch_per_module[module_name].to(DEVICE_CPU)

        return hook

    def _register_loss_hooks(self, q_model):
        """Register loss computation hooks on quantized model."""
        print("Registering Frobenius norm loss hooks on quantized model...")

        handles = []

        for module_name in self.modules_to_distill:
            module = self._get_module_by_name(q_model, module_name)
            if module is None:
                print(f"Warning: Module '{module_name}' not found in quantized model, skipping hook")
                continue

            hook = self._create_loss_hook(module_name)
            handle = module.register_forward_hook(hook)
            handles.append(handle)

        print(f"Registered {len(handles)} loss hooks for FlatQDistill")
        return handles

    def remove_hooks(self):
        """Remove all hooks."""
        for handle in self.handles:
            handle.remove()
        for handle in self.float_handles:
            handle.remove()
        print("Removed all FlatQDistill hooks")

    def __call__(self, q_out, batch):
        """
        Compute average Frobenius norm loss across all transformer blocks.

        Args:
            q_out: Final output from quantized model (not used, blocks are hooked)
            batch: Input batch

        Returns:
            Scalar loss averaged across all blocks
        """
        losses = torch.stack(list(self.per_module_loss.values()))
        loss = torch.sum(losses * self.weights.to(batch.device))
        self._clear_module_losses()
        return loss


class UnembedDistill:
    """Unembedding distillation loss (not implemented)."""

    def __init__(self, model, distance_metric: str, q_model, embed_weight, pad_id=None):
        """
        Args:
            model: Float model for distillation (teacher)
            distance_metric: Distance metric to use ("kl", "mse", etc.)
        """
        self.model = model
        self.distance_metric = distance_metric
        self.pad_id = pad_id
        self.embed_weight = embed_weight.detach()

        # Validate that pad_id is provided for metrics that need it
        if self.distance_metric in ["kl"] and self.pad_id is None:
            raise ValueError(f"pad_id is required for {distance_metric} metric")

        # TODO: change to argument
        pattern = re.compile(r"^model\.layers\.\d+\.mlp$")  # model.model.layers.<i>.mlp

        self.modules_to_distill = [
            n for n, _ in model.named_modules()
            if pattern.match(n)
        ]
        self.per_module_loss = {}
        self.float_batch_per_module = {}
        self._clear_module_losses()
        # TODO: change to argument
        self.weights = torch.tensor([1 / len(self.modules_to_distill)] * len(self.modules_to_distill))

        # Float output hooks (registered separately via register_float_hooks())
        self.float_handles = self._register_float_hooks()

        # Loss hooks on quantized model
        self.handles = self._register_loss_hooks(q_model)

        self.batch = None

    def set_batch(self, batch):
        """Set the current batch for loss computation."""
        self.batch = batch

    def _clear_module_losses(self):
        self.per_module_loss = {n: 0.0 for n in self.modules_to_distill}
        self.float_batch_per_module = {n: None for n in self.modules_to_distill}

        # TODO: save to mask no need for the whole batch
        self.batch = None

    def _create_float_capture_hook(self, module_name):
        """Create a hook that captures float model outputs and stores them on CPU."""
        def hook(module, input, output):
            # Store output on CPU
            if isinstance(output, tuple):
                # Store first element if tuple (common for transformers)
                self.float_batch_per_module[module_name] = output[0].detach().cpu()
            else:
                self.float_batch_per_module[module_name] = output.detach().cpu()

        return hook

    def _register_float_hooks(self):
        """Register hooks on float model to capture intermediate outputs."""
        print("Registering float output capture hooks...")

        handles = []

        for module_name in self.modules_to_distill:
            module = self._get_module_by_name(self.model, module_name)
            if module is None:
                raise Exception(f"Warning: Module '{module_name}' not found in float model, skipping hook")
                continue

            hook = self._create_float_capture_hook(module_name)
            handle = module.register_forward_hook(hook)
            handles.append(handle)

        print(f"Registered {len(handles)} float capture hooks")

        return handles

    def compute_float_outputs(self):
        device = self.batch.device
        self.model.to(device)
        self.model.eval()

        with torch.no_grad():
            _ = self.model(self.batch)

        self.model.to("cpu")

    def _get_module_by_name(self, model, module_name):
        """
        Get a module by its name path (e.g., 'model.layers.0').

        Args:
            model: The model
            module_name: Dot-separated path to the module

        Returns:
            The module, or None if not found
        """
        parts = module_name.split('.')
        module = model

        try:
            for part in parts:
                module = getattr(module, part)
            return module
        except AttributeError:
            return None

    def _create_loss_hook(self, module_name):
        def hook(module, input, q_out):
            if self.distance_metric == "kl":
                f_logits = torch.einsum('bsh,vh->bsv', self.float_batch_per_module[module_name].to(q_out.device), self.embed_weight)
                q_logits = torch.einsum('bsh,vh->bsv', q_out, self.embed_weight)
                loss = _compute_kl_divergence(q_logits, f_logits, self.batch, self.pad_id)
            elif self.distance_metric == "mse":
                loss_fn = _compute_mse
            else:
                raise ValueError(f"Unsupported distance metric: {self.distance_metric}")

            self.per_module_loss[module_name] += loss
            self.float_batch_per_module[module_name].to(DEVICE_CPU)

        return hook

    def _register_loss_hooks(self, q_model):
        print("Warning: registering unembedding distillation hooks on the model.")

        handles = []

        for module_name in self.modules_to_distill:
            module = self._get_module_by_name(q_model, module_name)
            if module is None:
                print(f"Warning: Module '{module_name}' not found, skipping hook")
                continue

            hook = self._create_loss_hook(module_name)
            handle = module.register_forward_hook(hook)
            handles.append(handle)

        print(f"Registered {len(handles)} forward hooks")
        return handles

    def remove_hooks(self):
        """Remove all hooks."""
        for handle in self.handles:
            handle.remove()
        for handle in self.float_handles:
            handle.remove()
        print("Removed all hooks")

    def __call__(self, q_out, batch):

        losses = torch.stack(list(self.per_module_loss.values()))
        loss = torch.sum(losses * self.weights.to(batch.device))
        self._clear_module_losses()
        return loss


def get_loss_function(loss_function: str, distance_metric: str, model, **kwargs):
    """
    Factory function to create loss function based on configuration.

    Args:
        loss_function: Type of loss function (e.g., "output_distillation")
        distance_metric: Distance metric to use (e.g., "kl", "mse")
        model: Float model for distillation
        **kwargs: Additional arguments specific to the loss function (e.g., pad_id)

    Returns:
        Loss function instance
    """
    if loss_function == LossFunction.OUTPUT_DISTILLATION.value:
        return OutDistill(model=model, distance_metric=distance_metric, **kwargs)
    elif loss_function == LossFunction.UNEMBED_DISTILLATION.value:
        return UnembedDistill(model=model, distance_metric=distance_metric, **kwargs)
    elif loss_function == LossFunction.FLAT_Q_DISTILLATION.value:
        return FlatQDistill(model=model, distance_metric=distance_metric, **kwargs)
    else:
        raise ValueError(f"Unsupported loss function: {loss_function}")
