import torch
from torch import nn
import torch.distributed as dist
from typing import Iterable, Dict, Any
from torch.optim import Optimizer

class DPOPG_Optimizer(Optimizer):
    """
    Implements the DPO-Projected Gradient (DPO-PG) Algorithm (See Appendix H).
    
    This class is not a standalone optimizer but rather a gradient processor.
    It should be used *before* an actual optimizer's `step()` method (e.g., AdamW).
    It first accumulates gradients derived from chosen samples and
    rejected samples separately. Then, in the `set_gradients` method,
    it computes the final gradients to be used by the main optimizer.

    Parameters:
        params (`Iterable[nn.parameter.Parameter]`):
            Iterable of parameters to optimize or dictionaries defining parameter groups.
            These are the parameters whose gradients will be processed.
        dtype (`torch.dtype`, optional, defaults to `torch.bfloat16`):
            The data type to use for storing intermediate accumulated gradients.
        fsdp_enabled (`bool`, optional, defaults to `False`):
            Whether Fully Sharded Data Parallel (FSDP) is being used. If `True`,
            gradient statistics (dot products, norms) will be aggregated across
            all distributed processes using `torch.distributed.all_reduce`.
        main_device (`torch.device`, optional, defaults to `torch.device('cuda:0')`):
            The primary device used for creating tensors for aggregated statistics
            (like dot products and norms) before potential FSDP reduction.
    """

    def __init__(
        self,
        params: Iterable[nn.parameter.Parameter],
        dtype: torch.dtype = torch.bfloat16,
        fsdp_enabled: bool = False,
        main_device: torch.device = None,
    ):
        self.fsdp_enabled = fsdp_enabled
        self.main_device = main_device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.dtype = dtype
        super().__init__(params, dict())

    @torch.no_grad()
    def update_chosen_grad(self):
        """
        Extracts and accumulates the model's gradients after back-propagating
        the loss for chosen samples.

        These gradients are stored internally in `self.state[p]["chosen_grad"]`
        for each parameter `p`. If called multiple times (e.g., gradient accumulation),
        new gradients are added to existing ones.
        """
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                # Detach and clone the gradient to prevent further modifications by autograd
                # and to store it with the specified dtype.
                grad_clone = p.grad.detach().clone().to(self.dtype)
                state = self.state[p]
                if "chosen_grad" in state and state["chosen_grad"] is not None:
                    state["chosen_grad"].add_(grad_clone)
                else:
                    state["chosen_grad"] = grad_clone

    @torch.no_grad()
    def update_rejected_grad(self):
        """
        Extracts and accumulates the model's gradients after back-propagating
        the loss for rejected samples.

        These gradients are stored internally in `self.state[p]["rejected_grad"]`
        for each parameter `p`. If called multiple times, new gradients are added.
        """
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                # Detach and clone the gradient.
                grad_clone = p.grad.detach().clone().to(self.dtype)
                state = self.state[p]
                if "rejected_grad" in state and state["rejected_grad"] is not None:
                    state["rejected_grad"].add_(grad_clone)
                else:
                    state["rejected_grad"] = grad_clone

    @torch.no_grad()
    def set_gradients(self) -> Dict[str, Any]:
        """
        Computes the final gradients based on the accumulated chosen and rejected
        gradients, applying the projection logic. Sets `p.grad` for each parameter
        to these final gradients.

        The update rule is:
            θ_{k+1} = θ_k - η * (∇L(y_w) - (max(0, ∇L(y_w)⋅∇L(y_l)) / {∇L(y_l)⋅∇L(y_l)}) * ∇L(y_l))

        where:
            - θ_k      : model parameters at step k
            - η        : step size (learning rate)
            - ∇L(y_w)  : gradient of the negative log-likelihood of the chosen samples y_w
            - ∇L(y_l)  : gradient of the negative log-likelihood of the rejected samples y_l
            - ⋅        : vector dot product

        This method should be called after all chosen and rejected gradients for a
        batch (or accumulation steps) have been collected via `update_chosen_grad`
        and `update_rejected_grad`, and after `model.zero_grad()` has been called
        to clear any previous `p.grad` values that were set by `backward()`.

        Returns:
            Dict[str, Any]: A dictionary containing metrics:
                - "dot_prod": The dot product $∇L(y_w) \cdot ∇L(y_l)$.
                - "rejected_grad_norm": The L2 norm of $∇L(y_l)$.
                - "update_norm": The L2 norm of the final gradient written to `p.grad`.
        """
        metrics = {}
        chosen_grad_key = "chosen_grad"
        rejected_grad_key = "rejected_grad"

        # Calculate dot product and norms of the *total accumulated* chosen and rejected gradients.
        dot_prod = self._dot_prod(chosen_grad_key, rejected_grad_key)
        rejected_grad_squared_norm = self._grad_squared_norm(rejected_grad_key)

        metrics["dot_prod"] = [dot_prod] # Stored as a list for consistent metric aggregation
        metrics["rejected_grad_norm"] = [rejected_grad_squared_norm**(1/2) if rejected_grad_squared_norm > 0 else 0.0]

        agg_update_norm = torch.tensor(0.0, dtype=self.dtype, device=self.main_device)

        # Apply projection logic
        if dot_prod > 0 and rejected_grad_squared_norm > 1e-9: # Add epsilon to avoid division by zero if rejected_grad is zero
            ratio = -dot_prod / rejected_grad_squared_norm
            for group in self.param_groups:
                for p in group["params"]:
                    state = self.state[p]
                    # Ensure both gradients are present.
                    if chosen_grad_key not in state or state[chosen_grad_key] is None:
                        assert rejected_grad_key not in state or state[rejected_grad_key] is None, \
                               "Rejected grad present while chosen grad is missing."
                        continue
                    
                    chosen_grad = state[chosen_grad_key]
                    rejected_grad = state[rejected_grad_key]

                    # Perform the projection: chosen_grad = chosen_grad + ratio * rejected_grad
                    chosen_grad.add_(rejected_grad, alpha=ratio)
                    
                    p.grad = chosen_grad.to(p.dtype)
                    agg_update_norm += (chosen_grad.float()**2).sum()
        else:
            # If dot_prod <= 0 (gradients are opposing/orthogonal) or rejected_grad_norm is zero,
            # use the chosen_grad as is.
            for group in self.param_groups:
                for p in group["params"]:
                    state = self.state[p]
                    if chosen_grad_key not in state or state[chosen_grad_key] is None:
                        continue

                    chosen_grad = state[chosen_grad_key]
                    p.grad = chosen_grad.to(p.dtype)
                    agg_update_norm += (chosen_grad.float()**2).sum()

        # Clears the internally stored chosen_grad and rejected_grad
        self._reset_gradients()

        if self.fsdp_enabled:
            dist.all_reduce(agg_update_norm, op=dist.ReduceOp.SUM)

        metrics["update_norm"] = [float(agg_update_norm**(1/2))]
        return metrics

    @torch.no_grad()
    def _reset_gradients(self):
        """
        Clears the internally stored "chosen_grad" and "rejected_grad" from the
        optimizer's state for all parameters. User should NOT call this, rather, 
        this is called inside the `set_gradients` method to prepare for the
        next round of gradient accumulation.
        """
        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                if "chosen_grad" in state:
                    state['chosen_grad'] = None
                if "rejected_grad" in state:
                    state['rejected_grad'] = None

    @torch.no_grad()
    def _dot_prod(self, key_1: str, key_2: str) -> float:
        """
        Computes the global dot product of two sets of gradients stored in the state.
        Handles FSDP reduction if enabled.

        Args:
            key_1 (str): State key for the first set of gradients.
            key_2 (str): State key for the second set of gradients.

        Returns:
            float: The global dot product.
        """
        agg_result = torch.tensor(0.0, dtype=self.dtype, device=self.main_device)
        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                # If either gradient is missing for this param, skip its contribution.
                if key_1 not in state or state[key_1] is None or \
                   key_2 not in state or state[key_2] is None:
                    continue
                
                grad1 = state[key_1]
                grad2 = state[key_2]
                agg_result += (grad1.float() * grad2.float()).sum() # Use float for precision in sum

        if self.fsdp_enabled:
            dist.all_reduce(agg_result, op=dist.ReduceOp.SUM)
        return float(agg_result)

    @torch.no_grad()
    def _grad_squared_norm(self, key: str) -> float:
        """
        Computes the global squared L2 norm of gradients stored under the given key.
        Handles FSDP reduction if enabled.

        Args:
            key (str): State key for the gradients.

        Returns:
            float: The global squared L2 norm.
        """
        agg_result = torch.tensor(0.0, dtype=self.dtype, device=self.main_device)
        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                if key not in state or state[key] is None:
                    continue
                
                grad = state[key]
                agg_result += (grad.float()**2).sum() # Use float for precision

        if self.fsdp_enabled:
            dist.all_reduce(agg_result, op=dist.ReduceOp.SUM)
        return float(agg_result)