import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal


# class NormalizeDensity():
#     def __init__(self, density, size_x=4, size_y=4):
#         self.density = density
#         self.size_x, self.size_y = size_x, size_y
    
#         n_pts = 101
#         x = np.linspace(0, size_x, n_pts)
#         xx, yy = np.meshgrid((x, x))
#         zz = np.stack((xx.flatten(), yy.flatten()), axis=1)
#         zz = zz.reshape((-1, 2))
        
#         self.total_sum = np.sum(density.score_samples(zz))

#     def score_samples(self, x):
#         scores = self.density.score_samples(x)
#         return scores / self.total_sum

class Clamp(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Clamp, self).__init__()

    def forward(self, input):
        return torch.clamp(input, min=-10, max=10)

class ScaledTanh(nn.Module):
    def __init__(self, *args, **kwargs):
        super(ScaledTanh, self).__init__()

    def forward(self, input):
        return torch.tanh(input) * 10

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

def tanh_mlp(sizes, activation, output_activation=nn.Tanh):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

def sigmoid_mlp(sizes, activation, output_activation=nn.Sigmoid):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

def clamped_mlp(sizes, activation, output_activation=Clamp):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

def scaled_tanh(sizes, activation, output_activation=ScaledTanh):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)


LOG_STD_MAX = 2
LOG_STD_MIN = -20



class SquashedGaussianMLPActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.act_limit = act_limit

    def forward(self, obs, deterministic=False, with_logprob=True):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
            # NOTE: The correction formula is a little bit magic. To get an understanding 
            # of where it comes from, check out the original SAC paper (arXiv 1801.01290) 
            # and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
            # Try deriving it yourself as a (very difficult) exercise. :)
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi
        
    def log_prob_unclipped(self,obs,action):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        
        std = torch.exp(log_std)
        pi_distribution = Normal(mu, std)
        
        pi_action = action
        logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
        return logp_pi

    def log_prob(self, obs, act):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)

        # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
        # NOTE: The correction formula is a little bit magic. To get an understanding 
        # of where it comes from, check out the original SAC paper (arXiv 1801.01290) 
        # and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
        # Try deriving it yourself as a (very difficult) exercise. :)
        logp_pi = pi_distribution.log_prob(act).sum(axis=-1)
        logp_pi -= (2*(np.log(2) - act - F.softplus(-2*act))).sum(axis=1)

        return logp_pi


class awacMLPActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)

        self.log_std_logits = nn.Parameter(
                    torch.zeros(act_dim, requires_grad=True))
        self.min_log_std = -6
        self.max_log_std = 0
        # self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.act_limit = act_limit

    def forward(self, obs, deterministic=False, with_logprob=True):
        # print("Using the special policy")
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        mu = torch.tanh(mu) * self.act_limit

        log_std = torch.sigmoid(self.log_std_logits)
        
        log_std = self.min_log_std + log_std * (
                        self.max_log_std - self.min_log_std)
        std = torch.exp(log_std)
        # print("Std: {}".format(std))

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
            # NOTE: The correction formula is a little bit magic. To get an understanding 
            # of where it comes from, check out the original SAC paper (arXiv 1801.01290) 
            # and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
            # Try deriving it yourself as a (very difficult) exercise. :)
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            # logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None


        return pi_action, logp_pi

    def log_prob_unclipped(self,obs,actions):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        mu = torch.tanh(mu) * self.act_limit
        log_std = torch.sigmoid(self.log_std_logits)
        # log_std = self.log_std_layer(net_out)
        log_std = self.min_log_std + log_std * (
                        self.max_log_std - self.min_log_std)
        std = torch.exp(log_std)
        pi_distribution = Normal(mu, std)
        logp_pi = pi_distribution.log_prob(actions).sum(axis=-1)

        return logp_pi

    def get_logprob(self,obs, actions):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        mu = torch.tanh(mu) * self.act_limit
        log_std = torch.sigmoid(self.log_std_logits)
        # log_std = self.log_std_layer(net_out)
        log_std = self.min_log_std + log_std * (
                        self.max_log_std - self.min_log_std)
        std = torch.exp(log_std)
        pi_distribution = Normal(mu, std)
        logp_pi = pi_distribution.log_prob(actions).sum(axis=-1)

        return logp_pi



class SquashedGmmMLPActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit, k):
        super().__init__()
        print("gmm")
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], k*act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], k*act_dim)
        self.act_limit = act_limit
        self.k = k 
        

    def forward(self, obs, deterministic=False, with_logprob=True):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # n = batch size
        n, _ = mu.shape
        mixture_components = torch.from_numpy(np.random.randint(0, self.k, (n))) # NOTE: fixed equal weight

        # change shape to k x batch_size x act_dim
        mu = mu.view(n, self.k, -1).permute(1, 0, 2)
        std = std.view(n, self.k, -1).permute(1, 0, 2)

        mu_sampled = mu[mixture_components, torch.arange(0,n).long(), :]
        std_sampled = std[mixture_components, torch.arange(0,n).long(), :]

        if deterministic:
            pi_action = mu_sampled
        else:
            pi_action = Normal(mu_sampled, std_sampled).rsample() # (n, act_dim)

        if with_logprob:
            # logp_pi[i,j] contains probability of ith action under jth mixture component
            logp_pi = torch.zeros((n, self.k)).to(pi_action)

            for j in range(self.k):
                pi_distribution = Normal(mu[j,:,:], std[j,:,:]) # (n, act_dim)

                logp_pi_mixture = pi_distribution.log_prob(pi_action).sum(axis=-1)
                logp_pi_mixture -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
                logp_pi[:,j] = logp_pi_mixture

            # logp_pi = (sum of p_pi over mixture components)/k
            logp_pi = torch.logsumexp(logp_pi, dim=1) - torch.FloatTensor([np.log(self.k)]).to(logp_pi) # numerical stable
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi


class MLPQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        # print(obs, act)
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.

class MLPActorCritic(nn.Module):

    def __init__(self, observation_space, action_space, k, hidden_sizes=(256,256), add_time=False,
                 activation=nn.ReLU, device=torch.device("cpu"),special_policy=""):
        super().__init__()

        obs_dim = observation_space.shape[0]
        act_dim = action_space.shape[0]
        act_limit = action_space.high[0]
        self.device = device
        # print("MLP actor critic device: ", device)

        # build policy and value functions
        # if add_time: # policy ignores the time index. only Q function uses the time index
        #     self.pi = SquashedGaussianMLPActor(obs_dim - 1, act_dim, hidden_sizes, activation, act_limit).to(self.device)
        # else:

        # old code: gaussian
        #self.pi = SquashedGaussianMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit).to(self.device)

        if k == 1:
            if special_policy=="awac": 
                self.pi = awacMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit).to(self.device)
            else:
                self.pi = SquashedGaussianMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit).to(self.device)
        else:
            self.pi = SquashedGmmMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit, k).to(self.device)
        self.q1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation).to(self.device)
        self.q2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation).to(self.device)

    def act(self, obs, deterministic=False, get_logprob = False):
        with torch.no_grad():
            a, logpi = self.pi(obs, deterministic, True)
            if get_logprob:
                return a.cpu().data.numpy().flatten(), logpi.cpu().data.numpy()
            else:
                return a.cpu().data.numpy().flatten()

    def act_batch(self, obs, deterministic=False):
        with torch.no_grad():
            a, logpi = self.pi(obs, deterministic, True)
            return a.cpu().data.numpy(), logpi.cpu().data.numpy()

    def log_prob(self, obs, act):
        return self.pi.log_prob(obs, act)


class MLPReward(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_sizes=(256,256),
        hid_act='tanh',
        use_bn=False,
        residual=False,
        clamp_magnitude=10.0,
        device=torch.device('cpu'), 
        **kwargs
    ):
        super().__init__()

        if hid_act == 'relu':
            hid_act_class = nn.ReLU
        elif hid_act == 'leaky_relu':
            hid_act_class = nn.LeakyReLU
        elif hid_act == 'tanh':
            hid_act_class = nn.Tanh
        else:
            raise NotImplementedError()

        self.clamp_magnitude = clamp_magnitude
        self.input_dim = input_dim
        self.device = device
        self.residual = residual

        self.first_fc = nn.Linear(input_dim, hidden_sizes[0])
        self.blocks_list = nn.ModuleList()

        for i in range(len(hidden_sizes) - 1):
            block = nn.ModuleList()
            block.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            if use_bn: block.append(nn.BatchNorm1d(hidden_sizes[i+1]))
            block.append(hid_act_class())
            self.blocks_list.append(nn.Sequential(*block))
        
        self.last_fc = nn.Linear(hidden_sizes[-1], 1)

    def forward(self, batch):
        x = self.first_fc(batch)
        for block in self.blocks_list:
            if self.residual:
                x = x + block(x)
            else:
                x = block(x)
        output = self.last_fc(x)
        output = torch.clamp(output, min=-1.0*self.clamp_magnitude, max=self.clamp_magnitude)
        return output  

    def r(self, batch):
        return self.forward(batch)

    def get_scalar_reward(self, obs):
        self.eval()
        with torch.no_grad():
            if not torch.is_tensor(obs):
                obs = torch.FloatTensor(obs.reshape(-1,self.input_dim))
            obs = obs.to(self.device)
            reward = self.forward(obs).cpu().detach().numpy().flatten()
        self.train()
        return reward



class MLPTrexReward(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_sizes=(256,256),
        hid_act='leaky_relu',
        use_bn=False,
        residual=False,
        clamp_magnitude=10.0,
        device=torch.device('cpu'), 
        **kwargs
    ):
        super().__init__()

        if hid_act == 'relu':
            hid_act_class = nn.ReLU
        elif hid_act == 'leaky_relu':
            hid_act_class = nn.LeakyReLU
        elif hid_act == 'tanh':
            hid_act_class = nn.Tanh
        else:
            raise NotImplementedError()

        self.clamp_magnitude = clamp_magnitude
        self.input_dim = input_dim
        self.device = device
        self.residual = residual




        self.first_fc = nn.Linear(input_dim, hidden_sizes[0])
        self.blocks_list = nn.ModuleList()

        for i in range(len(hidden_sizes) - 1):
            block = nn.ModuleList()
            block.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            if use_bn: block.append(nn.BatchNorm1d(hidden_sizes[i+1]))
            block.append(hid_act_class())
            self.blocks_list.append(nn.Sequential(*block))
        
        self.last_fc = nn.Linear(hidden_sizes[-1], 1)

    def forward(self, batch):
        x = self.first_fc(batch)
        for block in self.blocks_list:
            if self.residual:
                x = x + block(x)
            else:
                x = block(x)
        output = self.last_fc(x)
        output = torch.clamp(output, min=-1.0*self.clamp_magnitude, max=self.clamp_magnitude)
        return output  

    def r(self, batch):
        return self.forward(batch)

    def get_scalar_reward(self, obs):
        self.eval()
        with torch.no_grad():
            if not torch.is_tensor(obs):
                obs = torch.FloatTensor(obs.reshape(-1,self.input_dim))
            obs = obs.to(self.device)
            reward = self.forward(obs).cpu().detach().numpy().flatten()
        self.train()
        return reward



class MLPSigmoidReward(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_sizes=(256,256),
        hid_act='tanh',
        use_bn=False,
        residual=False,
        clamp_magnitude=10.0,
        device=torch.device('cpu'), 
        **kwargs
    ):
        super().__init__()

        if hid_act == 'relu':
            hid_act_class = nn.ReLU
        elif hid_act == 'leaky_relu':
            hid_act_class = nn.LeakyReLU
        elif hid_act == 'tanh':
            hid_act_class = nn.Tanh
        else:
            raise NotImplementedError()

        self.clamp_magnitude = clamp_magnitude
        self.input_dim = input_dim
        self.device = device
        self.residual = residual

        self.first_fc = nn.Linear(input_dim, hidden_sizes[0])
        self.blocks_list = nn.ModuleList()

        for i in range(len(hidden_sizes) - 1):
            block = nn.ModuleList()
            block.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            if use_bn: block.append(nn.BatchNorm1d(hidden_sizes[i+1]))
            block.append(hid_act_class())
            self.blocks_list.append(nn.Sequential(*block))
        
        self.last_fc = nn.Linear(hidden_sizes[-1], 1)


    def forward(self, batch):
        x = self.first_fc(batch)
        for block in self.blocks_list:
            if self.residual:
                x = x + block(x)
            else:
                x = block(x)
        output = self.last_fc(x)
        output = torch.sigmoid(output)
        # output = torch.clamp(output, min=-1.0*self.clamp_magnitude, max=self.clamp_magnitude)
        return output  

    def r(self, batch):
        return self.forward(batch)

    def get_scalar_reward(self, obs):
        self.eval()
        with torch.no_grad():
            if not torch.is_tensor(obs):
                obs = torch.FloatTensor(obs.reshape(-1,self.input_dim))
            obs = obs.to(self.device)
            reward = self.forward(obs).cpu().detach().numpy().flatten()
        self.train()
        return reward


class MeshReward(nn.Module):
    def __init__(self, n_ptr=10, size_x=4, size_y=4, device=torch.device("cpu")):
        super(MeshReward, self).__init__()
        self.device = device
        self.fc1 = nn.Linear(n_ptr**2, 1).to(self.device)
        self.centers = torch.zeros((n_ptr**2, 2)).float().to(self.device)
        self.n_ptr = n_ptr
        radius_x, radius_y = float(size_x) / (n_ptr+1), float(size_y) / (n_ptr+1)
        
        for i in range(n_ptr):
            for j in range(n_ptr):
                self.centers[i+n_ptr*j, :] = torch.Tensor([radius_x*(i+1), radius_x*(j+1)])
        self.centers = self.centers.unsqueeze(0) # (1, n_ptr**2, 2)
        # with torch.no_grad():
        #     self.fc1.weight *= 10
        
    def forward(self, x):
        distance = torch.sum((x.view(-1, 1, 2).repeat(1, self.n_ptr**2, 1) - self.centers)**2, dim=2) # (B, n_ptr**2)
        phi = torch.exp(-distance) # RBF kernel
        score = self.fc1(phi)
        return torch.clamp(score, -1, 1) # normalize to [-1, 1]

    def r(self, x):
        return self.forward(x)

    def get_scalar_reward(self, obs):
        with torch.no_grad():
            if not torch.is_tensor(obs):
                obs = torch.FloatTensor(obs.reshape(-1,2))
            obs = obs.to(self.device)
            reward = self.forward(obs).cpu().detach().numpy().flatten()
        return reward

class Discriminator(nn.Module):
    def __init__(self, expert_density, observation_dim, hidden_sizes=(256,256),
                 activation=nn.ReLU, device=torch.device("cpu")):
        super(Discriminator, self).__init__()
        self.expert_density = expert_density
        self.device = device

        self.agent_density = sigmoid_mlp([self.obs_dim] + list(hidden_sizes) + [1], activation).to(self.device)

    def foward(self, s):
        return self.agent_density(s)

