import numpy as np
import scipy.signal
from gym.spaces import Box, Discrete
from gymnasium.spaces import Box as SafeBox
from gymnasium.spaces import Discrete as SafeDiscrete

import torch
import torch.nn as nn
import torch.optim
from torch.distributions.normal import Normal
from collections import deque
from typing import Union, Optional

def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if torch.isscalar(shape) else (length, *shape)

def build_mlp_network(sizes):
    layers = list()
    for j in range(len(sizes) - 1):
        act = nn.Tanh if j < len(sizes) - 2 else None  
        affine_layer = nn.Linear(sizes[j], sizes[j + 1])
        nn.init.kaiming_uniform_(affine_layer.weight, a=np.sqrt(5))
        layers.append(affine_layer)
        if act is not None:  
            layers.append(act())
    
    # Separate output layers for safe_ratio and reward_ratio
    safe_output = nn.Sequential(nn.Linear(sizes[-1], 1), nn.Sigmoid())
    reward_output = nn.Sequential(nn.Linear(sizes[-1], 1), nn.Sigmoid())
    
    return nn.Sequential(*layers), safe_output, reward_output

class Actor(nn.Module):
    """
    Actor network for policy-based reinforcement learning.

    This class represents an actor network that outputs a distribution over actions given observations.

    Args:
        obs_dim (int): Dimensionality of the observation space.
        act_dim (int): Dimensionality of the action space.

    Attributes:
        mean (nn.Sequential): MLP network representing the mean of the action distribution.
        log_std (nn.Parameter): Learnable parameter representing the log standard deviation of the action distribution.

    Example:
        obs_dim = 10
        act_dim = 2
        actor = Actor(obs_dim, act_dim)
        observation = torch.randn(1, obs_dim)
        action_distribution = actor(observation)
    """

    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes: list = [64, 64]):
        super().__init__()
        self.mean = build_mlp_network([obs_dim]+hidden_sizes+[act_dim])
        self.log_std = nn.Parameter(torch.zeros(act_dim), requires_grad=True)

    def forward(self, obs: torch.Tensor):
        mean = self.mean(obs)
        std = torch.exp(self.log_std)
        return Normal(mean, std)

class VCritic(nn.Module):
    """
    Critic network for value-based reinforcement learning.

    This class represents a critic network that estimates the value function for input observations.

    Args:
        obs_dim (int): Dimensionality of the observation space.

    Attributes:
        critic (nn.Sequential): MLP network representing the critic function.

    Example:
        obs_dim = 10
        critic = VCritic(obs_dim)
        observation = torch.randn(1, obs_dim)
        value_estimate = critic(observation)
    """

    def __init__(self, obs_dim, hidden_sizes: list = [64, 64]):
        super().__init__()
        self.critic = build_mlp_network([obs_dim]+hidden_sizes+[1])

    def forward(self, obs):
        return torch.squeeze(self.critic(obs), -1)

class ActorVCritic(nn.Module):
    """
    Actor-critic policy for reinforcement learning.

    This class represents an actor-critic policy that includes an actor network, two critic networks for reward
    and cost estimation, and provides methods for taking policy steps and estimating values.

    Args:
        obs_dim (int): Dimensionality of the observation space.
        act_dim (int): Dimensionality of the action space.

    Example:
        obs_dim = 10
        act_dim = 2
        actor_critic = ActorVCritic(obs_dim, act_dim)
        observation = torch.randn(1, obs_dim)
        action, log_prob, reward_value, cost_value = actor_critic.step(observation)
        value_estimate = actor_critic.get_value(observation)
    """

    def __init__(self, obs_dim, act_dim, hidden_sizes: list = [64, 64]):
        super().__init__()
        self.reward_critic = VCritic(obs_dim, hidden_sizes)
        self.cost_critic = VCritic(obs_dim, hidden_sizes)
        self.actor = Actor(obs_dim, act_dim, hidden_sizes)

    def get_value(self, obs):
        """
        Estimate the value of observations using the critic network.

        Args:
            obs (torch.Tensor): Input observation tensor.

        Returns:
            torch.Tensor: Estimated value for the input observation.
        """
        return self.critic(obs)

    def step(self, obs, deterministic=False):
        """
        Take a policy step based on observations.

        Args:
            obs (torch.Tensor): Input observation tensor.
            deterministic (bool): Flag indicating whether to take a deterministic action.

        Returns:
            tuple: Tuple containing action tensor, log probabilities of the action, reward value estimate,
                   and cost value estimate.
        """

        dist = self.actor(obs)
        if deterministic:
            action = dist.mean
        else:
            action = dist.rsample()
        log_prob = dist.log_prob(action).sum(axis=-1)
        value_r = self.reward_critic(obs)
        value_c = self.cost_critic(obs)
        return action, log_prob, value_r, value_c


class Lagrange:
    """Lagrange multiplier for constrained optimization.
    
    Args:
        cost_limit: the cost limit
        lagrangian_multiplier_init: the initial value of the lagrangian multiplier
        lagrangian_multiplier_lr: the learning rate of the lagrangian multiplier
        lagrangian_upper_bound: the upper bound of the lagrangian multiplier

    Attributes:
        cost_limit: the cost limit  
        lagrangian_multiplier_lr: the learning rate of the lagrangian multiplier
        lagrangian_upper_bound: the upper bound of the lagrangian multiplier
        _lagrangian_multiplier: the lagrangian multiplier
        lambda_range_projection: the projection function of the lagrangian multiplier
        lambda_optimizer: the optimizer of the lagrangian multiplier    
    """

    # pylint: disable-next=too-many-arguments
    def __init__(
        self,
        cost_limit: float,
        lagrangian_multiplier_init: float,
        lagrangian_multiplier_lr: float,
        lagrangian_upper_bound: Optional[float] = None,
    ) -> None:
        """Initialize an instance of :class:`Lagrange`."""
        self.cost_limit: float = cost_limit
        self.lagrangian_multiplier_lr: float = lagrangian_multiplier_lr
        self.lagrangian_upper_bound: Optional[float] = lagrangian_upper_bound

        init_value = max(lagrangian_multiplier_init, 0.0)
        self._lagrangian_multiplier: torch.nn.Parameter = torch.nn.Parameter(
            torch.as_tensor(init_value),
            requires_grad=True,
        )
        self.lambda_range_projection: torch.nn.ReLU = torch.nn.ReLU()
        # fetch optimizer from PyTorch optimizer package
        self.lambda_optimizer: torch.optim.Optimizer = torch.optim.Adam(
            [
                self._lagrangian_multiplier,
            ],
            lr=lagrangian_multiplier_lr,
        )

    @property
    def lagrangian_multiplier(self) -> torch.Tensor:
        """The lagrangian multiplier.
        
        Returns:
            the lagrangian multiplier
        """
        return self.lambda_range_projection(self._lagrangian_multiplier).detach().item()

    def compute_lambda_loss(self, mean_ep_cost: float) -> torch.Tensor:
        """Compute the loss of the lagrangian multiplier.
        
        Args:
            mean_ep_cost: the mean episode cost
            
        Returns:
            the loss of the lagrangian multiplier
        """
        return -self._lagrangian_multiplier * (mean_ep_cost - self.cost_limit)

    def update_lagrange_multiplier(self, Jc: float) -> None:
        """Update the lagrangian multiplier.
        
        Args:
            Jc: the mean episode cost
            
        Returns:
            the loss of the lagrangian multiplier
        """
        self.lambda_optimizer.zero_grad()
        lambda_loss = self.compute_lambda_loss(Jc)
        lambda_loss.backward()
        self.lambda_optimizer.step()
        self._lagrangian_multiplier.data.clamp_(
            0.0,
            self.lagrangian_upper_bound,
        )  # enforce: lambda in [0, inf]

def calculate_adv_and_fuzzy_value_targets(
    values: torch.Tensor,
    rewards: torch.Tensor,
    fuzzy_value_next: torch.Tensor,
    lam: float,
    gamma: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    # GAE formula: A_t = \sum_{k=0}^{n-1} (lam*gamma)^k delta_{t+k}
    deltas = rewards[:-1] + gamma * values[1:] - values[:-1]
    adv = discount_cumsum(deltas, gamma * lam)
    target_fuzzy_value = rewards[:-1] + gamma * fuzzy_value_next[:-1]
    target_true_value = adv + values[:-1]
    return adv, target_fuzzy_value, target_true_value


def discount_cumsum(vector_x: torch.Tensor, discount: float) -> torch.Tensor:
    """
    Compute the discounted cumulative sum of a tensor along its first dimension.

    This function computes the discounted cumulative sum of the input tensor `vector_x` along
    its first dimension. The discount factor `discount` is applied to compute the weighted sum
    of future values. The resulting tensor has the same shape as the input tensor.

    Args:
        vector_x (torch.Tensor): Input tensor with shape `(length, ...)`.
        discount (float): Discount factor for future values.

    Returns:
        torch.Tensor: Tensor containing the discounted cumulative sum of `vector_x`.
    """
    length = vector_x.shape[0]
    vector_x = vector_x.type(torch.float64)
    cumsum = vector_x[-1]
    for idx in reversed(range(length - 1)):
        cumsum = vector_x[idx] + discount * cumsum
        vector_x[idx] = cumsum
    return vector_x

@torch.jit.script
def gbellmf(x, a, b):
    return 1 / (1 + torch.abs((x - 0.5) / (a+1e-5)) ** (2 * b))
@torch.jit.script
def sigmf(x, a):
    return 1 / (1 + torch.exp(-a * (x - 0.5)))
@torch.jit.script
def zmf(x, a):
    return 1 / (1 + torch.exp(a * (x - 0.5)))

# class ANFIS(nn.Module):
#     def __init__(self, out_level, state_dim, hidden_sizes: list = [64, 32], device="cuda:0"):
#         super(ANFIS, self).__init__()
#         self.out_level = out_level
#         self.shared_layers, self.safe_output, self.reward_output = build_mlp_network([state_dim]+hidden_sizes+[1])
#         self.fy = torch.linspace(-1, 1, out_level + 1).to(device)
        
#         # Input membership functions (state safety & reward level): low high
#         self.sigmf_f_reward = nn.Parameter(torch.rand(1, 1))  # High
#         self.zmf_f_reward = nn.Parameter(torch.rand(1, 1))    # Low
#         self.sigmf_f_safety = nn.Parameter(torch.rand(1, 1))
#         self.zmf_f_safety = nn.Parameter(torch.rand(1, 1))
#         # Output membership functions (transition disturb level): low mid high 
#         self.sigmf_f_y = nn.Parameter(torch.randn(1, 1)) 
#         self.gbellmf_f_y = nn.Parameter(torch.randn(2, 1)) 
#         self.zmf_f_y = nn.Parameter(torch.randn(1, 1)) 

#     def apply_rules(self, h_reward_lo, h_reward_hi, h_safe_lo, h_safe_hi):
#         batch = h_reward_lo.size(0)
        
#         h_y_lo = zmf(torch.abs(self.fy), torch.relu(self.zmf_f_y[0]))  # [level+1]
#         h_y_mid = gbellmf(torch.abs(self.fy), torch.relu(self.gbellmf_f_y[0]), torch.relu(self.gbellmf_f_y[1]))  # [level+1]
#         h_y_hi = sigmf(torch.abs(self.fy), torch.relu(self.sigmf_f_y[0]))  # [level+1]
        
#         w_hi = h_reward_lo * h_safe_hi
#         w_mid = h_reward_lo * h_safe_hi + h_reward_hi * h_safe_lo
#         w_lo = h_reward_hi * h_safe_lo
        
#         w_sum = w_hi + w_mid + w_lo
        
#         h_rule_high_disturb = torch.min(w_hi / w_sum, h_y_hi)
#         h_rule_medium_disturb = torch.min(w_mid / w_sum, h_y_mid)
#         h_rule_low_disturb = torch.min(w_lo / w_sum, h_y_lo)
        
#         aggregated = torch.max(h_rule_high_disturb, torch.max(h_rule_medium_disturb, h_rule_low_disturb)) # [batch, level+1]
#         normalized_aggregated = aggregated / (torch.sum(aggregated, dim=1, keepdim=True))
#         return h_rule_high_disturb, h_rule_medium_disturb, h_rule_low_disturb, normalized_aggregated

#     def forward(self, state):
#         batch_size = state.size(0)
#         shared_output = self.shared_layers(state)
#         safe_ratio = self.safe_output(shared_output)
#         reward_ratio = self.reward_output(shared_output)


#         h_reward_lo = zmf(reward_ratio, torch.relu(self.zmf_f_reward[0])).view(batch_size, -1)  
#         h_reward_hi = sigmf(reward_ratio, torch.relu(self.sigmf_f_reward[0])).view(batch_size, -1)  

#         h_safe_lo = zmf(safe_ratio, torch.relu(self.zmf_f_safety[0])).view(batch_size, -1)  
#         h_safe_hi = sigmf(reward_ratio, torch.relu(self.sigmf_f_safety[0])).view(batch_size, -1)  

        
#         _, _, _, normalized_aggregated = self.apply_rules(h_reward_lo, h_reward_hi, h_safe_lo, h_safe_hi)

#         has_nan = torch.isnan(normalized_aggregated).any()
#         return normalized_aggregated

class ANFIS(nn.Module):
    def __init__(self, out_level, state_dim, hidden_sizes: list = [64, 32], device="cuda:0"):
        super(ANFIS, self).__init__()
        self.out_level = out_level
        self.shared_layers, self.safe_output, self.reward_output = build_mlp_network([state_dim]+hidden_sizes+[1])
        self.fy = torch.linspace(-1, 1, out_level + 1).to(device)
        
        # Input membership functions (state safety & reward level): low high
        self.f_reward_hi = nn.Parameter(torch.rand(1, 1))  
        self.f_reward_lo = nn.Parameter(torch.rand(1, 1))    
        self.f_safety_hi = nn.Parameter(torch.rand(1, 1))
        self.f_safety_lo = nn.Parameter(torch.rand(1, 1))
        # Output membership functions (transition disturb level): low mid high 
        self.f_y_hi = nn.Parameter(torch.randn(2, 1)) 
        self.f_y_mid = nn.Parameter(torch.randn(2, 1)) 
        self.f_y_lo = nn.Parameter(torch.randn(2, 1)) 

    def apply_rules(self, h_reward_lo, h_reward_hi, h_safe_lo, h_safe_hi):
        batch = h_reward_lo.size(0)
        
        h_y_lo = gbellmf(torch.abs(self.fy), self.f_y_lo[0], self.f_y_lo[1])  # [level+1]
        h_y_mid = gbellmf(torch.abs(self.fy), self.f_y_mid[0], self.f_y_mid[1])  # [level+1]
        h_y_hi = gbellmf(torch.abs(self.fy),self.f_y_hi[0],self.f_y_hi[1])  # [level+1]
        
        w_hi = h_reward_lo * h_safe_hi
        w_mid = h_reward_lo * h_safe_hi + h_reward_hi * h_safe_lo
        w_lo = h_reward_hi * h_safe_lo
        
        w_sum = w_hi + w_mid + w_lo
        
        h_rule_high_disturb = torch.min(w_hi / w_sum, h_y_hi)
        h_rule_medium_disturb = torch.min(w_mid / w_sum, h_y_mid)
        h_rule_low_disturb = torch.min(w_lo / w_sum, h_y_lo)
        
        aggregated = torch.max(h_rule_high_disturb, torch.max(h_rule_medium_disturb, h_rule_low_disturb)) # [batch, level+1]
        normalized_aggregated = aggregated / (torch.sum(aggregated, dim=1, keepdim=True))
        return h_rule_high_disturb, h_rule_medium_disturb, h_rule_low_disturb, normalized_aggregated

    def forward(self, state):
        batch_size = state.size(0)
        shared_output = self.shared_layers(state)
        safe_ratio = self.safe_output(shared_output)
        reward_ratio = self.reward_output(shared_output)


        h_reward_lo = zmf(reward_ratio, torch.relu(self.f_reward_lo[0])).view(batch_size, -1)  
        h_reward_hi = sigmf(reward_ratio, torch.relu(self.f_reward_hi[0])).view(batch_size, -1)  

        h_safe_lo = zmf(safe_ratio, torch.relu(self.f_safety_lo[0])).view(batch_size, -1)  
        h_safe_hi = sigmf(reward_ratio, torch.relu(self.f_safety_hi[0])).view(batch_size, -1)  

        
        _, _, _, normalized_aggregated = self.apply_rules(h_reward_lo, h_reward_hi, h_safe_lo, h_safe_hi)

        return normalized_aggregated
