import numpy as np
import scipy.signal
from gym.spaces import Box, Discrete

import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
import os
import os.path as osp

from gymnasium.spaces import Box as SafeBox
from gymnasium.spaces import Discrete as SafeDiscrete


def load_pytorch_policy(fpath, itr, deterministic=False):
    """ Load a pytorch policy saved with Spinning Up Logger."""
    
    fname = osp.join(fpath, 'pyt_save', 'model'+itr+'.pt')
    print('\n\nLoading from %s.\n\n'%fname)

    model = torch.load(fname)

    # make function for producing an action given a single state
    def get_v(x):
        with torch.no_grad():
            x = torch.as_tensor(x, dtype=torch.float32)
            _, v, vc, _ = model.step(x)
        return v, vc

    return get_v, model

def load_policy(fpath, itr='last', deterministic=False):
    """
    Load a policy from save, whether it's TF or PyTorch, along with RL env.

    Not exceptionally future-proof, but it will suffice for basic uses of the 
    Spinning Up implementations.

    Checks to see if there's a tf1_save folder. If yes, assumes the model
    is tensorflow and loads it that way. Otherwise, loads as if there's a 
    PyTorch save.
    """

    # determine if tf save or pytorch save
    if any(['tf1_save' in x for x in os.listdir(fpath)]):
        backend = 'tf1'
    else:
        backend = 'pytorch'

    # handle which epoch to load from
    if itr=='last':
        # check filenames for epoch (AKA iteration) numbers, find maximum value

        if backend == 'tf1':
            pass
        elif backend == 'pytorch':
            pytsave_path = osp.join(fpath, 'pyt_save')
            # Each file in this folder has naming convention 'modelXX.pt', where
            # 'XX' is either an integer or empty string. Empty string case
            # corresponds to len(x)==8, hence that case is excluded.
            saves = [int(x.split('.')[0][5:]) for x in os.listdir(pytsave_path) if len(x)>8 and 'model' in x]

        itr = '%d'%max(saves) if len(saves) > 0 else ''
    else:
        assert isinstance(itr, int), \
            "Bad value provided for itr (needs to be int or 'last')."
        itr = '%d'%itr
    # load the get_action function
    if backend == 'tf1':
        pass
    else:
        get_v, model = load_pytorch_policy(fpath, itr, deterministic)

    return get_v, model

def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)


def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)


def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])


def discount_cumsum(x, discount):
    """
    magic from rllab for computing discounted cumulative sums of vectors.

    input: 
        vector x, 
        [x0, 
         x1, 
         x2]

    output:
        [x0 + discount * x1 + discount^2 * x2,  
         x1 + discount * x2,
         x2]
    """
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]


class Actor(nn.Module):

    def _distribution(self, obs):
        raise NotImplementedError

    def _log_prob_from_distribution(self, pi, act):
        raise NotImplementedError

    def forward(self, obs, act=None):
        # Produce action distributions for given observations, and 
        # optionally compute the log likelihood of given actions under
        # those distributions.
        pi = self._distribution(obs)
        logp_a = None
        if act is not None:
            logp_a = self._log_prob_from_distribution(pi, act)
        return pi, logp_a


class MLPCategoricalActor(Actor):
    
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)

    def _distribution(self, obs):
        logits = self.logits_net(obs)
        return Categorical(logits=logits)

    def _log_prob_from_distribution(self, pi, act):
        return pi.log_prob(act)


class MLPGaussianActor(Actor):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        log_std = -0.5 * np.ones(act_dim, dtype=np.float32)
        self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
        self.mu_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)

    def _distribution(self, obs):
        mu = self.mu_net(obs)
        std = torch.exp(self.log_std)
        return Normal(mu, std)

    def _log_prob_from_distribution(self, pi, act):
        return pi.log_prob(act).sum(axis=-1)    # Last axis sum needed for Torch Normal distribution


class MLPCritic(nn.Module):

    def __init__(self, obs_dim, hidden_sizes, activation):
        super().__init__()
        self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs):
        return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape.

class MLPActorCritic(nn.Module):
    def __init__(self, observation_space, action_space, 
                 hidden_sizes=(64,64), activation=nn.Tanh):
        super().__init__()

        obs_dim = observation_space.shape[0]

        # policy builder depends on action space
        if isinstance(action_space, Box) or isinstance(action_space, SafeBox):
            self.pi = MLPGaussianActor(obs_dim, action_space.shape[0], hidden_sizes, activation)
        elif isinstance(action_space, Discrete) or isinstance(action_space, SafeDiscrete):
            self.pi = MLPCategoricalActor(obs_dim, action_space.n, hidden_sizes, activation)

        # build value function
        self.v  = MLPCritic(obs_dim, hidden_sizes, activation)
        self.vc = MLPCritic(obs_dim, hidden_sizes, activation)

    def step(self, obs):
        with torch.no_grad():
            pi = self.pi._distribution(obs)
            a = pi.sample()
            logp_a = self.pi._log_prob_from_distribution(pi, a)
            v = self.v(obs)
            vc = self.vc(obs)
        return a.cpu().numpy(), v.cpu().numpy(), vc.cpu().numpy(), logp_a.cpu().numpy()

    def act(self, obs):
        return self.pi.mu_net(obs)

def build_mlp_network(sizes):
    layers = list()
    for j in range(len(sizes) - 1):
        act = nn.Tanh if j < len(sizes) - 2 else None  
        affine_layer = nn.Linear(sizes[j], sizes[j + 1])
        nn.init.kaiming_uniform_(affine_layer.weight, a=np.sqrt(5))
        layers.append(affine_layer)
        if act is not None:  
            layers.append(act())
    
    # Separate output layers for safe_ratio and reward_ratio
    safe_output = nn.Sequential(nn.Linear(sizes[-1], 1), nn.Sigmoid())
    reward_output = nn.Sequential(nn.Linear(sizes[-1], 1), nn.Sigmoid())
    
    return nn.Sequential(*layers), safe_output, reward_output

@torch.jit.script
def gbellmf(x, a, b):
    return 1 / (1 + torch.abs((x - 0.5) / (a+1e-5)) ** (2 * b))
@torch.jit.script
def sigmf(x, a):
    return 1 / (1 + torch.exp(-a * (x - 0.5)))
@torch.jit.script
def zmf(x, a):
    return 1 / (1 + torch.exp(a * (x - 0.5)))

class ANFIS(nn.Module):
    def __init__(self, out_level, state_dim, hidden_sizes: list = [64, 32], device="cuda:0"):
        super(ANFIS, self).__init__()
        self.out_level = out_level
        self.shared_layers, self.safe_output, self.reward_output = build_mlp_network([state_dim]+hidden_sizes+[1])
        self.fy = torch.linspace(-1, 1, out_level + 1).to(device)
        
        # Input membership functions (state safety & reward level): low high
        self.f_reward_hi = nn.Parameter(torch.rand(1, 1))  
        self.f_reward_lo = nn.Parameter(torch.rand(1, 1))    
        self.f_safety_hi = nn.Parameter(torch.rand(1, 1))
        self.f_safety_lo = nn.Parameter(torch.rand(1, 1))
        # Output membership functions (transition disturb level): low mid high 
        self.f_y_hi = nn.Parameter(torch.randn(2, 1)) 
        self.f_y_mid = nn.Parameter(torch.randn(2, 1)) 
        self.f_y_lo = nn.Parameter(torch.randn(2, 1)) 

    def apply_rules(self, h_reward_lo, h_reward_hi, h_safe_lo, h_safe_hi):
        batch = h_reward_lo.size(0)
        
        h_y_lo = gbellmf(torch.abs(self.fy), self.f_y_lo[0], self.f_y_lo[1])  # [level+1]
        h_y_mid = gbellmf(torch.abs(self.fy), self.f_y_mid[0], self.f_y_mid[1])  # [level+1]
        h_y_hi = gbellmf(torch.abs(self.fy),self.f_y_hi[0],self.f_y_hi[1])  # [level+1]
        
        w_hi = h_reward_lo * h_safe_hi
        w_mid = h_reward_lo * h_safe_hi + h_reward_hi * h_safe_lo
        w_lo = h_reward_hi * h_safe_lo
        
        w_sum = w_hi + w_mid + w_lo
        
        h_rule_high_disturb = torch.min(w_hi / w_sum, h_y_hi)
        h_rule_medium_disturb = torch.min(w_mid / w_sum, h_y_mid)
        h_rule_low_disturb = torch.min(w_lo / w_sum, h_y_lo)
        
        aggregated = torch.max(h_rule_high_disturb, torch.max(h_rule_medium_disturb, h_rule_low_disturb)) # [batch, level+1]
        normalized_aggregated = aggregated / (torch.sum(aggregated, dim=1, keepdim=True))
        return h_rule_high_disturb, h_rule_medium_disturb, h_rule_low_disturb, normalized_aggregated

    def forward(self, state):
        batch_size = state.size(0)
        shared_output = self.shared_layers(state)
        safe_ratio = self.safe_output(shared_output)
        reward_ratio = self.reward_output(shared_output)


        h_reward_lo = zmf(reward_ratio, torch.relu(self.f_reward_lo[0])).view(batch_size, -1)  
        h_reward_hi = sigmf(reward_ratio, torch.relu(self.f_reward_hi[0])).view(batch_size, -1)  

        h_safe_lo = zmf(safe_ratio, torch.relu(self.f_safety_lo[0])).view(batch_size, -1)  
        h_safe_hi = sigmf(reward_ratio, torch.relu(self.f_safety_hi[0])).view(batch_size, -1)  

        
        _, _, _, normalized_aggregated = self.apply_rules(h_reward_lo, h_reward_hi, h_safe_lo, h_safe_hi)

        return normalized_aggregated
