from statistics import mode

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6

# Initialize Policy weights
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)


class ValueNetwork(nn.Module):
    def __init__(self, num_inputs, hidden_dim):
        super(ValueNetwork, self).__init__()

        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)

        self.apply(weights_init_)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x

class LamadaNetwork(nn.Module):#in -out =1
    def __init__(self, num_inputs, hidden_dim):
        super(LamadaNetwork, self).__init__()

        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, num_inputs-1)
        self.apply(weights_init_)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        x=self.linear3(x)
        return x

        
class LamadaNetwork_hardconstraint(nn.Module):
    def __init__(self,input_dim, output_dim,args,hidden_size=64):
        super().__init__()
        self.device =  torch.device("cuda" if args.cuda else "cpu")
        self.in_dim = input_dim
        self.o_dim= output_dim
        self.p = 1#A-->n*p
        self.fc_1 = nn.Linear(input_dim, hidden_size)
        self.fc_2 = nn.Linear(hidden_size, hidden_size)
        self.fc_3 = nn.Linear(hidden_size, output_dim-self.p)

    def forward(self, x): 
        #device = torch.device('cpu')
        x +=  (x[::,:-1] == 0.0).all(dim=-1,keepdim=True).float()*1e-9
        x1 = F.relu(self.fc_1(x))
        x2 = F.relu(self.fc_2(x1))
        out = self.fc_3(x2)

        p=self.p
        n=self.in_dim-1

        A,B = x[::,:-1].unsqueeze(1),x[::,-1:].unsqueeze(1)
#QR     
        if  False:#True:
            Q,R_ = torch.linalg.qr(A.transpose(-1,-2),mode='complete') #q1 q2 r 3*3 1*1
            R = R_[::,p-1:p,::]
            Q1,Q2=Q[::,::,:p],Q[::,::,p:]
            x_hat = Q1.matmul(torch.linalg.inv(R.transpose(-1,-2))).matmul(B)
            x_real = Q2.matmul(out.unsqueeze(-1))+x_hat
            # print(F.mse_loss(A.matmul(x_real),B))
            # B_= R.transpose(-1,-2).matmul(Q1.transpose(-1,-2)).matmul(Q1).matmul(torch.linalg.inv(R.transpose(-1,-2))).matmul(B)
            # # print(F.mse_loss(B_,B))
            # print(F.mse_loss(R.transpose(-1,-2).matmul(Q1.transpose(-1,-2)),A))


#LU
        if  not False:
            A_lu,pivots=A.transpose(-1,-2).lu()
            P,A_L,A_U = torch.lu_unpack(A_lu,pivots)
            L1,L2 = A_L[::,:p,::],A_L[::,p:,::]
            temp =torch.linalg.inv(L1).transpose(-1,-2).matmul(torch.linalg.inv(A_U).transpose(-1,-2)).matmul(B)#p*1
            x_hat = P.matmul(torch.cat([temp,torch.zeros([temp.shape[0],n-p,temp.shape[-1]]).to(self.device)],dim=-2))
            temp = -1*L1.inverse().transpose(-1,-2).matmul(L2.transpose(-1,-2))
            F_ = P.matmul(torch.cat([temp,torch.eye(n-p).expand([temp.shape[0],n-p,n-p]).to(self.device)],dim=-2))

            x_real = F_.matmul(out.unsqueeze(-1))+x_hat
            # print(F.mse_loss(A.matmul(x_real),B))
        return x_real.squeeze(-1)

class LamadaNetwork_hardconstraint_s(nn.Module):
    def __init__(self,input_dim, output_dim,args,hidden_size=64):
        super().__init__()
        self.device =  torch.device("cuda" if args.cuda else "cpu")
        self.agentnum= args.agentnum
        self.in_dim = input_dim
        self.o_dim= output_dim
        self.p = 1#A-->n*p
        self.fc_1 = nn.Linear(input_dim, hidden_size)
        self.fc_2 = nn.Linear(hidden_size, hidden_size)
        self.fc_3 = nn.Linear(hidden_size, output_dim-self.p)

    def forward(self, x, state): 
        #device = torch.device('cpu')
        xu = state
        x1 = F.relu(self.fc_1(xu))
        x2 = F.relu(self.fc_2(x1))
        out = self.fc_3(x2)
        
        x +=  (x[::,:-1] == 0.0).all(dim=-1,keepdim=True).float()*1e-9
        p=self.p
        n=self.agentnum-1#self.in_dim-1
        A,B = x[::,:-1].unsqueeze(1),x[::,-1:].unsqueeze(1)
        A_lu,pivots=A.transpose(-1,-2).lu()
        P,A_L,A_U = torch.lu_unpack(A_lu,pivots)
        L1,L2 = A_L[::,:p,::],A_L[::,p:,::]
        temp =torch.linalg.inv(L1).transpose(-1,-2).matmul(torch.linalg.inv(A_U).transpose(-1,-2)).matmul(B)#p*1
        x_hat = P.matmul(torch.cat([temp,torch.zeros([temp.shape[0],n-p,temp.shape[-1]]).to(self.device)],dim=-2))
        temp = -1*L1.inverse().transpose(-1,-2).matmul(L2.transpose(-1,-2))
        F_ = P.matmul(torch.cat([temp,torch.eye(n-p).expand([temp.shape[0],n-p,n-p]).to(self.device)],dim=-2))
        x_real = F_.matmul(out.unsqueeze(-1))+x_hat

        return x_real.squeeze(-1)


class LamadaNetwork_hardconstraint_sa(nn.Module):
    def __init__(self,input_dim, output_dim,args,hidden_size=64):
        super().__init__()
        self.device =  torch.device("cuda" if args.cuda else "cpu")
        self.in_dim = args.agentnum#input_dim
        self.o_dim= output_dim
        self.p = 1#A-->n*p
        self.fc_1 = nn.Linear(input_dim, hidden_size)
        self.fc_2 = nn.Linear(hidden_size, hidden_size)
        self.fc_3 = nn.Linear(hidden_size, output_dim-self.p)

    def forward(self, x, state, action): 
        #device = torch.device('cpu')
        xu = torch.cat([state, action], -1)
        x1 = F.relu(self.fc_1(xu))
        x2 = F.relu(self.fc_2(x1))
        out = self.fc_3(x2)

        p=self.p
        n=self.in_dim-1
        A,B = x[::,:-1].unsqueeze(1),x[::,-1:].unsqueeze(1)
        A_lu,pivots=A.transpose(-1,-2).lu()
        P,A_L,A_U = torch.lu_unpack(A_lu,pivots)
        L1,L2 = A_L[::,:p,::],A_L[::,p:,::]
        temp =torch.linalg.inv(L1).transpose(-1,-2).matmul(torch.linalg.inv(A_U).transpose(-1,-2)).matmul(B)#p*1
        x_hat = P.matmul(torch.cat([temp,torch.zeros([temp.shape[0],n-p,temp.shape[-1]]).to(self.device)],dim=-2))
        temp = -1*L1.inverse().transpose(-1,-2).matmul(L2.transpose(-1,-2))
        F_ = P.matmul(torch.cat([temp,torch.eye(n-p).expand([temp.shape[0],n-p,n-p]).to(self.device)],dim=-2))
        x_real = F_.matmul(out.unsqueeze(-1))+x_hat

        return x_real.squeeze(-1)


class QNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim):
        super(QNetwork, self).__init__()

        # Q1 architecture
        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, num_actions)

        # Q2 architecture
        self.linear4 = nn.Linear(num_inputs, hidden_dim)
        self.linear5 = nn.Linear(hidden_dim, hidden_dim)
        self.linear6 = nn.Linear(hidden_dim, num_actions)

        self.apply(weights_init_)

    def forward(self, state):
        #xu = torch.cat([state, action], 1)
        
        x1 = F.relu(self.linear1(state))
        x1 = F.relu(self.linear2(x1))
        x1 = self.linear3(x1)

        x2 = F.relu(self.linear4(state))
        x2 = F.relu(self.linear5(x2))
        x2 = self.linear6(x2)

        return x1, x2


class GaussianPolicy(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
        super(GaussianPolicy, self).__init__()
        self.dim_a = num_actions
        
        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3=nn.Linear(hidden_dim,num_actions)

        self.apply(weights_init_)

        # action rescaling

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        action_prob = F.softmax(self.linear3(x),dim=-1)
        return action_prob

    def get_action_info(self,state):
        action_prob = self.forward(state)
        z= action_prob==0.0
        z=z.float()*1e-8
        log_action_prob = torch.log(action_prob+z)
        return action_prob,log_action_prob

    def to(self, device):
  
        return super(GaussianPolicy, self).to(device)

class GaussianPolicy_discrete_Multihead(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim, num_head):
        super(GaussianPolicy_discrete_Multihead, self).__init__()
        
        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.prob_out=nn.ModuleList([nn.Linear(hidden_dim, num_actions) for i in range(num_head)])
        #输出为 action——dim的对于概率
        self.apply(weights_init_)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        probs=[]
        for m in self.prob_out:
            probs+=[F.softmax(m(x),dim=-1)]
        probs=torch.stack(probs,dim=0)
        return probs

    def actions_logprob(self,state):
        action_prob = self.forward(state)
        z= (action_prob==0.0)
        z = z.float()*1e-9
        log_actions_prob = torch.log(action_prob+z)# 3 512 n_index
        return action_prob,log_actions_prob

    def sample(self, state):
        # agent_num*batch_size*action_dim
        action_prob = self.forward(state)
        z= (action_prob==0)
        z = z.float()*1e-8
        log_actions_prob = torch.log(action_prob+z)# 3 512 n_index
        action = torch.multinomial(action_prob.view(-1,action_prob.shape[-1]),1,replacement=True).view(*action_prob.shape[:-1],1)# np.random.choice(range(self.action_dim),p=action_probs)
        log_prob_a = log_actions_prob.gather(-1,action)
        max_action = torch.max(action_prob,dim=-1,keepdim=True)[-1]
        
        return action, log_prob_a, max_action
      
    def to(self, device):
        return super(GaussianPolicy_discrete_Multihead, self).to(device)

class QNetwork_continue(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim):
        super(QNetwork_continue, self).__init__()

        # Q1 architecture
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)

        # Q2 architecture
        self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.linear5 = nn.Linear(hidden_dim, hidden_dim)
        self.linear6 = nn.Linear(hidden_dim, 1)

        self.apply(weights_init_)

    def forward(self, state, action):
        xu = torch.cat([state, action], 1)
        
        x1 = F.relu(self.linear1(xu))
        x1 = F.relu(self.linear2(x1))
        x1 = self.linear3(x1)

        x2 = F.relu(self.linear4(xu))
        x2 = F.relu(self.linear5(x2))
        x2 = self.linear6(x2)

        return x1, x2


class QNetwork_continue_Multihead(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim,num_head,ensamble_num=2):
        super(QNetwork_continue_Multihead, self).__init__()

        self.ensamble_num=ensamble_num

        self.linear1 = nn.ModuleList([nn.Linear(num_inputs + num_actions, hidden_dim) for j in range(ensamble_num)])
        self.linear2 =   nn.ModuleList([nn.Linear(hidden_dim, hidden_dim)  for j in range(ensamble_num)])
        self.linear3 =   nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for j in range(ensamble_num)] )
        self.q_modellist =  nn.ModuleList([nn.ModuleList([ nn.Linear(hidden_dim, 1)for i in range(num_head)]) for j in range(ensamble_num)])

        self.apply(weights_init_)

    def forward(self, state, action):
        qs=[]
        for l1,l2,l3,q_model in zip(self.linear1,self.linear2,self.linear3,self.q_modellist):
            xu = torch.cat([state, action], -1)
            
            x1 = F.relu(l1(xu))
            x1 = F.relu(l2(x1))
            # x1 = F.relu(l3(x1))
            q=[]
            for model in q_model:
                q+=[model(x1)]
        
            qs += [torch.stack(q,dim=0)]# agentnum*batch size*1 
        qs=torch.stack(qs,dim=0)# ensamble * agentnum* batch *1   
            
        return qs


class QNetwork_discrete_Multihead(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim,num_head,ensamble_num=2):
        super(QNetwork_discrete_Multihead, self).__init__()

        self.ensamble_num=ensamble_num

        self.linear1 = nn.ModuleList([nn.Linear(num_inputs , hidden_dim) for j in range(ensamble_num)])
        self.linear2 =   nn.ModuleList([nn.Linear(hidden_dim, hidden_dim)  for j in range(ensamble_num)])
        self.linear3 =   nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for j in range(ensamble_num)] )
        self.q_modellist =  nn.ModuleList([nn.ModuleList([ nn.Linear(hidden_dim, num_actions)for i in range(num_head)]) for j in range(ensamble_num)])

        self.apply(weights_init_)

    def forward(self, state):
        qs=[]
        for l1,l2,l3,q_model in zip(self.linear1,self.linear2,self.linear3,self.q_modellist):
            # xu = torch.cat(state, -1)
            x1 = F.relu(l1(state))
            x1 = F.relu(l2(x1))
            x1 = F.relu(l3(x1))
            q=[]
            for model in q_model:
                q+=[model(x1)]
        
            qs += [torch.stack(q,dim=0)]# agentnum*batch size*out_dim 
        qs=torch.stack(qs,dim=0)# ensamble * agentnum* batch *out_dim  
            
        return qs# ensamble * agentnum* batch *out_dim  




class GaussianPolicy_continue(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
        super(GaussianPolicy_continue, self).__init__()
        
        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)

        self.mean_linear = nn.Linear(hidden_dim, num_actions)
        self.log_std_linear = nn.Linear(hidden_dim, num_actions)

        self.apply(weights_init_)

        # action rescaling
        if action_space is None:
            self.action_scale = torch.tensor(1.)
            self.action_bias = torch.tensor(0.)
        else:
            self.action_scale = torch.FloatTensor(
                (action_space.high - action_space.low) / 2.)
            self.action_bias = torch.FloatTensor(
                (action_space.high + action_space.low) / 2.)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
        return mean, log_std

    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean

    def calculate_prob(self,state,action):
        mean, log_std = self.forward(state)
        std = log_std.exp()

        normal = Normal(mean, std)
        y_t = (action-self.action_bias)/ (self.action_scale+1e-6)
        x_t= torch.atanh(torch.clamp(y_t,min= -1+1e-6,max = 1-1e-6))
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
        log_prob = log_prob.sum(1, keepdim=True)
        return log_prob


    def to(self, device):
        self.action_scale = self.action_scale.to(device)
        self.action_bias = self.action_bias.to(device)
        return super(GaussianPolicy_continue, self).to(device)



class GaussianPolicy_continue_Multihead(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim, num_head, action_space=None):
        super(GaussianPolicy_continue_Multihead, self).__init__()
        
        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)

        self.mean_linears=nn.ModuleList([nn.Linear(hidden_dim, num_actions) for i in range(num_head)])
        self.log_std_linears=nn.ModuleList([nn.Linear(hidden_dim, num_actions) for i in range(num_head)])

        self.apply(weights_init_)

        # action rescaling
        if action_space is None:
            self.action_scale = torch.tensor(1.)
            self.action_bias = torch.tensor(0.)
        else:
            self.action_scale = torch.FloatTensor(
                (action_space.high - action_space.low) / 2.)
            self.action_bias = torch.FloatTensor(
                (action_space.high + action_space.low) / 2.)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        means=[]
        for m in self.mean_linears:
            means+=[m(x)]
        means=torch.stack(means,dim=0)

        log_stds=[]
        for m in self.log_std_linears:
            log_stds+=[torch.clamp(m(x), min=LOG_SIG_MIN, max=LOG_SIG_MAX)]
        log_stds=torch.stack(log_stds,dim=0)
        
        #convert into tensor
        #agent_num*batch_size*action_dim
        return means, log_stds

    def sample(self, state):
        # agent_num*batch_size*action_dim
        means, log_stds = self.forward(state)
        # for mean,log_std in zip(means,log_stds):
        #here need change
        std = log_stds.exp()
        normal = Normal(means,  std)

        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
        log_prob = log_prob.sum(-1, keepdim=True)
        means = torch.tanh(means) * self.action_scale + self.action_bias
        return action, log_prob, means
        #agent nunm*batch*action dim 
        
    def calculate_prob(self,state,action):
        #agentnum * batch * actiondim 
        means, log_stds = self.forward(state)
        std = log_stds.exp()

        normal = Normal(means, std)

        y_t = (action-self.action_bias)/ (self.action_scale+1e-6)
        x_t= torch.atanh(torch.clamp(y_t,min= -1+1e-6,max = 1-1e-6))
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
        log_prob = log_prob.sum(-1, keepdim=True)
        return log_prob
        #agent nunm*batch*1#action dim 

    def to(self, device):
        self.action_scale = self.action_scale.to(device)
        self.action_bias = self.action_bias.to(device)
        return super(GaussianPolicy_continue_Multihead, self).to(device)


class DeterministicPolicy(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
        super(DeterministicPolicy, self).__init__()
        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)

        self.mean = nn.Linear(hidden_dim, num_actions)
        self.noise = torch.Tensor(num_actions)

        self.apply(weights_init_)

        # action rescaling
        if action_space is None:
            self.action_scale = 1.
            self.action_bias = 0.
        else:
            self.action_scale = torch.FloatTensor(
                (action_space.high - action_space.low) / 2.)
            self.action_bias = torch.FloatTensor(
                (action_space.high + action_space.low) / 2.)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias
        return mean

    def sample(self, state):
        mean = self.forward(state)
        noise = self.noise.normal_(0., std=0.1)
        noise = noise.clamp(-0.25, 0.25)
        action = mean + noise
        return action, torch.tensor(0.), mean

    def to(self, device):
        self.action_scale = self.action_scale.to(device)
        self.action_bias = self.action_bias.to(device)
        self.noise = self.noise.to(device)
        return super(DeterministicPolicy, self).to(device)
