import numpy as np
import scipy.signal
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
import itertools
from stable_baselines3 import SAC


def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])


LOG_STD_MAX = 2
LOG_STD_MIN = -20

class SquashedGaussianMLPActor(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit, source_model, device):
        super().__init__()
        self.layer_1 =  nn.Linear(obs_dim, hidden_sizes)
        self.activation = activation
        self.layer_2 =  nn.Linear(hidden_sizes, hidden_sizes)
        self.source_layer_1 = copy.deepcopy(source_model.latent_pi[0].to(device))
        self.source_layer_2 = copy.deepcopy(source_model.latent_pi[2].to(device))
        
        for p in itertools.chain(self.source_layer_1.parameters(), self.source_layer_2.parameters()):
            p.requires_grad = False

        self.mu_layer = nn.Linear(hidden_sizes, act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes, act_dim)
        self.act_limit = act_limit

        self.test_p = 0.0

        self.to(device)

    def forward(self, obs, src_obs, p, deterministic=False, with_logprob=True):
        if src_obs is not None:
            src_out = self.source_layer_1(src_obs)
            tar_out = self.layer_1(obs)
            tar_out = self.activation(p * tar_out + (1.0-p) * src_out)
            src_out = self.activation(src_out)

            src_out = self.source_layer_2(src_out)
            tar_out = self.layer_2(tar_out)
            tar_out = self.activation(p * tar_out + (1.0-p) * src_out)

            mu = self.mu_layer(tar_out) 
            log_std = torch.clamp(self.log_std_layer(tar_out), LOG_STD_MIN, LOG_STD_MAX)
            std = torch.exp(log_std)

            pi_distribution = Normal(mu, std)
            if deterministic:
                # Only used for evaluating policy at test time.
                pi_action = mu
            else:
                pi_action = pi_distribution.rsample()

            if with_logprob:
                logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
                logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
            else:
                logp_pi = None

            pi_action = torch.tanh(pi_action)
            pi_action = self.act_limit * pi_action
            return pi_action, logp_pi

        else:
            out = self.activation(self.layer_1(obs))
            out = self.layer_2(out)
            out = self.activation(out)
            mu = self.mu_layer(out)

            log_std = self.log_std_layer(out)
            log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
            std = torch.exp(log_std)
    
            pi_distribution = Normal(mu, std)
            if deterministic:
                # Only used for evaluating policy at test time.
                pi_action = mu
            else:
                pi_action = pi_distribution.rsample()

            if with_logprob:
                logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
                logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
            else:
                logp_pi = None

            pi_action = torch.tanh(pi_action)
            pi_action = self.act_limit * pi_action

            return pi_action, logp_pi


class MLPQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, source_model, device):
        super().__init__()
        self.layer_1 =  nn.Linear(obs_dim+act_dim, hidden_sizes)
        self.activation = nn.ReLU()
        self.layer_2 =  nn.Linear(hidden_sizes, hidden_sizes)
        self.out_layer = nn.Linear(hidden_sizes, 1)

        self.source_layer_1 = copy.deepcopy(source_model[0].to(device))
        self.source_layer_2 = copy.deepcopy(source_model[2].to(device))

        for p in itertools.chain(self.source_layer_1.parameters(), self.source_layer_2.parameters()):
            p.requires_grad = False

        self.to(device)

    def forward(self, obs, act, src_obs, act_obs, p):
        x = torch.cat([obs, act], dim=-1)
        if src_obs is not None and act_obs is not None:
            src_x = torch.cat([src_obs, act_obs], dim=-1)
            tar_out = self.layer_1(x)
            src_out = self.source_layer_1(src_x)
            tar_out = self.activation(p * tar_out + (1.0-p)* src_out)
            src_out = self.activation(src_out)

            tar_out = self.layer_2(tar_out)
            src_out = self.source_layer_2(src_out)
            tar_out = self.activation(p * tar_out + (1.0-p)* src_out)

            tar_out = self.out_layer(tar_out)
            return torch.squeeze(tar_out, -1)
        else:
            out = self.activation(self.layer_1(x))
            out = self.activation(self.layer_2(out))
            out = self.out_layer(out)
            return torch.squeeze(out, -1)

class MLPActorCritic(nn.Module):

    def __init__(self, observation_space, action_space, source_model, device, hidden_sizes=256,
                 activation=nn.ReLU()):
        super().__init__()

        obs_dim = observation_space.shape[0]
        act_dim = action_space.shape[0]
        act_limit = action_space.high[0]

        # build policy and value functions
        self.pi = SquashedGaussianMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit, source_model.policy.actor, device) 
        self.q1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation, source_model.policy.critic.qf0, device)
        self.q2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation, source_model.policy.critic.qf1, device)

    def act(self, obs, src_obs, p, deterministic=False):
        with torch.no_grad():
            a, _ = self.pi(obs, src_obs, p, deterministic, False)
            return a.cpu().numpy()

class decoder_network(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_size, device, variation=False, outputScale=None):
        super(decoder_network, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_dim),
        )
        if variation:
            self.logstd = nn.Parameter(torch.zeros(1, output_dim))
        
        self.outputScale = outputScale
        self.to(device)

    def forward(self, input):
        output = self.model(input)
        if self.outputScale is not None:
            output = nn.Tanh()(output)*self.outputScale
        return output