import torch
import torch.nn as nn


class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, goal_dim, max_action, use_batch_norm=0, units=128):
        super(Actor, self).__init__()

        if use_batch_norm:
            self.layers = nn.Sequential(
                nn.Linear(state_dim + goal_dim, units),
                nn.LeakyReLU(),
                nn.BatchNorm1d(units),
                nn.Linear(units, units),
                nn.LeakyReLU(),
                nn.BatchNorm1d(units),
                nn.Linear(units, units),
                nn.LeakyReLU(),
                nn.Linear(units, action_dim)
            )
        else:
            self.layers = nn.Sequential(
                nn.Linear(state_dim + goal_dim, units),
                nn.LeakyReLU(),
                nn.Linear(units, units),
                nn.LeakyReLU(),
                nn.Linear(units, units),
                nn.LeakyReLU(),
                nn.Linear(units, action_dim)
            )

        self.max_action = max_action

    def forward(self, state, goal=None):
        if goal is not None:
            state = torch.cat([state, goal], dim=1)
        x = self.layers(state)
        return self.max_action * torch.tanh(x)


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, goal_dim, n_qs=2, use_batch_norm=0, units=128):
        super(Critic, self).__init__()

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.goal_dim = goal_dim
        self.units = units
        self.n_qs = n_qs

        models = []
        for i in range(n_qs):
            models.append(self.make_q_network(use_batch_norm=use_batch_norm))

        self.models = nn.ModuleList(models)

    def make_q_network(self, use_batch_norm):
        if use_batch_norm:
            return nn.Sequential(
                nn.Linear(self.state_dim + self.action_dim + self.goal_dim, self.units),
                nn.LeakyReLU(),
                nn.BatchNorm1d(self.units),
                nn.Linear(self.units, self.units),
                nn.LeakyReLU(),
                nn.BatchNorm1d(self.units),
                nn.Linear(self.units, self.units),
                nn.LeakyReLU(),
                nn.Linear(self.units, 1)
            )
        return nn.Sequential(
            nn.Linear(self.state_dim + self.action_dim + self.goal_dim, self.units),
            nn.LeakyReLU(),
            nn.Linear(self.units, self.units),
            nn.LeakyReLU(),
            nn.Linear(self.units, self.units),
            nn.LeakyReLU(),
            nn.Linear(self.units, 1))

    def compute_ensemble_diversity(self):
        params = []
        for i in range(self.n_qs):
            params.append(nn.utils.parameters_to_vector(self.models[i].parameters()))
        params = torch.stack(params, 0)

        # Theil index
        l = torch.linalg.norm(params, dim=-1)
        l_mean = torch.mean(l)
        l_normalized = l / l_mean
        T = torch.sum(l_normalized * torch.log(l_normalized)) / l.shape[0]
        return T

    def forward(self, state, action, goal=None):
        if goal is not None:
            sa = torch.cat([state, action, goal], 1)
        else:
            sa = torch.cat([state, action], 1)

        outs = []
        for i in range(self.n_qs):
            outs.append(self.models[i](sa))
        return outs

    def q_min(self, state, action):
        qs = self.forward(state, action)
        qs = torch.stack(qs, dim=0)
        q_min = torch.min(qs, 0)[0]
        return q_min