import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim=None, activation=nn.LeakyReLU, output_activation=None):
        super().__init__()
        hidden_dims = [input_dim] + list(hidden_dims)
        model = []
        for in_dim, out_dim in zip(hidden_dims[:-1], hidden_dims[1:]):
            model += [nn.Linear(in_dim, out_dim), activation()]
        
        self.output_dim = hidden_dims[-1]
        if output_dim is not None:
            model += [nn.Linear(hidden_dims[-1], output_dim)]
            self.output_dim = output_dim
        if output_activation is not None:
            model += [output_activation()]
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        return self.model(x)

# for PPO/A2C/SAC
class ActorProb(nn.Module):
    def __init__(self, backbone, dist_net, device):
        super().__init__()

        self.device = device
        self.backbone = backbone.to(self.device)
        self.dist_net = dist_net.to(self.device)
    
    def get_dist(self, obs):
        if not isinstance(obs, torch.Tensor):
            obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
        logits = self.backbone(obs)
        dist = self.dist_net(logits)
        return dist


class DeterministicActor(nn.Module):
    def __init__(self, obs_dim, action_dim, max_action, device, hidden_dims=[256,256], activation=nn.LeakyReLU):
        super().__init__()
        hidden_dims = [obs_dim] + list(hidden_dims)
        net = []
        for in_dim, out_dim in zip(hidden_dims[:-1], hidden_dims[1:]):
            net += [nn.Linear(in_dim, out_dim), activation()]
        net += [nn.Linear(hidden_dims[-1], action_dim), nn.Tanh()]
        self.net = nn.Sequential(*net).to(device)
		
        self.max_action = max_action
        self.device = device

    def forward(self, obs):
        if not isinstance(obs, torch.Tensor):
            obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
        return self.max_action * self.net(obs)



class Critic(nn.Module):
    def __init__(self, backbone, device):
        super().__init__()

        self.device = device
        self.backbone = backbone.to(self.device)
        latent_dim = getattr(backbone, "output_dim")
        self.last = nn.Linear(latent_dim, 1).to(self.device)

    def forward(self, obs, actions=None):
        if not isinstance(obs, torch.Tensor):
            obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
        if actions is not None:
            if not isinstance(actions, torch.Tensor):
                actions = torch.as_tensor(actions, device=self.device, dtype=torch.float32).flatten(1)
            obs = torch.cat([obs, actions], dim=-1)
        logits = self.backbone(obs)
        values = self.last(logits)
        return values


class NormalWrapper(torch.distributions.Normal):
    def log_prob(self, actions):
        return super().log_prob(actions).sum(-1, keepdim=True)

    def entropy(self):
        return super().entropy().sum(-1)

    def mode(self):
        return self.mean


class DiagGaussian(nn.Module):
    def __init__(
        self, 
        latent_dim, 
        output_dim, 
        unbounded=False, 
        conditioned_sigma=False, 
        max_mu=1.0, 
        sigma_min=-20, 
        sigma_max=2
    ):
        super().__init__()
        self.mu = nn.Linear(latent_dim, output_dim)
        self._c_sigma = conditioned_sigma
        if conditioned_sigma:
            self.sigma = nn.Linear(latent_dim, output_dim)
        else:
            self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1))
        self._unbounded = unbounded
        self._max = max_mu
        self._sigma_min = sigma_min
        self._sigma_max = sigma_max

    def forward(self, logits):
        mu = self.mu(logits)
        if not self._unbounded:
            mu = self._max * torch.tanh(mu)
        if self._c_sigma:
            sigma = torch.clamp(self.sigma(logits), min=self._sigma_min, max=self._sigma_max).exp()
        else:
            shape = [1] * len(mu.shape)
            shape[1] = -1
            sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
        return NormalWrapper(mu, sigma)
