from abc import abstractmethod
from dataclasses import dataclass
from typing import Dict, Iterator, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.func import functional_call, grad, vmap
from torch.nn import CrossEntropyLoss


class GradBase:
    """Base class for gradient-based unlearning methods."""

    def __init__(
        self,
        lr: float,
        loss: nn.Module,
        optimizer_unlearn: Optional[torch.optim.Optimizer] = None,
        optimizer_retain: Optional[torch.optim.Optimizer] = None,
    ):
        """
        Initialize the gradient-based unlearning method.

        Args:
            lr: Learning rate for the unlearning process
            loss: Loss function to use
            optimizer_unlearn: Optimizer for unlearning data (optional)
            optimizer_retain: Optimizer for retain data (optional)
        """
        self.lr = lr
        self.loss = loss
        self.optimizer_unlearn = optimizer_unlearn
        self.optimizer_retain = optimizer_retain

        assert self.lr > 0, "Learning rate must be greater than 0"

    @staticmethod
    def _extract_gradients(params: Iterator) -> List[torch.Tensor]:
        """Extract gradients from model parameters."""
        grads = []
        for param in params:
            grads.append(param.grad.clone() if param.grad is not None else None)
        return grads

    def _compute_per_sample_grads(
        self,
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        data: torch.Tensor,
        labels: torch.Tensor,
        original_model: Optional[nn.Module] = None,
    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
        """
        Compute per-sample gradients using torch.func vmap.

        Args:
            model: The model to compute gradients for
            optimizer: The optimizer to use (optional)
            data: Input data
            labels: Target labels
            original_model: Original model for knowledge distillation (optional)

        Returns:
            Tuple of (per_sample_gradients, per_sample_losses)
        """
        model.zero_grad()
        if optimizer is not None:
            optimizer.zero_grad()

        params = {k: v.detach() for k, v in model.named_parameters() if v.requires_grad}
        buffers = {k: v.detach() for k, v in model.named_buffers() if v.requires_grad}

        if original_model is not None:
            original_params = {
                k: v.detach()
                for k, v in original_model.named_parameters()
                if v.requires_grad
            }
            original_buffers = {
                k: v.detach()
                for k, v in original_model.named_buffers()
                if v.requires_grad
            }

        def compute_loss(
            params: Dict[str, torch.Tensor],
            buffers: Dict[str, torch.Tensor],
            sample: torch.Tensor,
            target: torch.Tensor,
        ) -> torch.Tensor:
            """Compute loss for a single sample."""
            target = target.unsqueeze(0)
            batch = sample.unsqueeze(0)

            predictions = functional_call(model, (params, buffers), (batch,))

            if isinstance(self.loss, CrossEntropyLoss):
                loss = self.loss(predictions, target)
            elif hasattr(self.loss, "__name__") and self.loss.__name__ == "DistillKL":
                assert (
                    original_model is not None
                ), "Original model is required for knowledge distillation"
                with torch.no_grad():
                    target_logits = functional_call(
                        original_model,
                        (original_params, original_buffers),
                        (batch,),
                    )
                loss = self.loss(predictions, target_logits)
            else:
                raise ValueError(f"Unsupported loss function: {self.loss}")

            return loss

        # Define a function to compute both loss and gradients
        def compute_loss_and_grad(
            params: Dict[str, torch.Tensor],
            buffers: Dict[str, torch.Tensor],
            sample: torch.Tensor,
            target: torch.Tensor,
        ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
            """Compute loss and gradients for a single sample."""
            loss = compute_loss(params, buffers, sample, target)
            grads = grad(compute_loss)(params, buffers, sample, target)
            return loss, grads

        # Vectorize the computation over all samples in the batch
        ft_compute_sample_grad = vmap(compute_loss_and_grad, in_dims=(None, None, 0, 0))
        per_sample_loss, per_sample_grads = ft_compute_sample_grad(
            params, buffers, data, labels
        )

        # Convert per_sample_grads from dict to list if needed
        if isinstance(per_sample_grads, dict):
            per_sample_grads = list(per_sample_grads.values())

        return per_sample_grads, per_sample_loss

    def _compute_grads(
        self,
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        data: torch.Tensor,
        labels: torch.Tensor,
        grad_mode: str,
        original_model: Optional[nn.Module] = None,
    ) -> Tuple[Union[Dict[str, torch.Tensor], List[torch.Tensor]], torch.Tensor]:
        """
        Compute gradients using either mean or per-sample mode.

        Args:
            model: The model to compute gradients for
            optimizer: The optimizer to use
            data: Input data
            labels: Target labels
            grad_mode: "mean" or "per_sample"
            original_model: Original model for knowledge distillation

        Returns:
            Tuple of (gradients, loss)
        """
        if grad_mode == "mean":
            return self._compute_agg_grads_and_loss(
                model, optimizer, data, labels, original_model
            )
        else:  # grad_mode == "per_sample"
            per_sample_grads, per_sample_loss = self._compute_per_sample_grads(
                model, optimizer, data, labels, original_model
            )
            return per_sample_grads, per_sample_loss.mean()

    def _compute_loss_and_grads(
        self,
        model: nn.Module,
        unlearn_data: torch.Tensor,
        unlearn_labels: torch.Tensor,
        retain_data: torch.Tensor,
        retain_labels: torch.Tensor,
        retain_grad_mode: str,
        original_model: Optional[nn.Module] = None,
    ) -> Tuple[
        List[torch.Tensor],
        Union[Dict[str, torch.Tensor], List[torch.Tensor]],
        torch.Tensor,
        torch.Tensor,
    ]:
        """
        Compute both unlearn and retain gradients.

        Args:
            model: The model to compute gradients for
            unlearn_data: Data to unlearn
            unlearn_labels: Labels for unlearn data
            retain_data: Data to retain
            retain_labels: Labels for retain data
            retain_grad_mode: "mean" or "per_sample" for retain gradients
            original_model: Original model for knowledge distillation

        Returns:
            Tuple of (unlearn_grads, retain_grads, unlearn_loss, retain_loss)
        """
        unlearn_grads, unlearn_loss = self._compute_grads(
            model,
            self.optimizer_unlearn,
            unlearn_data,
            unlearn_labels,
            grad_mode="mean",
            original_model=original_model,
        )
        retain_grads, retain_loss = self._compute_grads(
            model,
            self.optimizer_retain,
            retain_data,
            retain_labels,
            retain_grad_mode,
            original_model,
        )
        return unlearn_grads, retain_grads, unlearn_loss, retain_loss

    def _compute_agg_grads_and_loss(
        self,
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        data_input: torch.Tensor,
        labels: torch.Tensor,
        original_model: Optional[nn.Module] = None,
    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
        """
        Compute aggregated gradients and loss.

        Args:
            model: The model to compute gradients for
            optimizer: The optimizer to use
            data_input: Input data
            labels: Target labels
            original_model: Original model for knowledge distillation

        Returns:
            Tuple of (gradients, loss)
        """
        model.zero_grad()
        if optimizer is not None:
            optimizer.zero_grad()

        logits = model(data_input)

        if isinstance(self.loss, CrossEntropyLoss):
            loss = self.loss(logits, labels)
        elif hasattr(self.loss, "__name__") and self.loss.__name__ == "DistillKL":
            assert (
                original_model is not None
            ), "Original model is required for knowledge distillation"
            with torch.no_grad():
                target_logits = original_model(data_input)
            loss = self.loss(logits, target_logits)
        else:
            raise ValueError(f"Unsupported loss function: {self.loss}")

        loss.backward()
        grads = self._extract_gradients(model.parameters())

        return grads, loss

    def _update_model_params(
        self, params: Iterator, grads: List[torch.Tensor], mode: str = "decent"
    ) -> None:
        """
        Update model parameters based on gradients.

        Args:
            params: Model parameters to update
            grads: Gradients to apply
            mode: "decent" for gradient descent, "accent" for gradient ascent
        """
        if mode == "decent":
            sign = -1
        else:  # mode == "ascent"
            sign = 1

        for p, g in zip((p for p in params if p.requires_grad), grads):
            if g is not None and not torch.isnan(g).any():
                assert p.shape == g.shape, "Shape mismatch model param and gradient"
                p.data += sign * self.lr * g

    @staticmethod
    def _check_for_nan_gradients(unlearn_grads, retain_grads):
        """Check if there are any NaN values in the gradients."""
        for g_u, g_r in zip(unlearn_grads, retain_grads):
            if torch.isnan(g_u).any() or torch.isnan(g_r).any():
                raise ValueError("The gradients exploded!")

    @abstractmethod
    def __call__(self, model, unlearn_data, unlearn_labels, retain_data, retain_labels):
        """Apply the unlearning method to the model."""
        pass


class OrthoGrad(GradBase):
    """
    Orthogonal Gradient-based unlearning method that projects unlearn gradients
    to be orthogonal to retain gradients.
    """

    def __init__(
        self,
        lr: float,
        loss: nn.Module,
        optimizer_unlearn: Optional[torch.optim.Optimizer] = None,
        optimizer_retain: Optional[torch.optim.Optimizer] = None,
        retain_grad_mode: str = "per_sample",
        update_mode: str = "both",
        original_model: Optional[nn.Module] = None,
        grad_mask: Optional[List[torch.Tensor]] = None,
        alpha: float = 0.5,
    ):
        """
        Initialize the OrthogonalGrad unlearning method.

        Args:
            lr: Learning rate for the unlearning process
            loss: Loss function to use
            optimizer_unlearn: Optimizer for unlearning data
            optimizer_retain: Optimizer for retain data
            retain_grad_mode: "mean" or "per_sample" for retain gradients
            update_mode: "both" or "accent"
            original_model: Original model for knowledge distillation
            grad_mask: Mask to apply to gradients
            alpha: Weighting parameter for combining retain and unlearn gradients
        """
        assert update_mode in (
            "both",
            "accent",
        ), "update_mode must be 'both' or 'accent'"
        assert retain_grad_mode in [
            "mean",
            "per_sample",
        ], "retain_grad_mode must be 'mean' or 'per_sample'"

        self.alpha = alpha
        self.update_mode = update_mode
        self.retain_grad_mode = retain_grad_mode
        self.original_model = original_model
        self.grad_mask = grad_mask

        super().__init__(lr, loss, optimizer_unlearn, optimizer_retain)

    @staticmethod
    def _project_orthogonal(
        unlearn_grads: torch.Tensor, retain_grads: torch.Tensor
    ) -> torch.Tensor:
        """
        Project unlearn gradients to be orthogonal to retain gradients.

        Args:
            unlearn_grads: Gradients from unlearn data
            retain_grads: Gradients from retain data

        Returns:
            Projected unlearn gradients
        """
        # Start with v_proj as the original vector v
        unlearn_grads_proj = unlearn_grads.clone()

        # Iterate through each vector in retain_grads
        for g_i in retain_grads:
            # Skip if the retain gradient is zero
            if g_i.norm() < 1e-10:
                continue

            # Calculate projection of unlearn_grads onto g_i
            projection = torch.dot(unlearn_grads_proj, g_i) / (g_i.norm() ** 2) * g_i

            # Subtract the projection to make unlearn_grads orthogonal to g_i
            unlearn_grads_proj -= projection

        return unlearn_grads_proj

    def _per_sample_projection(
        self, unlearn_grads: torch.Tensor, retain_grads: torch.Tensor
    ) -> torch.Tensor:
        """
        Apply per-sample projection using QR decomposition.

        Args:
            unlearn_grads: Gradients from unlearn data
            retain_grads: Gradients from retain data

        Returns:
            Orthogonalized unlearn gradients
        """
        orig_shape = unlearn_grads.shape

        # Reshape retain_grads based on retain_grad_mode
        retain_grads = (
            retain_grads.flatten(1)
            if self.retain_grad_mode == "per_sample"
            else retain_grads.flatten().unsqueeze(0)
        )

        # Flatten unlearn_grads
        unlearn_grads = unlearn_grads.flatten()

        # Perform QR decomposition on retain_grads
        q, r = torch.linalg.qr(retain_grads.T)

        # Project unlearn_grads to be orthogonal to retain_grads
        orthogonal_unlearn_grads = self._project_orthogonal(unlearn_grads, q.T)

        # Reshape back to original shape
        return orthogonal_unlearn_grads.reshape(orig_shape)

    def __call__(
        self,
        model: nn.Module,
        unlearn_data: torch.Tensor,
        unlearn_labels: torch.Tensor,
        retain_data: torch.Tensor,
        retain_labels: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply the orthogonal gradient unlearning method.

        Args:
            model: The model to unlearn from
            unlearn_data: Data to unlearn
            unlearn_labels: Labels for unlearn data
            retain_data: Data to retain
            retain_labels: Labels for retain data

        Returns:
            Tuple of (unlearn_loss, retain_loss)
        """
        # Compute unlearn and retain gradients
        (
            unlearn_grads,
            retain_grads,
            unlearn_loss,
            retain_loss,
        ) = self._compute_loss_and_grads(
            model,
            unlearn_data,
            unlearn_labels,
            retain_data,
            retain_labels,
            self.retain_grad_mode,
            self.original_model,
        )

        # Apply gradient masks if provided
        if self.grad_mask is not None:
            for g_u, g_r, g_m in zip(unlearn_grads, retain_grads, self.grad_mask):
                g_u *= g_m
                g_r *= ~g_m

        # Check for NaN gradients
        self._check_for_nan_gradients(unlearn_grads, retain_grads)

        # Compute orthogonal unlearn gradients
        unlearn_orthogonal_grads = []
        for g_u, g_r in zip(unlearn_grads, retain_grads):
            unlearn_orthogonal_grads.append(self._per_sample_projection(g_u, g_r))

        # Update model parameters
        if self.alpha == 0:
            # Only use unlearn gradients
            self._update_model_params(
                model.parameters(), unlearn_orthogonal_grads, mode="accent"
            )
        else:
            # Combine retain and unlearn gradients
            retain_averaged_grads = [
                g.mean(dim=0) if self.retain_grad_mode == "per_sample" else g
                for g in retain_grads
            ]

            combined_grads = []
            for g_r, g_u in zip(retain_averaged_grads, unlearn_orthogonal_grads):
                combined_grads.append(self.alpha * g_r - ((1 - self.alpha) * g_u))

            self._update_model_params(model.parameters(), combined_grads, mode="decent")

        return unlearn_loss, retain_loss


# Test the OrthogonalGrad implementation
if __name__ == "__main__":
    import torch.optim as optim

    # Set random seed for reproducibility
    torch.manual_seed(42)

    # Create a simple model
    class SimpleModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(10, 50)
            self.fc2 = nn.Linear(50, 20)
            self.fc3 = nn.Linear(20, 2)

        def forward(self, x):
            x = torch.relu(self.fc1(x))
            x = torch.relu(self.fc2(x))
            return self.fc3(x)

    # Create model
    model = SimpleModel()

    # Define loss function
    loss_fn = nn.CrossEntropyLoss()

    # Generate dummy data
    batch_size = 4
    unlearn_data = torch.randn(batch_size, 10)
    unlearn_labels = torch.randint(0, 2, (batch_size,))

    retain_data = torch.randn(batch_size, 10)
    retain_labels = torch.randint(0, 2, (batch_size,))

    # Initialize OrthogonalGrad
    ortho_grad = OrthoGrad(
        lr=0.01, loss=loss_fn, retain_grad_mode="per_sample", alpha=0.5
    )

    # Apply OrthogonalGrad
    unlearn_loss, retain_loss = ortho_grad(
        model, unlearn_data, unlearn_labels, retain_data, retain_labels
    )
