import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import numpy as np
import ipdb
from dm_env import specs

LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6

# Initialize Policy weights
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)


class ValueNetwork(nn.Module):
    def __init__(self, num_inputs, hidden_dim):
        super(ValueNetwork, self).__init__()

        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)

        self.apply(weights_init_)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x

        
class QNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim):
        super(QNetwork, self).__init__()
        # Q1 architecture
        # print('num_inputs', num_inputs)
        # print('num_actions', num_actions)

        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear2_2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)

        # Q2 architecture
        self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.linear5 = nn.Linear(hidden_dim, hidden_dim)
        self.linear5_5 = nn.Linear(hidden_dim, hidden_dim)
        self.linear6 = nn.Linear(hidden_dim, 1)

        self.apply(weights_init_)

    def forward(self, state, action):
        xu = torch.cat([state, action], 1)
        
        x1 = F.elu(self.linear1(xu))
        x1 = F.relu(self.linear2(x1))
        x1 = F.relu(self.linear2_2(x1))
        x1 = self.linear3(x1)

        x2 = F.elu(self.linear4(xu))
        x2 = F.relu(self.linear5(x2))
        x2 = F.relu(self.linear5_5(x2))
        x2 = self.linear6(x2)

        return x1, x2

class GaussianCausalPolicy(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
        super(GaussianCausalPolicy, self).__init__()
        
        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)

        self.mean_linear = nn.Linear(hidden_dim, num_actions)
        self.log_std_linear = nn.Linear(hidden_dim, num_actions)

        self.causal_default_weight = np.ones(num_actions, dtype=np.float32)

        self.apply(weights_init_)

        # action rescaling
        if action_space is None:
            self.action_scale = torch.tensor(1.)
            self.action_bias = torch.tensor(0.)
        elif isinstance(action_space, specs.BoundedArray):
            self.action_scale = torch.FloatTensor(
                (action_space.maximum - action_space.minimum) / 2.)
            self.action_bias = torch.FloatTensor(
                (action_space.maximum + action_space.minimum) / 2.)
        else:
            self.action_scale = torch.FloatTensor(
                (action_space.high - action_space.low) / 2.)
            self.action_bias = torch.FloatTensor(
                (action_space.high + action_space.low) / 2.)

    def forward(self, state):
        x = F.elu(self.linear1(state))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
        return mean, log_std

    def sample(self, state, causal_weight=None):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
        #* compute causal weighted entropy
        if causal_weight is None:
            causal_weight = self.causal_default_weight
        causal_weight = torch.from_numpy(causal_weight).to(log_prob.device).clone().detach()
        log_prob = log_prob * causal_weight.unsqueeze(0)

        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean
    
    def get_log_density(self, state, action):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        log_density = normal.log_prob(action)
        return log_density

    def to(self, device):
        self.action_scale = self.action_scale.to(device)
        self.action_bias = self.action_bias.to(device)
        return super(GaussianCausalPolicy, self).to(device)


class GaussianDynamics(nn.Module):
    def __init__(self, num_inputs, num_state, hidden_dim, state_space=None):
        super(GaussianDynamics, self).__init__()

        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)

        self.mean_linear = nn.Linear(hidden_dim, num_state)
        self.log_std_linear = nn.Linear(hidden_dim, num_state)

        self.causal_default_weight = np.ones(num_state, dtype=np.float32)

        self.apply(weights_init_)

        # action rescaling
        if state_space is None:
            self.state_scale = torch.tensor(1.)
            self.state_bias = torch.tensor(0.)
        elif isinstance(state_space, specs.BoundedArray):
            self.state_scale = torch.FloatTensor(
                (state_space.maximum - state_space.minimum) / 2.)
            self.state_bias = torch.FloatTensor(
                (state_space.maximum + state_space.minimum) / 2.)
        else:
            self.state_scale = torch.FloatTensor(
                (state_space.high - state_space.low) / 2.)
            self.state_bias = torch.FloatTensor(
                (state_space.high + state_space.low) / 2.)

    def forward(self, state_action):
        x = F.elu(self.linear1(state_action))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
        return mean, log_std

    def sample(self, state_action):
        mean, log_std = self.forward(state_action)
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        next_state = y_t * self.state_scale + self.state_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.state_scale * (1 - y_t.pow(2)) + epsilon)
        # # * compute causal weighted entropy
        # if causal_weight is None:
        #     causal_weight = self.causal_default_weight
        # causal_weight = torch.from_numpy(causal_weight).to(log_prob.device).clone().detach()
        # log_prob = log_prob * causal_weight.unsqueeze(0)

        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.state_scale + self.state_bias
        return next_state, log_prob, mean

    # def get_log_density(self, state, action):
    #     mean, log_std = self.forward(state)
    #     std = log_std.exp()
    #     normal = Normal(mean, std)
    #     log_density = normal.log_prob(action)
    #     return log_density

    def to(self, device):
        self.state_scale = self.state_scale.to(device)
        self.state_bias = self.state_bias.to(device)
        return super(GaussianDynamics, self).to(device)
