import numpy as np
import math

import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
# from algorithms.utils.mani_skill_learn.networks.backbones.pointnet import getPointNet
from algorithms.legorl.ppobc.cnn import CNNLayer


class ActorCritic(nn.Module):

    def __init__(self, envs, obs_shape, states_shape, actions_shape, initial_std, model_cfg, asymmetric=False):
        super(ActorCritic, self).__init__()

        self.asymmetric = asymmetric

        if model_cfg is None:
            actor_hidden_dim = [256, 256, 256]
            critic_hidden_dim = [256, 256, 256]
            activation = get_activation("selu")
        else:
            actor_hidden_dim = model_cfg['pi_hid_sizes']
            critic_hidden_dim = model_cfg['vf_hid_sizes']
            activation = get_activation(model_cfg['activation'])

        # Policy
        actor_layers = []
        actor_layers.append(nn.Linear(*obs_shape, actor_hidden_dim[0]))
        actor_layers.append(activation)
        for l in range(len(actor_hidden_dim)):
            if l == len(actor_hidden_dim) - 1:
                actor_layers.append(nn.Linear(actor_hidden_dim[l], *actions_shape))
            else:
                actor_layers.append(nn.Linear(actor_hidden_dim[l], actor_hidden_dim[l + 1]))
                actor_layers.append(activation)
        self.actor = nn.Sequential(*actor_layers)

        # Value function
        critic_layers = []
        if self.asymmetric:
            critic_layers.append(nn.Linear(*states_shape, critic_hidden_dim[0]))
        else:
            critic_layers.append(nn.Linear(*obs_shape, critic_hidden_dim[0]))
        critic_layers.append(activation)
        for l in range(len(critic_hidden_dim)):
            if l == len(critic_hidden_dim) - 1:
                critic_layers.append(nn.Linear(critic_hidden_dim[l], 1))
            else:
                critic_layers.append(nn.Linear(critic_hidden_dim[l], critic_hidden_dim[l + 1]))
                critic_layers.append(activation)
        self.critic = nn.Sequential(*critic_layers)

        print(self.actor)
        print(self.critic)

        # Action noise
        self.log_std = nn.Parameter(np.log(initial_std) * torch.ones(*actions_shape))

        # Initialize the weights like in stable baselines
        actor_weights = [np.sqrt(2)] * len(actor_hidden_dim)
        actor_weights.append(0.01)
        critic_weights = [np.sqrt(2)] * len(critic_hidden_dim)
        critic_weights.append(1.0)
        self.init_weights(self.actor, actor_weights)
        self.init_weights(self.critic, critic_weights)

    @staticmethod
    def init_weights(sequential, scales):
        [torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in
         enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))]

    def forward(self):
        raise NotImplementedError

    def act(self, observations, states):
        actions_mean = self.actor(observations)

        covariance = torch.diag(self.log_std.exp() * self.log_std.exp())
        distribution = MultivariateNormal(actions_mean, scale_tril=covariance)

        actions = distribution.sample()
        actions_log_prob = distribution.log_prob(actions)
        if self.asymmetric:
            value = self.critic(states)
        else:
            value = self.critic(observations)

        return actions.detach(), actions_log_prob.detach(), value.detach(), actions_mean.detach(), self.log_std.repeat(actions_mean.shape[0], 1).detach()

    def act_inference(self, observations):
        actions_mean = self.actor(observations)
        return actions_mean

    def evaluate(self, observations, states, actions):
        actions_mean = self.actor(observations)

        covariance = torch.diag(self.log_std.exp() * self.log_std.exp())
        distribution = MultivariateNormal(actions_mean, scale_tril=covariance)

        actions_log_prob = distribution.log_prob(actions)
        entropy = distribution.entropy()

        return actions_log_prob

class BCTransformer(nn.Module):

    def __init__(self, envs, obs_shape, states_shape, actions_shape, initial_std, model_cfg, d_model=256, asymmetric=False):
        super(BCTransformer, self).__init__()

        self.asymmetric = asymmetric
        num_actions_encoder_layers = 3
        num_obs_encoder_layers = 3

        activation = get_activation(model_cfg['activation'])

        actor_layers = []
        actor_layers.append(nn.Linear(actions_shape[0], d_model))
        actor_layers.append(activation)
        for l in range(num_actions_encoder_layers):
            if l == num_actions_encoder_layers - 1:
                actor_layers.append(nn.Linear(d_model, d_model))
            else:
                actor_layers.append(nn.Linear(d_model, d_model))
                actor_layers.append(activation)
        self.action_encoder = nn.Sequential(*actor_layers)

        # Value function
        critic_layers = []
        critic_layers.append(nn.Linear(obs_shape[0], d_model))
        critic_layers.append(activation)
        for l in range(num_obs_encoder_layers):
            if l == num_obs_encoder_layers - 1:
                critic_layers.append(nn.Linear(d_model, d_model))
            else:
                critic_layers.append(nn.Linear(d_model, d_model))
                critic_layers.append(activation)
        self.observation_encoder = nn.Sequential(*critic_layers)

        # self.transformer_model = nn.Transformer(d_model=256, nhead=2, num_encoder_layers=4, num_decoder_layers=4, dim_feedforward=256, batch_first=True)
        encoder_layers = nn.TransformerEncoderLayer(d_model=256, nhead=2, dim_feedforward=256, dropout=0.01, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=4)

        decoder_layers = []
        decoder_layers.append(nn.Linear(d_model, d_model))
        decoder_layers.append(activation)
        for l in range(num_obs_encoder_layers):
            if l == num_obs_encoder_layers - 1:
                decoder_layers.append(nn.Linear(d_model, actions_shape[0]))
            else:
                decoder_layers.append(nn.Linear(d_model, d_model))
                decoder_layers.append(activation)
        self.decoder = nn.Sequential(*decoder_layers)


    def forward(self, obs, last_actions):
        obs_encode = self.observation_encoder(obs)
        # last_action_encode = self.action_encoder(last_actions)
        # tgt = torch.cat([obs_encode, last_action_encode], dim=1)
        out = self.transformer_encoder(obs_encode, None)
        # print(out.shape)
        out = self.decoder(out)
        # print(out.shape)
        return out

def get_activation(act_name):
    if act_name == "elu":
        return nn.ELU()
    elif act_name == "selu":
        return nn.SELU()
    elif act_name == "relu":
        return nn.ReLU()
    elif act_name == "crelu":
        return nn.ReLU()
    elif act_name == "lrelu":
        return nn.LeakyReLU()
    elif act_name == "tanh":
        return nn.Tanh()
    elif act_name == "sigmoid":
        return nn.Sigmoid()
    else:
        print("invalid activation function!")
        return None
