from __future__ import annotations
from collections import deque
from typing import Iterable, List, Optional, Tuple
import torch


class Lagrange:
    """Lagrange multiplier for constrained optimization.
    
    Args:
        cost_limit: the cost limit
        lagrangian_multiplier_init: the initial value of the lagrangian multiplier
        lagrangian_multiplier_lr: the learning rate of the lagrangian multiplier
        lagrangian_upper_bound: the upper bound of the lagrangian multiplier

    Attributes:
        cost_limit: the cost limit  
        lagrangian_multiplier_lr: the learning rate of the lagrangian multiplier
        lagrangian_upper_bound: the upper bound of the lagrangian multiplier
        _lagrangian_multiplier: the lagrangian multiplier
        lambda_range_projection: the projection function of the lagrangian multiplier
        lambda_optimizer: the optimizer of the lagrangian multiplier    
    """

    def __init__(
        self,
        cost_limit: float,
        lagrangian_multiplier_init: float,
        lagrangian_multiplier_lr: float,
        lagrangian_upper_bound: float | None = None,
    ) -> None:
        """Initialize an instance of :class:`Lagrange`."""
        self.cost_limit: float = cost_limit
        self.lagrangian_multiplier_lr: float = lagrangian_multiplier_lr
        self.lagrangian_upper_bound: float | None = lagrangian_upper_bound

        init_value = max(lagrangian_multiplier_init, 0.0)
        self._lagrangian_multiplier: torch.nn.Parameter = torch.nn.Parameter(
            torch.as_tensor(init_value),
            requires_grad=True,
        )
        self.lambda_range_projection: torch.nn.ReLU = torch.nn.ReLU()

        self.lambda_optimizer: torch.optim.Optimizer = torch.optim.Adam(
            [
                self._lagrangian_multiplier,
            ],
            lr=lagrangian_multiplier_lr,
        )

    @property
    def lagrangian_multiplier(self) -> torch.Tensor:
        """The lagrangian multiplier.
        
        Returns:
            the lagrangian multiplier
        """
        return self.lambda_range_projection(self._lagrangian_multiplier).detach().item()

    def compute_lambda_loss(self, mean_ep_cost: float) -> torch.Tensor:
        """Compute the loss of the lagrangian multiplier.
        
        Args:
            mean_ep_cost: the mean episode cost
            
        Returns:
            the loss of the lagrangian multiplier
        """

        return -self._lagrangian_multiplier * (mean_ep_cost)

    def update_lagrange_multiplier(self, Jc: float) -> None:
        """Update the lagrangian multiplier.
        
        Args:
            Jc: the mean episode cost
            
        Returns:
            the loss of the lagrangian multiplier
        """
        self.lambda_optimizer.zero_grad()
        lambda_loss = self.compute_lambda_loss(Jc)
        lambda_loss.backward()
        self.lambda_optimizer.step()
        self._lagrangian_multiplier.data.clamp_(
            0.0,
            self.lagrangian_upper_bound,
        ) 


class HCRL(Lagrange):
    """
    HCRL algorithm for balancing rewards and constraints in constrained optimization.

    Args:
        cost_limit: Cost limit for constraints.
        lagrangian_multiplier_init: Initial value of the Lagrangian multiplier.
        lagrangian_multiplier_lr: Learning rate for Lagrangian multiplier updates.
        lagrangian_upper_bound: Upper bound for the Lagrangian multiplier.
        eta: Learning rate for parameter updates.
        beta: Learning rate for Lagrangian multiplier updates.
        c_init: Initial value for the scaling parameter c.
        c_init: Decay rate for scaling parameter c.
        alpha_lr: Learning rate for alpha optimization.
        alpha_iters: Number of iterations for alpha optimization.

    Attributes:
        c: Scaling parameter.
        c_init: Decay rate for scaling parameter.
    """

    def __init__(
        self,
        cost_limit: float,
        lagrangian_multiplier_init: float,
        lagrangian_multiplier_lr: float,
        lagrangian_upper_bound: float,
        eta: float,
        beta: float,
        c_init: float = 0.9,
        alpha_lr: float = 0.01,
        alpha_iters: int = 20,
    ):
        super().__init__(
            cost_limit, lagrangian_multiplier_init, lagrangian_multiplier_lr, lagrangian_upper_bound
        )
        self.eta = eta
        self.beta = beta
        self.c = c_init
        self.c_init = c_init
        self.alpha_lr = alpha_lr
        self.alpha_iters = alpha_iters

    def optimize_alpha(self, g1: torch.Tensor, g2: torch.Tensor, g0: torch.Tensor) -> float:
        """
        Optimize alpha using gradient descent for the full objective.

        Args:
            g1: Gradient of the objective function.
            g2: Gradient of the constraint function.
            g0: Base gradient.

        Returns:
            Optimal alpha value.
        """
        device_str = g1.device
        alpha = torch.tensor(0.5).to(device_str)  
        alpha.requires_grad_()
        optimizer = torch.optim.Adam([alpha], lr=self.alpha_lr)

        # Pre-compute norms and dot products for efficiency
        g1_norm = torch.norm(g1).detach()
        g2_norm = torch.norm(g2).detach()
        g0_norm = torch.norm(g0).detach()
        dot_g1_g2 = torch.dot(g1, g2).detach()
        g0 = g0.detach()
        g1 = g1.detach()
        g2 = g2.detach()

        for alpha_iter in range(self.alpha_iters):
            optimizer.zero_grad()

            # Compute g_alpha and its norm
            g_alpha = alpha * g1 + (1 - alpha) * g2
            g_alpha_norm = torch.sqrt(
                g1_norm**2 * alpha**2
                + g2_norm**2 * (1 - alpha)**2
                + 2 * alpha * (1 - alpha) * dot_g1_g2
            )
            
            # Compute the loss
            loss = g_alpha@g0 + self.c * g0_norm * g_alpha_norm
            loss.backward()
            optimizer.step()

            # Clamp alpha to [0, 1]
            with torch.no_grad():
                alpha.clamp_(0, 1)

        return alpha.detach().float()

    def update_parameters(
        self, g1: torch.Tensor, g2: torch.Tensor
    ) -> torch.Tensor:
        """
        Update parameters based on the HCRL algorithm.

        Args:
            theta: Current parameters.
            g1: Gradient of the objective function.
            g2: Gradient of the constraint function.
            Jc: Current constraint violation.

        Returns:
            Updated parameters theta.
        """

        # Base gradient calculation
        g0 = g1 + self._lagrangian_multiplier * g2

        # Optimize alpha and compute combined gradient
        alpha_star = self.optimize_alpha(g1, g2, g0)
        g_alpha_star = alpha_star * g1 + (1 - alpha_star) * g2

        # Compute lambda
        lambda_k = torch.norm(g_alpha_star) / (2 * self.c * torch.norm(g0) + 1e-8)

        # Compute update direction
        direction = g0 

        # Update scaling parameter c
        cos_theta = torch.dot(g1, g2) / (torch.norm(g1) * torch.norm(g2) + 1e-8)
        cos_theta = torch.clamp(cos_theta, -1, 1)
        theta_bar = torch.arccos(cos_theta)
        self.c = self.c_init * torch.sin(theta_bar / 2)

        return direction



class PIDLagrangian:

    """PID Lagrangian multiplier for constrained optimization.

    Args:
        cost_limit: the cost limit
        lagrangian_multiplier_init: the initial value of the lagrangian multiplier
        pid_kp: the proportional gain of the PID controller
        pid_ki: the integral gain of the PID controller
        pid_kd: the derivative gain of the PID controller
        pid_d_delay: the delay of the derivative term
        pid_delta_p_ema_alpha: the exponential moving average alpha of the delta_p
        pid_delta_d_ema_alpha: the exponential moving average alpha of the delta_d
        sum_norm: whether to normalize the sum of the PID output
        diff_norm: whether to normalize the difference of the PID output
        penalty_max: the maximum value of the penalty

    Attributes:
        cost_limit: the cost limit
        lagrangian_multiplier_init: the initial value of the lagrangian multiplier
        pid_kp: the proportional gain of the PID controller
        pid_ki: the integral gain of the PID controller
        pid_kd: the derivative gain of the PID controller
        pid_d_delay: the delay of the derivative term
        pid_delta_p_ema_alpha: the exponential moving average alpha of the delta_p
        pid_delta_d_ema_alpha: the exponential moving average alpha of the delta_d
        sum_norm: whether to normalize the sum of the PID output
        diff_norm: whether to normalize the difference of the PID output
        penalty_max: the maximum value of the penalty

    References:
        - Title: Responsive Safety in Reinforcement Learning by PID Lagrangian Methods
        - Authors: Adam Stooke, Joshua Achiam, Pieter Abbeel.
        - URL: `CPPOPID <https://arxiv.org/abs/2007.03964>`_
    """
    
    def __init__(
        self,
        cost_limit: float,
        lagrangian_multiplier_init: float,
        pid_kp: float,
        pid_ki: float,
        pid_kd: float,
        pid_d_delay: int,
        pid_delta_p_ema_alpha: float,
        pid_delta_d_ema_alpha: float,
        sum_norm: bool=True,
        diff_norm: bool=False,
        penalty_max: int,
    ) -> None:
        """Initialize an instance of :class:`PIDLagrangian`."""
        self._pid_kp: float = pid_kp
        self._pid_ki: float = pid_ki
        self._pid_kd: float = pid_kd
        self._pid_d_delay = pid_d_delay
        self._pid_delta_p_ema_alpha: float = pid_delta_p_ema_alpha
        self._pid_delta_d_ema_alpha: float = pid_delta_d_ema_alpha
        self._penalty_max: int = penalty_max
        self._sum_norm: bool = sum_norm
        self._diff_norm: bool = diff_norm
        self._pid_i: float = lagrangian_multiplier_init
        self._cost_ds: deque[float] = deque(maxlen=self._pid_d_delay)
        self._cost_ds.append(0.0)
        self._delta_p: float = 0.0
        self._cost_d: float = 0.0
        self._cost_limit: float = cost_limit
        self._cost_penalty: float = 0.1

    @property
    def lagrangian_multiplier(self) -> float:
        """The lagrangian multiplier."""
        return self._cost_penalty

    def update_lagrange_multiplier(self, ep_cost_avg: float) -> None:
        delta = float(ep_cost_avg - self._cost_limit)
        self._pid_i = max(0.0, self._pid_i + delta * self._pid_ki)
        if self._diff_norm:
            self._pid_i = max(0.0, min(1.0, self._pid_i))
        a_p = self._pid_delta_p_ema_alpha
        self._delta_p *= a_p
        self._delta_p += (1 - a_p) * delta
        a_d = self._pid_delta_d_ema_alpha
        self._cost_d *= a_d
        self._cost_d += (1 - a_d) * float(ep_cost_avg)
        pid_d = max(0.0, self._cost_d - self._cost_ds[0])
        pid_o = self._pid_kp * self._delta_p + self._pid_i + self._pid_kd * pid_d
        self._cost_penalty = max(0.0, pid_o)
        if self._diff_norm:
            self._cost_penalty = min(1.0, self._cost_penalty)
        if not (self._diff_norm or self._sum_norm):
            self._cost_penalty = min(self._cost_penalty, self._penalty_max)
        self._cost_penalty = min(self._cost_penalty, 0.01)
        self._cost_penalty = max(self._cost_penalty, 0.001)
        self._cost_ds.append(self._cost_d)


def apply_grad_vector_to_params(
    model_params: Iterable[torch.Tensor], grad_vector: torch.Tensor, accumulate: bool = False
    ):
    """Apply gradient vector to model parameters.

    Args:
        model_params (Iterable[torch.Tensor]): Iterable of model parameter tensors.
        grad_vector (torch.Tensor): A single vector representing the gradients.
        accumulate (bool): Whether to accumulate the gradients or overwrite them.
    """
    # Ensure grad_vector is of type Tensor
    if not isinstance(grad_vector, torch.Tensor):
        raise TypeError(f"Expected torch.Tensor, but got: {type(grad_vector).__name__}")

    # Pointer for slicing the gradient vector for each parameter
    pointer = 0
    for param in model_params:
        num_elements = param.numel()
        # Slice the vector and reshape it to match the parameter's shape
        if accumulate:
            param.grad = (param.grad + grad_vector[pointer:pointer + num_elements].view_as(param).data)
        else:
            param.grad = grad_vector[pointer:pointer + num_elements].view_as(param).data

        # Update the pointer
        pointer += num_elements