from typing import Union

import numpy as np
import torch
from torch import nn
from torch.distributions import Normal,MultivariateNormal
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import TanhTransform

Activation = Union[str, nn.Module]

_str_to_activation = {
    'relu': nn.ReLU(),
    'tanh': nn.Tanh(),
    'leaky_relu': nn.LeakyReLU(),
    'sigmoid': nn.Sigmoid(),
    'selu': nn.SELU(),
    'softplus': nn.Softplus(),
    'identity': nn.Identity(),
}

device = None


def init_gpu(use_gpu=True, gpu_id=0):
    global device
    if torch.cuda.is_available() and use_gpu:
        device = torch.device("cuda:" + str(gpu_id))
        print("Using GPU id {}".format(gpu_id))
    else:
        device = torch.device("cpu")
        print("GPU not detected. Defaulting to CPU.")


def set_device(gpu_id):
    torch.cuda.set_device(gpu_id)

def from_numpy(*args, **kwargs):
    return torch.from_numpy(*args, **kwargs).float().to(device)

def to_numpy(tensor):
    return tensor.to('cpu').detach().numpy()

#复制tensor数量
def extend_and_repeat(tensor,dim,repeat):
    # Extend and repeast the tensor along dim axie and repeat it
    ones_shpae = [1 for _ in range(tensor.ndim + 1)]
    ones_shpae[dim] = repeat
    return torch.unsqueeze(tensor,dim) * tensor.new_ones(ones_shpae)

#目标更新
def soft_target_updata(network,target_network,soft_target_update_rate):
    target_network_params = {k:v for k,v in target_network.named_parameters()}
    for k,v in network.named_parameters():
        target_network_params[k].data = (
            (1 - soft_target_update_rate) * target_network_params[k].data + soft_target_update_rate * v.data
        )

def multiple_action_q_function(forward):
    # Forward the q function with multiple actions on each state, to be used as a decorator
    def wrapped(self,ob_n,ac_n,**kwargs):
        multiple_actions = False
        batch_size = ob_n.shape[0]
        if ac_n.ndim == 3 and ob_n.ndim == 2:
            multiple_actions = True
            ob_n = extend_and_repeat(ob_n,1,ac_n.shape[1]).reshape(-1,ob_n.shape[-1])
            ac_n = ac_n.reshape(-1,ac_n.shape[-1])
        q_values = forward(self,ob_n,ac_n,**kwargs)
        if multiple_actions:
            q_values = q_values.reshape(batch_size,-1)
        return q_values
    return wrapped

def creat_fullconnect(input_dim,output_dim,arch,orthogonal_init=False):
    d = input_dim
    modules = []
    hidden_sizes = [int(h) for h in arch.split('-')]

    for hidden_size in hidden_sizes:
        fc = nn.Linear(d, hidden_size)
        if orthogonal_init:
            nn.init.orthogonal_(fc.weight, gain=np.sqrt(2))
        modules.append(fc)
        d = hidden_size
    last_fc = nn.Linear(d, output_dim)
    if orthogonal_init:
        nn.init.orthogonal_(last_fc.weight, gain=np.sqrt(2))
    modules.append(last_fc)
    return modules

#普通的全连接层
class FullConnecteNetwork(nn.Module):
    # arch是隐藏层的维度
    # orthogonal_init用（半）是否用正交矩阵来初始化网络的参数
    def __init__(self,input_dim,output_dim,arch='256-256',orthogonal_init=False,dropout_rate=None):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.arch = arch
        self.orthogonal_init = orthogonal_init

        d = input_dim
        modules = []
        hidden_sizes = [int(h) for h in arch.split('-')]

        for hidden_size in hidden_sizes:
            fc = nn.Linear(d,hidden_size)
            if orthogonal_init:
                nn.init.orthogonal_(fc.weight,gain=np.sqrt(2))
            modules.append(fc)
            if dropout_rate is not None:
                modules.append(nn.Dropout(dropout_rate))
            modules.append(nn.ReLU())
            d = hidden_size

        last_fc = nn.Linear(d,output_dim)
        if orthogonal_init:
            nn.init.orthogonal_(last_fc.weight,gain=np.sqrt(2))
        modules.append(last_fc)
        self.network = nn.Sequential(*modules)

    def forward(self,input_tensor):
        return self.network(input_tensor)

    def save(self,path):
        torch.save(self.state_dict(),path)

#重参数化tanh高斯模型（tanh是(exp(x)-exp(-x))/(exp(x)+exp(-x))）
'''
TransformedDistribution用来扩展第一个参数即基础分布，记输出为Y，基础分布为X,用来扩展的函数是f为第二个参数。
可得：log p(Y) = log p(X) + log [det(dX/dY)]
rsample:使用X重参数化采样，然后用f转换。
log_prob:使用log abs det jacobian和X来计算log_Π(a|s)，再用f,得到log_Π。这个是用来更新policy的，这个相当于准确的Q值。

TanhTransform使用tanh来转换基本分布
'''
class ReparameterizedTanhGaussian(nn.Module):
    def __init__(self,log_std_min=-20.0,log_std_max=2.0,no_tanh=False):
        super().__init__()
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        self.no_tanh = no_tanh

    def log_prob(self,mean,log_std,sample):
        log_std = torch.clamp(log_std,self.log_std_min,self.log_std_max)#clamp将第一个参数的所有值都限制在后面的参数范围中
        std = torch.exp(log_std)
        if self.no_tanh:
            action_distribution = Normal(mean,std)
        else:
            action_distribution = TransformedDistribution(Normal(mean,std),TanhTransform(cache_size=1))
        return torch.sum(action_distribution.log_prob(sample),dim=-1)

    def forward(self,mean,log_std,deterministic=False):
        log_std = torch.clamp(log_std,self.log_std_min,self.log_std_max)
        std = torch.exp(log_std)
        if self.no_tanh:
            action_distribution = Normal(mean,std)
        else:
            action_distribution = TransformedDistribution(Normal(mean,std),TanhTransform(cache_size=1))
        if deterministic:
            action_sample = torch.tanh(mean)
        else:
            action_sample = action_distribution.rsample()
        log_prob = torch.sum(action_distribution.log_prob(action_sample),dim=-1)
        return action_sample,log_prob

#使用tanh高斯的策略模型
class TanhGaussianPolicy(nn.Module):
    def __init__(self,ob_dim,ac_dim,arch='256-256',log_std_multiplier=1.0,log_std_offset=-1.0,orthogonal_init=False,no_tanh=False,dropout=None):
        super().__init__()
        self.ob_dim = ob_dim
        self.ac_dim = ac_dim
        self.arch = arch
        self.orthogonal_init = orthogonal_init
        self.no_tanh = no_tanh
        self.base_network = FullConnecteNetwork(ob_dim,2*ac_dim,arch,orthogonal_init,dropout)
        self.log_std_multiplier = Scalar(log_std_multiplier)
        self.log_std_offset = Scalar(log_std_offset)
        self.tanh_gaussian = ReparameterizedTanhGaussian(no_tanh=no_tanh)

    def log_prob(self,ob_n,ac_n):
        if ac_n.ndim == 3:
            ob_n = extend_and_repeat(ob_n,1,ac_n.shape[1])
        base_network_output = self.base_network(ob_n)
        mean,log_std = torch.split(base_network_output,self.ac_dim,dim=-1)
        log_std = self.log_std_multiplier() * log_std + self.log_std_offset()
        return self.tanh_gaussian.log_prob(mean,log_std,ac_n)

    def forward(self,ob_n,deterministic=False,repeat=None):
        if repeat is not None:
            ob_n = extend_and_repeat(ob_n,1,repeat)
        base_network_output = self.base_network(ob_n)
        mean,log_std = torch.split(base_network_output,self.ac_dim,dim=-1)
        log_std = self.log_std_multiplier() * log_std + self.log_std_offset()
        return self.tanh_gaussian(mean,log_std,deterministic)

    def get_x_and_action(self,ob_n,deterministic=False,repeat=None):
        if repeat is not None:
            ob_n = extend_and_repeat(ob_n,1,repeat)
        base_network_output = self.base_network(ob_n)
        mean, log_std = torch.split(base_network_output,self.ac_dim,dim=-1)
        log_std = self.log_std_multiplier() * log_std + self.log_std_offset()
        x_t = Normal(mean,log_std.exp()).rsample()
        return x_t,self.tanh_gaussian(mean,log_std,deterministic)

    def save(self,path):
        torch.save(self.state_dict(),path)

    def load(self,path):
        self.load_state_dict(torch.load(path))

class NormalTahnNetwork(nn.Module):
    def __init__(self, ob_dim, ac_dim, arch='256-256', log_std_min=-10.0, log_std_max=2.0,
                 orthogonal_init=False, no_tanh=False, dropout=None):
        super().__init__()
        self.ob_dim = ob_dim
        self.ac_dim = ac_dim
        self.arch = arch
        self.orthogonal_init = orthogonal_init
        self.no_tanh = no_tanh
        self.base_network = FullConnecteNetwork(ob_dim, ac_dim, arch, orthogonal_init, dropout)
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        self.log_std = nn.Parameter(torch.zeros(ac_dim, dtype=torch.float32))

    def get_dist(self,ob_n,repeat=None):
        if repeat is not None:
            ob_n = extend_and_repeat(ob_n,1,repeat)
        mean = self.base_network(ob_n)
        #mean, log_std = torch.split(base_network_output, self.ac_dim, dim=-1)
        #log_std = torch.clamp(log_std,self.log_std_min,self.log_std_max)
        #std = log_std.exp()
        log_std = self.log_std.clamp(self.log_std_min,self.log_std_max)
        std = torch.exp(log_std)
        #scale_tril = torch.tensor([to_numpy(torch.diag(i)) for i in std]).to('cuda')
        scale_tril = torch.diag(std)
        return MultivariateNormal(mean, scale_tril=scale_tril)

    def forward(self,ob_n,deterministic=False,repeat=None):
        dist= self.get_dist(ob_n,repeat)
        if deterministic:
            action = dist.mean
        else:
            action = dist.sample()
        log_std = dist.log_prob(action)
        return action,log_std

    def save(self,path):
        torch.save(self.state_dict(),path)

    def load(self,path):
        self.load_state_dict(torch.load(path))


#用于采样的策略
class SamplerPolicy(object):
    def __init__(self,policy,device):
        self.policy = policy
        self.device = device

    def call(self,ob_n,deterministic=False):
        with torch.no_grad():
            ob_n = torch.tensor(ob_n,dtype=torch.float32,device=self.device)
            actions,_ = self.policy(ob_n,deterministic)
            actions = actions.cpu().numpy()
        return actions

class SamplerPolicy_Ae(object):
    def __init__(self,policy,device,encoder):
        self.policy = policy
        self.device = device
        self.encoder = encoder

    def call(self,ob_n,deterministic=False):
        with torch.no_grad():
            ob_n = torch.tensor(ob_n,dtype=torch.float32,device=self.device)
            re_ob = self.encoder(ob_n)
            actions,_ = self.policy(re_ob,deterministic)
            actions = actions.cpu().numpy()
        return actions

#用于Q函数的全连接层，和普通全连接层的不同在于forward输出的时候将状态和动作合并，然后输出。
class FullConnectedQFunction(nn.Module):
    def __init__(self,ob_dim,ac_dim,arch='256-256',orthogonal_init=False,dropout_rate=None):
        super().__init__()
        self.ob_dim = ob_dim
        self.ac_dim = ac_dim
        self.arch = arch
        self.orthogonal_init = orthogonal_init
        self.network = FullConnecteNetwork(ob_dim + ac_dim,1,arch,orthogonal_init,dropout_rate)

    def forward(self,ob_n,ac_n):
        input_tensor = torch.cat([ob_n,ac_n],dim=-1)
        return torch.squeeze(self.network(input_tensor),dim=-1)

    def save(self,path):
        torch.save(self.state_dict(),path)

    def load(self,path):
        self.load_state_dict(torch.load(path))

#nn.Parameter是将输入转变成可以用来随着模型一起训练的参数并将其绑定到模型中。目的是为了训练那些本不在模型中，但是需要训练的参数。在SAC中训练alpha参数
class Scalar(nn.Module):
    def __init__(self, init_value):
        super().__init__()
        self.constant = nn.Parameter(
            torch.tensor(init_value, dtype=torch.float32)
        )

    def forward(self):
        return self.constant

class K_QNet(nn.Module):
    def __init__(self,ob_dim,ac_dim,arch='256-256',orthogonal_init=False,K=3,dropout_rate=None):
        super().__init__()
        self.ob_dim = ob_dim
        self.ac_dim = ac_dim
        self.arch = arch
        self.orthogonal_init = orthogonal_init
        self.K = K
        self.q_nets = []
        self.params = []
        for i in range(K):
            self.q_nets.append(FullConnecteNetwork(self.ob_dim + self.ac_dim,1,self.arch,self.orthogonal_init,dropout_rate))
            self.params += list(self.q_nets[-1].parameters())

    def get_params(self):
        return self.params

    def forward(self,ob_n,ac_n):
        input_tensor = torch.cat([ob_n, ac_n], dim=-1)
        q_value = []
        for i in range(self.K):
            q_value.append(self.q_nets[i](input_tensor).cpu().tolist())
        q_value = np.array(q_value)
        q_values = from_numpy(q_value).cuda(0)
        return q_values.squeeze()

    def Q(self,ob_n,ac_n,i):
        input_tensor = torch.cat([ob_n, ac_n], dim=-1)
        return self.q_nets[i](input_tensor).squeeze()

    def to(self,device):
        for i in range(self.K):
            self.q_nets[i].to(device)

class DynamicsNet(nn.Module):
    def __init__(self, ob_dim, ac_dim, arch='64-64',device = 'cuda',
                 s_shift = None,
                 s_scale = None,
                 a_shift = None,
                 a_scale = None,
                 out_shift = None,
                 out_scale = None,
                 out_dim = None,
                 residual = True,
                 use_mask = True,
                 ):
        super(DynamicsNet, self).__init__()

        self.ob_dim, self.ac_dim, self.arch = ob_dim, ac_dim, arch
        self.out_dim = ob_dim if out_dim is None else out_dim
        # hidden layers
        self.fc_layers = nn.ModuleList(creat_fullconnect(self.ob_dim + self.ac_dim,self.out_dim,self.arch))
        self.nonlinearity = torch.relu
        self.residual, self.use_mask = residual, use_mask
        self._apply_out_transforms = True
        self.device = device
        self.set_transformations(s_shift, s_scale, a_shift, a_scale, out_shift, out_scale)

    #设置状态，动作和输出的转移
    def set_transformations(self, s_shift=None, s_scale=None,
                            a_shift=None, a_scale=None,
                            out_shift=None, out_scale=None):

        if s_shift is None:
            self.s_shift     = torch.zeros(self.ob_dim)
            self.s_scale    = torch.ones(self.ob_dim)
            self.a_shift     = torch.zeros(self.ac_dim)
            self.a_scale    = torch.ones(self.ac_dim)
            self.out_shift   = torch.zeros(self.out_dim)
            self.out_scale  = torch.ones(self.out_dim)
        elif type(s_shift) == torch.Tensor:
            self.s_shift, self.s_scale = s_shift, s_scale
            self.a_shift, self.a_scale = a_shift, a_scale
            self.out_shift, self.out_scale = out_shift, out_scale
        elif type(s_shift) == np.ndarray:
            self.s_shift     = torch.from_numpy(np.float32(s_shift))
            self.s_scale    = torch.from_numpy(np.float32(s_scale))
            self.a_shift     = torch.from_numpy(np.float32(a_shift))
            self.a_scale    = torch.from_numpy(np.float32(a_scale))
            self.out_shift   = torch.from_numpy(np.float32(out_shift))
            self.out_scale  = torch.from_numpy(np.float32(out_scale))
        else:
            print("Unknown type for transformations")
            quit()

        self.s_shift, self.s_scale = self.s_shift.to(self.device), self.s_scale.to(self.device)
        self.a_shift, self.a_scale = self.a_shift.to(self.device), self.a_scale.to(self.device)
        self.out_shift, self.out_scale = self.out_shift.to(self.device), self.out_scale.to(self.device)
        # if some state dimensions have very small variations, we will force it to zero
        self.mask = self.out_scale >= 1e-8

        self.transformations = dict(s_shift=self.s_shift, s_scale=self.s_scale,
                                    a_shift=self.a_shift, a_scale=self.a_scale,
                                    out_shift=self.out_shift, out_scale=self.out_scale)

    def forward(self, X):
        ob,ac = X
        if ob.dim() != ac.dim():
            print("State and action inputs should be of the same size")
        # normalize inputs
        ob_in = (ob- self.s_shift)/(self.s_scale + 1e-8)
        ac_in = (ac - self.a_shift)/(self.a_scale + 1e-8)
        out = torch.cat([ob_in, ac_in], -1)
        for i in range(len(self.fc_layers) - 1):
            out = self.fc_layers[i](out)
            out = self.nonlinearity(out)
        out = self.fc_layers[-1](out)
        if self._apply_out_transforms:
            out = out * (self.out_scale + 1e-8) + self.out_shift
            out = out * self.mask if self.use_mask else out
            out = out + ob if self.residual else out
        return out

    def get_params(self):
        network_weights = [p.data for p in self.parameters()]
        transforms = (self.s_shift, self.s_scale,
                      self.a_shift, self.a_scale,
                      self.out_shift, self.out_scale)
        return dict(weights=network_weights, transforms=transforms)

    def set_params(self, new_params):
        new_weights = new_params['weights']
        s_shift, s_scale, a_shift, a_scale, out_shift, out_scale = new_params['transforms']
        for idx, p in enumerate(self.parameters()):
            p.data = new_weights[idx]
        self.set_transformations(s_shift, s_scale, a_shift, a_scale, out_shift, out_scale)

class RewardNet(nn.Module):
    def __init__(self, ob_dim, ac_dim,
                 arch = '64-64',device = 'cuda',
                 s_shift = None,
                 s_scale = None,
                 a_shift = None,
                 a_scale = None,
                 ):
        super(RewardNet, self).__init__()
        self.ob_dim, self.ac_dim, self.arch = ob_dim, ac_dim, arch
        self.fc_layers = nn.ModuleList(creat_fullconnect(ob_dim + ac_dim + ob_dim,1,arch))
        self.nonlinearity = torch.relu
        self.device =device
        self.set_transformations(s_shift, s_scale, a_shift, a_scale)

    def set_transformations(self, s_shift=None, s_scale=None,
                            a_shift=None, a_scale=None,
                            out_shift=None, out_scale=None):

        if s_shift is None:
            self.s_shift, self.s_scale       = torch.zeros(self.ob_dim), torch.ones(self.ob_dim)
            self.a_shift, self.a_scale       = torch.zeros(self.ac_dim), torch.ones(self.ac_dim)
            self.sp_shift, self.sp_scale     = torch.zeros(self.ob_dim), torch.ones(self.ob_dim)
            self.out_shift, self.out_scale   = 0.0, 1.0
        elif type(s_shift) == torch.Tensor:
            self.s_shift, self.s_scale       = s_shift, s_scale
            self.a_shift, self.a_scale       = a_shift, a_scale
            self.sp_shift, self.sp_scale     = s_shift, s_scale
            self.out_shift, self.out_scale   = out_shift, out_scale
        elif type(s_shift) == np.ndarray:
            self.s_shift, self.s_scale       = torch.from_numpy(s_shift).float(), torch.from_numpy(s_scale).float()
            self.a_shift, self.a_scale       = torch.from_numpy(a_shift).float(), torch.from_numpy(a_scale).float()
            self.sp_shift, self.sp_scale     = torch.from_numpy(s_shift).float(), torch.from_numpy(s_scale).float()
            self.out_shift, self.out_scale   = out_shift, out_scale
        else:
            print("Unknown type for transformations")
            quit()

        device = next(self.parameters()).data.device
        self.s_shift, self.s_scale   = self.s_shift.to(self.device), self.s_scale.to(self.device)
        self.a_shift, self.a_scale   = self.a_shift.to(self.device), self.a_scale.to(self.device)
        self.sp_shift, self.sp_scale = self.sp_shift.to(self.device), self.sp_scale.to(self.device)

        self.transformations = dict(s_shift=self.s_shift, s_scale=self.s_scale,
                                    a_shift=self.a_shift, a_scale=self.a_scale,
                                    out_shift=self.out_shift, out_scale=self.out_scale)

    def forward(self, X):
        # The reward will be parameterized as r = f_theta(s, a, s').
        # If sp is unavailable, we can re-use s as sp, i.e. sp \approx s
        ob,ac,ob_pre = X
        if ob.dim() != ac.dim():
            print("State and action inputs should be of the same size")
        # normalize all the inputs
        ob = (ob - self.s_shift) / (self.s_scale + 1e-8)
        ac = (ac - self.a_shift) / (self.a_scale + 1e-8)
        ob_pre = (ob_pre - self.sp_shift) / (self.sp_scale + 1e-8)
        out = torch.cat([ob, ac, ob_pre], -1)
        for i in range(len(self.fc_layers) - 1):
            out = self.fc_layers[i](out)
            out = self.nonlinearity(out)
        out = self.fc_layers[-1](out)
        out = out * (self.out_scale + 1e-8) + self.out_shift
        return out

    def get_params(self):
        network_weights = [p.data for p in self.parameters()]
        transforms = (self.s_shift, self.s_scale,
                      self.a_shift, self.a_scale)
        return dict(weights=network_weights, transforms=transforms)

    def set_params(self, new_params):
        new_weights = new_params['weights']
        s_shift, s_scale, a_shift, a_scale = new_params['transforms']
        for idx, p in enumerate(self.parameters()):
            p.data = new_weights[idx]
        self.set_transformations(s_shift, s_scale, a_shift, a_scale)