import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):

    def __init__(self, input_dim, output_dim, architecture_params):
        super(MLP, self).__init__()

        n_neurons = architecture_params['n_neurons']
        n_hidden_layers = architecture_params['n_layers']
        activation_type = nn.ReLU() if architecture_params['activation_type'] == 0 else nn.Tanh()

        layers = [nn.Linear(input_dim, n_neurons), activation_type]
        for _ in range(n_hidden_layers):
            layers += [nn.Linear(n_neurons, n_neurons),
                       activation_type]
        layers += [nn.Linear(n_neurons, output_dim)]
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class CNN(nn.Module):

    def __init__(self, input_channels, output_dim, architecture_params):
        super(CNN, self).__init__()

        self.network = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
            nn.LazyLinear(out_features=output_dim)
        )

    def forward(self, x):
        return self.network(x)

class Actor(nn.Module):

    def __init__(self, state_dim, action_dim, architecture_params):
        super(Actor, self).__init__()

        self.net = MLP(state_dim, architecture_params['n_neurons'], architecture_params)
        self.mean_head = nn.Linear(architecture_params['n_neurons'], action_dim)
        self.log_std_head = nn.Linear(architecture_params['n_neurons'], action_dim)

        self.bound_mean = architecture_params['action_bounds'].mean(0)
        self.bound_scale = architecture_params['action_bounds'][1] - self.bound_mean

    def forward(self, state):
        x = F.relu(self.net(state))
        mean = self.mean_head(x)
        log_std = self.log_std_head(x)

        log_std = torch.clamp(log_std, min=-20, max=2)
        return mean, log_std

    def sample(self, state, test=False):

        mean, log_std = self.forward(state)
        std = log_std.exp()

        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample() if not test else mean

        action = torch.tanh(x_t)

        log_prob = normal.log_prob(x_t) - torch.log(1 - action.pow(2) + 1e-6)
        log_prob = log_prob.sum(dim=-1, keepdim=True) - torch.log(self.bound_scale.abs()).sum().unsqueeze(0)

        scaled_action = action * self.bound_scale + self.bound_mean

        return scaled_action, (log_prob,)


class Critic(nn.Module):

    def __init__(self, state_dim, action_dim=0, architecture_params=None):
        super(Critic, self).__init__()

        if action_dim != 0:
            self.net = MLP(state_dim + action_dim, 1, architecture_params)
        else:
            self.net = MLP(state_dim, 1, architecture_params)

    def forward(self, state, action=None):
        x = torch.cat([state, action], dim=-1) if action is not None else state
        q = self.net(x)
        return q













