import torch
import torch.nn as nn
from typing import Optional
from distributions import DiagTruncatedGaussianDistribution, MaskableCategoricalDistribution


MAX_STD = 1.0
MIN_STD = 0.1

MAX_BETA = 2.0
MIN_BETA = 0.1

# forward policy with truncated Gaussian distribution
class ContinuousForwardPolicy(nn.Module):
    def __init__(self, state_dim, action_dim,\
                 hidden_sizes = [128, 128],\
                 activation_fn = torch.nn.ReLU,\
                 device = "cpu"):
        super(ContinuousForwardPolicy, self).__init__()

        self.device = device
        
        self.action_dim = action_dim
        
        # Create the action network
        layers = []
        # Input layer
        layers.append(nn.Linear(state_dim, hidden_sizes[0]))
        layers.append(activation_fn())

        # Hidden layers
        for i in range(len(hidden_sizes) - 1):
            layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i + 1]))
            layers.append(activation_fn())

        self.common_net = nn.Sequential(*layers)

        # Output layer mean
        action_layers = [nn.Linear(hidden_sizes[-1], action_dim)]
        action_layers.append(nn.Tanh())
        self.action_net = nn.Sequential(*action_layers)

        std_layers = [nn.Linear(hidden_sizes[-1], action_dim)]
        std_layers.append(nn.Sigmoid())
        self.std_net = nn.Sequential(*std_layers)

        exit_layers = [nn.Linear(hidden_sizes[-1], 1)]
        exit_layers.append(nn.Sigmoid())
        self.exit_net = nn.Sequential(*exit_layers)

        # Keyword arguments for gSDE distribution
        self.action_dist = DiagTruncatedGaussianDistribution(action_dim)
        
    def forward(self, obs, reach_end, use_mask, epsilon = 0):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float, device = self.device)

        commont_output = self.common_net(obs)
        mean_actions =  self.action_net(commont_output)
        std = self.std_net(commont_output) * (MAX_STD - MIN_STD) + MIN_STD
        
        exit_prob = self.exit_net(commont_output).squeeze()
        exit_flag = torch.bernoulli(exit_prob)

        if use_mask:
            exit_flag = torch.from_numpy(reach_end).to(exit_flag.device).bool() # exit_flag.bool() | torch.from_numpy(reach_end).to(exit_flag.device).bool()

        distribution = self.action_dist.proba_distribution(mean_actions, std)
        actions = distribution.get_actions()

         # Epsilon-random exploration
        if epsilon > 0.0:
            random_actions = 2 * (torch.rand_like(actions) - 0.5)  # Uniform in [-1, 1]
            rand_mask = (torch.rand(actions.shape[0], device=actions.device) < epsilon).float().unsqueeze(1)
            actions = actions * (1 - rand_mask) + random_actions * rand_mask

        if len(actions.shape) >= 2:
            # concatenate the exit flag to the actions, exit flag is the first column
            actions = torch.concat([exit_flag.unsqueeze(1), actions], dim=1)
        else:
            # add the exit flag (a scalar) to the action tensor
            actions = torch.cat([exit_flag.unsqueeze(0), actions])
        
        return actions
    
    def evaluate_actions(self, obs, actions, max_len, lengths = None, use_mask = False):
        common_output = self.common_net(obs)
        mean_actions = self.action_net(common_output)
        std = self.std_net(common_output) * (MAX_STD - MIN_STD) + MIN_STD

        exit_prob = self.exit_net(common_output).squeeze()

        distribution = self.action_dist.proba_distribution(mean_actions, std)

        exit_flag = actions[..., 0].squeeze()

        actions2 = actions[..., 1:]

        log_prob = distribution.log_prob(actions2)

        if use_mask:
            log_prob = log_prob
        else:
            log_prob = log_prob + torch.log(1 - exit_prob + 1e-10)
            log_prob[exit_flag > 0.5] = torch.log(exit_prob[exit_flag > 0.5] + 1e-10)

        if lengths is not None:
            # create and apply a mask based on the lengths
            batch_size, seq_len = obs.size(0), obs.size(1)
            if use_mask: # the last action is the exit action, with probability 1
                mask = torch.arange(seq_len).expand(batch_size, seq_len).to(obs.device) < lengths.unsqueeze(1) - 1
            else:
                mask = torch.arange(seq_len).expand(batch_size, seq_len).to(obs.device) < lengths.unsqueeze(1)
            log_prob = log_prob * mask

        return log_prob


class ContinuousBackwardPolicy(nn.Module):
    def __init__(self, state_dim, action_dim,\
                 hidden_sizes = [128, 128],\
                 activation_fn = torch.nn.ReLU,\
                 device = "cpu"):
        super(ContinuousBackwardPolicy, self).__init__()

        self.device = device
        
        self.action_dim = action_dim
        
        # Create the action network
        layers = []
        # Input layer
        layers.append(nn.Linear(state_dim, hidden_sizes[0]))
        layers.append(activation_fn())

        # Hidden layers
        for i in range(len(hidden_sizes) - 1):
            layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i + 1]))
            layers.append(activation_fn())

        self.common_net = nn.Sequential(*layers)

        # Output layer mean
        action_layers = [nn.Linear(hidden_sizes[-1], action_dim)]
        action_layers.append(nn.Tanh())
        self.action_net = nn.Sequential(*action_layers)

        std_layers = [nn.Linear(hidden_sizes[-1], action_dim)]
        std_layers.append(nn.Sigmoid())
        self.std_net = nn.Sequential(*std_layers)

        # Keyword arguments for gSDE distribution
        self.action_dist = DiagTruncatedGaussianDistribution(action_dim)
        
    def forward(self, obs):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float, device = self.device)

        commont_output = self.common_net(obs)
        mean_actions =  self.action_net(commont_output)
        std = self.std_net(commont_output) * (MAX_STD - MIN_STD) + MIN_STD

        distribution = self.action_dist.proba_distribution(mean_actions, std)
        actions = distribution.get_actions()

        if len(actions.shape) == 2 and actions.shape[0] == 1:
            actions = actions[0]
        
        return actions
    
    def evaluate_actions(self, obs, actions, lengths = None):
        common_output = self.common_net(obs)
        mean_actions = self.action_net(common_output)
        std = self.std_net(common_output) * (MAX_STD - MIN_STD) + MIN_STD

        distribution = self.action_dist.proba_distribution(mean_actions, std)
        log_prob = distribution.log_prob(actions[..., 1:])

        if lengths is not None:
            # create and apply a mask based on the lengths
            batch_size, seq_len = obs.size(0), obs.size(1)

            mask = torch.arange(seq_len).expand(batch_size, seq_len).to(obs.device) < (lengths-1).unsqueeze(1) # this -1 is important, this means the last action is the exit action

            log_prob = log_prob * mask

        return log_prob
    

class DiscreteForwardPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_sizes = [128, 128],\
                 activation_fn = torch.nn.ReLU,\
                 device = "cpu"):
        super(DiscreteForwardPolicy, self).__init__()

        self.device = device

        self.action_dim = action_dim

        # Create the action network
        layers = []
        # Input layer
        layers.append(nn.Linear(state_dim, hidden_sizes[0]))
        layers.append(activation_fn())

        # Hidden layers
        for i in range(len(hidden_sizes) - 1):
            layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i + 1]))
            layers.append(activation_fn())

        # Output layer
        layers.append(nn.Linear(hidden_sizes[-1], action_dim))
        # layers.append(nn.Softmax(dim=-1))

        self.action_net = nn.Sequential(*layers)

        self.distribution = MaskableCategoricalDistribution(self.action_dim)

    def forward(self, obs, action_masks: Optional[torch.Tensor] = None, epsilon = 0):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float, device = self.device)
        
        action_logits = self.action_net(obs)

        if epsilon > 0.0:
            batch_size, _ = action_logits.shape

            # Create epsilon mask
            rand_mask = (torch.rand(batch_size, device=obs.device) < epsilon).float().unsqueeze(1)  # [batch_size, 1]

            # Create uniform logits over valid actions
            uniform_logits = torch.zeros_like(action_logits)

            # Blend original logits and uniform logits
            action_logits = (1 - rand_mask) * action_logits + rand_mask * uniform_logits

        self.distribution.proba_distribution(action_logits, action_masks)

        actions = self.distribution.get_actions()

        if len(actions.shape) == 2 and actions.shape[0] == 1:
            actions = actions[0]

        return actions

    def evaluate_actions(self, obs, actions, action_masks: Optional[torch.Tensor] = None, lengths = None):
        action_logits = self.action_net(obs)

        self.distribution.proba_distribution(action_logits, action_masks)

        log_prob = self.distribution.log_prob(actions)


        if lengths is not None:
            # create and apply a mask based on the lengths
            batch_size, seq_len = obs.size(0), obs.size(1)
            mask = torch.arange(seq_len).expand(batch_size, seq_len).to(obs.device) < lengths.unsqueeze(1)
            log_prob = log_prob * mask

        if(log_prob.min().item()<-1000):
            # find the index of the action with log_prob < -1000
            print("Forward masking is wrong")

        return log_prob
    
class DiscreteBackwardPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_sizes = [128, 128],\
                 activation_fn = torch.nn.ReLU,\
                 device = "cpu"):
        super(DiscreteBackwardPolicy, self).__init__()

        self.device = device

        self.action_dim = action_dim

        # Create the action network
        layers = []
        # Input layer
        layers.append(nn.Linear(state_dim, hidden_sizes[0]))
        layers.append(activation_fn())

        # Hidden layers
        for i in range(len(hidden_sizes) - 1):
            layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i + 1]))
            layers.append(activation_fn())

        # Output layer
        layers.append(nn.Linear(hidden_sizes[-1], action_dim))

        self.action_net = nn.Sequential(*layers)

        self.distribution = MaskableCategoricalDistribution(self.action_dim)

    def forward(self, obs, action_masks: Optional[torch.Tensor] = None):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float, device = self.device)
        action_logits = self.action_net(obs)

        self.distribution.proba_distribution(action_logits, action_masks)

        actions = self.distribution.get_actions()
        
        if len(actions.shape) == 2 and actions.shape[0] == 1:
            actions = actions[0]

        return actions

    def evaluate_actions(self, obs, actions, action_masks: Optional[torch.Tensor] = None, lengths = None):
        action_logits = self.action_net(obs)

        self.distribution.proba_distribution(action_logits, action_masks)

        log_prob = self.distribution.log_prob(actions)

        # create and apply a mask based on the lengths
        batch_size, seq_len = obs.size(0), obs.size(1)
        mask = torch.arange(seq_len).expand(batch_size, seq_len).to(obs.device) < (lengths - 1).unsqueeze(1) # lengths - 1 for ignoring the last transition
        log_prob = log_prob * mask

        if(log_prob.min().item()<-1000):
            # find the index of the action with log_prob < -1000
            print("Backward masking is wrong")

        return log_prob


class Critic(nn.Module): # critic for value (advantage) function
    def __init__(self, state_dim, hidden_sizes = [128, 128],\
                 activation_fn = torch.nn.ReLU,\
                 device = "cpu"):
        super(Critic, self).__init__()
        self.device = device

        # Create the critic network
        layers = []
        # Input layer
        layers.append(nn.Linear(state_dim, hidden_sizes[0]))
        layers.append(activation_fn())

        # Hidden layers
        for i in range(len(hidden_sizes) - 1):
            layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i + 1]))
            layers.append(activation_fn())

        # Output layer
        layers.append(nn.Linear(hidden_sizes[-1], 1))

        self.critic = nn.Sequential(*layers)

    def forward(self, obs):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float, device = self.device)
        return self.critic(obs)

class logF(nn.Module): # logF for scaler estimator in SubTB GFlowNets, interestingly it is the same as critic.
    def __init__(self, state_dim, hidden_sizes = [128, 128],\
                 activation_fn = torch.nn.ReLU,\
                 device = "cpu"):
        super(logF, self).__init__()
        self.device = device

        # Create the state flow network
        layers = []
        # Input layer
        layers.append(nn.Linear(state_dim, hidden_sizes[0]))
        layers.append(activation_fn())

        # Hidden layers
        for i in range(len(hidden_sizes) - 1):
            layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i + 1]))
            layers.append(activation_fn())

        # Output layer
        layers.append(nn.Linear(hidden_sizes[-1], 1))
        
        self.critic = nn.Sequential(*layers)

    def forward(self, obs):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float, device = self.device)
        return self.critic(obs)


# Used for GAFN
class RND(nn.Module):
    def __init__(self, state_dim, reward_scale=0.5, hidden_dim=256, s_latent_dim=128, device="cpu"):
        super(RND, self).__init__()
        self.device = device

        self.random_target_network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, s_latent_dim)
        )

        self.predictor_network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, s_latent_dim),
        )
        
        self.reward_scale = reward_scale

    def forward(self, next_state):
        random_phi_s_next = self.random_target_network(next_state)
        predicted_phi_s_next = self.predictor_network(next_state)
        return random_phi_s_next, predicted_phi_s_next

    def compute_intrinsic_reward(self, next_states, lengths):
        random_phi_s_next, predicted_phi_s_next = self.forward(next_states)

        intrinsic_reward = torch.norm(predicted_phi_s_next.detach() - random_phi_s_next.detach(), dim=-1, p=2)
        intrinsic_reward *= self.reward_scale

        intrinsic_reward = intrinsic_reward.detach()

        batch_size, seq_len = next_states.size(0), next_states.size(1)
        mask = torch.arange(seq_len).expand(batch_size, seq_len).to(next_states.device) < lengths.unsqueeze(1)

        intrinsic_reward = intrinsic_reward * mask

        return intrinsic_reward

    def compute_loss(self, next_states, lengths):
        random_phi_s_next, predicted_phi_s_next = self.forward(next_states)
        rnd_loss = torch.norm(predicted_phi_s_next - random_phi_s_next.detach(), dim=-1, p=2)

        batch_size, seq_len = next_states.size(0), next_states.size(1)
        mask = torch.arange(seq_len).expand(batch_size, seq_len).to(next_states.device) < lengths.unsqueeze(1)

        rnd_loss = rnd_loss * mask

        mean_rnd_loss = torch.mean(rnd_loss)

        # assert mean_rnd_loss is not Nan
        assert not torch.isnan(mean_rnd_loss).any()
        
        return mean_rnd_loss