import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

from gen_rl.policy.env_models import DIM_LATENT


class Actor(nn.Module):
    def __init__(self, args, **kwargs):
        super(Actor, self).__init__()
        self._args = args
        state_dim, action_dim, max_action = self._args["state_dim"], self._args["action_dim"], self._args["max_action"]

        self.l1 = nn.Linear(state_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, action_dim)

        self.max_action = max_action

    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        return self.max_action * torch.tanh(self.l3(a))


class Critic(nn.Module):
    def __init__(self, args, **kwargs):
        super(Critic, self).__init__()
        self._args = args
        self._if_use_act_val_fn = args["if_use_act_val_fn"]
        state_dim, action_dim, max_action = self._args["state_dim"], self._args["action_dim"], self._args["max_action"]

        _state_dim = DIM_LATENT if args["if_use_latent_state"] else state_dim
        _state_dim += 0 if args["if_use_act_val_fn"] or not args["if_use_prev_state"] else state_dim
        _action_dim = (action_dim if args["if_use_act_val_fn"] else 0)
        self.l1 = nn.Linear(_state_dim, 400)
        self.l2 = nn.Linear(400 + _action_dim, 300)
        self.l3 = nn.Linear(300, 1)

    def forward(self, state, action=None):
        state = F.relu(self.l1(state))
        _in = torch.cat([state, action], 1) if self._if_use_act_val_fn else state
        q = F.relu(self.l2(_in))
        return self.l3(q)


def get_img_obs_encoder():
    from gen_rl.policy.cvae_cnn import Flatten
    model = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(32, 64, kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(64, 64, kernel_size=3, stride=1),
        nn.ReLU(),
        Flatten(),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Linear(512, DIM_LATENT),
    )
    return model


class Actor2(nn.Module):
    def __init__(self, args, **kwargs):
        super(Actor2, self).__init__()

        self._args = args
        state_dim, action_dim, max_action = self._args["state_dim"], self._args["action_dim"], self._args["max_action"]

        if self._args["mjc_if_pomdp"]:
            self.obs_enc = get_img_obs_encoder()
        self.net = nn.Sequential(
            nn.Linear(DIM_LATENT if args["mjc_if_pomdp"] else state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
        )
        self.max_action = max_action

    def forward(self, state):
        if self._args["mjc_if_pomdp"]:
            state /= 255.0
            state = self.obs_enc(state.permute(0, 3, 1, 2))
        a = self.net(state)
        return self.max_action * torch.tanh(a)


class Critic2(nn.Module):
    def __init__(self, args, **kwargs):
        super(Critic2, self).__init__()
        self._args = args
        self._if_use_act_val_fn = args["if_use_act_val_fn"]
        state_dim, action_dim, max_action = self._args["state_dim"], self._args["action_dim"], self._args["max_action"]

        _state_dim = DIM_LATENT if args["if_use_latent_state"] else state_dim
        _state_dim *= 1 if args["if_use_act_val_fn"] or not args["if_use_prev_state"] else 2
        # _state_dim += 0 if args["if_use_act_val_fn"] or not args["if_use_prev_state"] else state_dim
        _action_dim = (action_dim if args["if_use_act_val_fn"] else 0)

        if self._args["mjc_if_pomdp"]:
            self.obs_enc = get_img_obs_encoder()
        self.net = nn.Sequential(
            nn.Linear(DIM_LATENT + _action_dim if args["mjc_if_pomdp"] else _state_dim + _action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, kwargs["num_outputs"] if "num_outputs" in kwargs else 1),
        )

    def forward(self, state, action=None):
        if self._args["mjc_if_pomdp"]:
            state /= 255.0
            state = self.obs_enc(state.permute(0, 3, 1, 2))
        _in = torch.cat([state, action], 1) if self._if_use_act_val_fn else state
        return self.net(_in)


class Critic_TD3(nn.Module):
    def __init__(self, args, **kwargs):
        super(Critic_TD3, self).__init__()
        self._args = args
        self._if_use_act_val_fn = args["if_use_act_val_fn"]
        state_dim, action_dim, max_action = self._args["state_dim"], self._args["action_dim"], self._args["max_action"]

        _state_dim = DIM_LATENT if args["if_use_latent_state"] else state_dim
        _state_dim *= 1 if args["if_use_act_val_fn"] or not args["if_use_prev_state"] else 2
        # _state_dim += 0 if args["if_use_act_val_fn"] or not args["if_use_prev_state"] else state_dim
        _action_dim = (action_dim if args["if_use_act_val_fn"] else 0)

        # Q1 architecture
        self.l1 = nn.Linear(_state_dim + _action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)

        # Q2 architecture
        self.l4 = nn.Linear(_state_dim + _action_dim, 256)
        self.l5 = nn.Linear(256, 256)
        self.l6 = nn.Linear(256, 1)

    def forward(self, state, action=None):
        _in = torch.cat([state, action], 1) if self._if_use_act_val_fn else state

        q1 = F.relu(self.l1(_in))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)

        q2 = F.relu(self.l4(_in))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2

    def Q1(self, state, action=None):
        _in = torch.cat([state, action], 1) if self._if_use_act_val_fn else state

        q1 = F.relu(self.l1(_in))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1


class Actor_TD3(nn.Module):
    def __init__(self, args, **kwargs):
        super(Actor_TD3, self).__init__()

        self._args = args
        self._if_use_act_val_fn = args["if_use_act_val_fn"]
        state_dim, action_dim, max_action = self._args["state_dim"], self._args["action_dim"], self._args["max_action"]

        _state_dim = state_dim * (1 if args["if_use_act_val_fn"] else 2)
        _action_dim = (action_dim if args["if_use_act_val_fn"] else 0)

        self.l1 = nn.Linear(state_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, action_dim)

        self.max_action = max_action

    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        return self.max_action * torch.tanh(self.l3(a))


def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)


LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6


class Actor_SAC(nn.Module):
    def __init__(self, args, **kwargs):
        super(Actor_SAC, self).__init__()

        self._args = args
        self._if_use_act_val_fn = args["if_use_act_val_fn"]
        state_dim, action_dim, max_action = self._args["state_dim"], self._args["action_dim"], self._args["max_action"]

        self.l1 = nn.Linear(state_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.max_action = max_action

        self.mean_linear = nn.Linear(256, action_dim)
        self.log_std_linear = nn.Linear(256, action_dim)
        self.apply(weights_init_)

    def forward(self, state):
        # Forward
        x = F.relu(self.l1(state))
        x = F.relu(self.l2(x))
        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)

        # Sampling
        std = log_std.exp()
        self._mean, self._std = mean.detach().cpu().numpy(), std.detach().cpu().numpy()  # for visualisation purpose
        normal = Normal(mean, std)

        # Reparametrisation trick applies here
        x_t = normal.rsample()  # sample x batch
        action = torch.tanh(x_t) * self.max_action
        log_prob = normal.log_prob(x_t)

        # Enforcing Action Bound
        log_prob -= torch.log(1 - action.pow(2) + epsilon)
        self.log_prob = log_prob.sum(-1, keepdim=True)
        self.mean = torch.tanh(mean) * self.max_action
        return action
