import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.optim import Adam


LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6

# Initialize Policy weights
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)

# For SAC Critic
class QNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim, args):
        super(QNetwork, self).__init__()

        self.metra_skill_dim = args.metra_skill_dim
        self.psd_skill_dim = args.radius_input_dim

        input_dims = [
            num_inputs - self.psd_skill_dim + num_actions,     # (obs, metra_skill, action)
            num_inputs - self.metra_skill_dim + num_actions,   # (obs, psd_skill, action)
        ]

        split_dims = [hidden_dim//2, hidden_dim//2]

        self.q1_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, split_dim),
                nn.ReLU()
            ) for input_dim, split_dim in zip(input_dims, split_dims)
        ])

        self.q2_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, split_dim),
                nn.ReLU()
            ) for input_dim, split_dim in zip(input_dims, split_dims)
        ])

        self.q1_final = nn.Sequential(
            nn.Linear(sum(split_dims), 1)
        )

        self.q2_final = nn.Sequential(
            nn.Linear(sum(split_dims), 1)
        )

        self.apply(weights_init_)

    def forward(self, state, action):
        obs = state[:, :-(self.metra_skill_dim + self.psd_skill_dim)]
        psd_skill = state[:, -(self.metra_skill_dim + self.psd_skill_dim):-self.metra_skill_dim]
        metra_skill = state[:, -self.metra_skill_dim:]

        q_inputs = [
            torch.cat([obs, metra_skill, action], dim=-1),           # (obs, metra_skill, action)
            torch.cat([obs, psd_skill, action], dim=-1)        # (obs, psd_skill, action)
        ]

        # Q1
        q1_features = [mlp(input_) for mlp, input_ in zip(self.q1_layers, q_inputs)]
        q1_combined = torch.cat(q1_features, dim=-1)
        q1_value = self.q1_final(q1_combined)

        # Q2
        q2_features = [mlp(input_) for mlp, input_ in zip(self.q2_layers, q_inputs)]
        q2_combined = torch.cat(q2_features, dim=-1)
        q2_value = self.q2_final(q2_combined)

        return q1_value, q2_value

# For SAC Actor
class GaussianPolicy(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim, action_space, args):
        super(GaussianPolicy, self).__init__()

        self.metra_skill_dim = args.metra_skill_dim
        self.psd_skill_dim = args.radius_input_dim

        input_dims = [
            num_inputs - self.metra_skill_dim,                        # (obs, psd_skill)
            num_inputs - self.psd_skill_dim                           # (obs, metra_skill)
        ]

        split_dims = [hidden_dim//2, hidden_dim//2]
        
        self.mlp_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, split_dim),
                nn.ReLU(),
            ) for input_dim, split_dim in zip(input_dims, split_dims)
        ])

        self.final_mlp = nn.Sequential(
            nn.Linear(sum(split_dims), hidden_dim),
            nn.ReLU()
        )

        self.mean_linear = nn.Linear(hidden_dim, num_actions)
        self.log_std_linear = nn.Linear(hidden_dim, num_actions)

        self.apply(weights_init_)

        # Action rescaling
        self.action_scale = torch.FloatTensor((action_space.high - action_space.low) / 2.)
        self.action_bias = torch.FloatTensor((action_space.high + action_space.low) / 2.)

    def forward(self, state):
        obs = state[:, :-(self.metra_skill_dim + self.psd_skill_dim)]
        psd_skill = state[:, -(self.metra_skill_dim + self.psd_skill_dim):-self.metra_skill_dim]
        metra_skill = state[:, -self.metra_skill_dim:]

        inputs = [
            torch.cat([obs, psd_skill], dim=-1),  # (obs, psd_skill)
            torch.cat([obs, metra_skill], dim=-1) # (obs, metra_skill)
        ]

        features = [mlp(input_) for mlp, input_ in zip(self.mlp_layers, inputs)]
        combined = torch.cat(features, dim=-1)

        x = self.final_mlp(combined)

        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)

        return mean, log_std

    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1)) , r stands for reparameterization trick in 'r'sample
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
        log_prob = log_prob.sum(1, keepdim=True) # sum of log (prob) == log (product of prob) 
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean

    def to(self, device):
        self.action_scale = self.action_scale.to(device)
        self.action_bias = self.action_bias.to(device)
        return super(GaussianPolicy, self).to(device)
