import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
from torch.distributions.multivariate_normal import MultivariateNormal


class Policy(nn.Module):

    def __init__(self, Representation, z_dim, a_dim, a_max):
        super().__init__()

        self.a_dim = a_dim
        self.a_max = a_max
        hidden_fc = 64

        self.body = Representation

        self.actor = nn.Sequential(nn.Linear(z_dim, hidden_fc),
                                   nn.ReLU(),
                                   nn.Linear(hidden_fc, hidden_fc),
                                   nn.ReLU(),
                                   nn.Linear(hidden_fc, hidden_fc),
                                   nn.ReLU(),
                                   nn.Linear(hidden_fc, a_dim * 2))
        self.value = nn.Sequential(nn.Linear(z_dim, hidden_fc),
                                   nn.ReLU(),
                                   nn.Linear(hidden_fc, hidden_fc),
                                   nn.ReLU(),
                                   nn.Linear(hidden_fc, hidden_fc),
                                   nn.ReLU(),
                                   nn.Linear(hidden_fc, 1))

    def forward(self, x):

        h = x

        policy = self.actor(h)
        v = self.value(h)
        mu = torch.tanh(policy[:, 0:self.a_dim]) * self.a_max
        sigma = torch.sigmoid(policy[:, self.a_dim:2 * self.a_dim]) * 1 + 0.0001
        v = v[:, -1]

        return mu, sigma, v


class ActorCritic(nn.Module):

    def __init__(self, Representation, z_dim, a_dim, a_max):
        super(ActorCritic, self).__init__()

        self.network = Policy(Representation, z_dim, a_dim, a_max)

    def forward(self, st):

        mu, sigma, v = self.network(st)

        return mu, sigma, v

    def get_action(self, st, test=False):

        mu, sigma, v = self.forward(st)

        if test:
            return mu, None, None

        m = MultivariateNormal(mu, torch.diag_embed(sigma))
        a = m.sample()
        logprob = m.log_prob(a)
        H = m.entropy()

        return a, logprob, sigma.detach().cpu().numpy()

    def evaluate(self, st, at):

        mu, sigma, v = self.forward(st)

        m = MultivariateNormal(mu, torch.diag_embed(sigma))
        logprob = m.log_prob(at)
        H = m.entropy()

        return logprob, v, H