import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.distributions.categorical import Categorical

class QNetwork(nn.Module):
    def __init__(self, env,latent_dim=3):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod()+latent_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.single_action_space.n),
        )

    def forward(self, x):
        return self.network(x)
    
class LatentMLP(nn.Module):
    def __init__(self, env,latent_dim=3):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod()+latent_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.single_action_space.n),
        )

    def forward(self, x,z):
        x=torch.cat([x,z],dim=-1)
        return self.network(x)
    
class ClassMLP(nn.Module):
    def __init__(self, env,latent_dim=3):
        super().__init__()
        self.networks = nn.ModuleList([nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.single_action_space.n),
        ) for _ in range(latent_dim)])

    def forward(self, x,z):
        #z=F.softmax(z,-1)
        y=0
        for i in range(len(self.networks)):
            if len(z.shape)==2:
                y+=self.networks[i](x)*z[:,i].unsqueeze(1)
            else:
                y+=self.networks[i](x)*z[i]
        return y
    
class NPMLP(nn.Module):
    def __init__(self, env,latent_dim=3,episode_limit=100):
        super().__init__()
        self.networks = nn.ModuleList([nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.single_action_space.n),
        ) for _ in range(latent_dim)])
        self.episode_limit=episode_limit
        self.state_dim=np.array(env.single_observation_space.shape).prod()
        self.action_size=env.single_action_space.n
        self.eye=torch.eye(latent_dim).to('cuda')

    def forward(self, x,z):
        #z=F.softmax(z,-1)
        y=0
        for i in range(len(self.networks)):
            if len(z.shape)==2:
                y+=self.networks[i](x)*z[:,i].unsqueeze(1)
            else:
                y+=self.networks[i](x)*z[i]
        return y
    
    def forward0(self, x):
        return self.networks[0](x)
    def forward1(self, x):
        return self.networks[1](x)
    
    def select(self, x,y,mask):
        x=x.view(-1,self.state_dim)
        y=F.one_hot(y.view(-1).long(),num_classes=self.action_size)#(bs*l,a)
        y=y*mask.view(-1).unsqueeze(-1)
        prob_list=[]
        for i in range(len(self.networks)):
            y_pred=self.networks[i](x)
            y_pred=F.softmax(y_pred)#(bs*l,a)
            prob=(y*y_pred).sum(1).view(-1,self.episode_limit).sum(1)#(bs)
            
            prob_list.append(prob.unsqueeze(-1))
        prob_list=torch.cat(prob_list,1)
        return prob_list/prob_list.sum(1,keepdim=True)
    
class NPMLPLORA(nn.Module):
    def __init__(self, env,latent_dim=3,episode_limit=100):
        super().__init__()
        self.linear1=nn.Linear(np.array(env.single_observation_space.shape).prod(), 120)
        self.relu=nn.ReLU()
        self.linear2=nn.Linear(120, 84)
        self.linear3=nn.Linear(84, env.single_action_space.n)
        
        self.latent_dim=latent_dim
        
        self.rank=9#4
        self.lora1_up=nn.ModuleList([nn.Linear(np.array(env.single_observation_space.shape).prod(),self.rank) for _ in range(self.latent_dim)])
        self.lora1_down=nn.ModuleList([nn.Linear(self.rank,120) for _ in range(self.latent_dim)])
        self.lora2_up=nn.ModuleList([nn.Linear(120,self.rank) for _ in range(self.latent_dim)])
        self.lora2_down=nn.ModuleList([nn.Linear(self.rank,84) for _ in range(self.latent_dim)])
        self.lora3_up=nn.ModuleList([nn.Linear(84,self.rank) for _ in range(self.latent_dim)])
        self.lora3_down=nn.ModuleList([nn.Linear(self.rank,env.single_action_space.n) for _ in range(self.latent_dim)])
        
        self.episode_limit=episode_limit
        self.state_dim=np.array(env.single_observation_space.shape).prod()
        self.action_size=env.single_action_space.n
        self.eye=torch.eye(latent_dim).to('cuda')

    def forward(self, x,z,weight=1):
        y_out=0
        for i in range(self.latent_dim):
            y=self.relu(self.linear1(x)+self.lora1_down[i](self.lora1_up[i](x))*weight)
            y=self.relu(self.linear2(y)+self.lora2_down[i](self.lora2_up[i](y))*weight)
            y=self.linear3(y)+self.lora3_down[i](self.lora3_up[i](y))*weight
            if len(z.shape)==1:
                y_out+=y*z[i]
            else:
                #print(y.shape,z.shape,y_out.shape,)
                #print(self.lora1s[i](y).shape)
                y_out+=y*z[:,i:i+1]
        return y_out
    
    def forward_all_head(self, x,weight=1):
        y_out=[]
        for i in range(self.latent_dim):
            y=self.relu(self.linear1(x)+self.lora1_down[i](self.lora1_up[i](x))*weight)
            y=self.relu(self.linear2(y)+self.lora2_down[i](self.lora2_up[i](y))*weight)
            y=self.linear3(y)+self.lora3_down[i](self.lora3_up[i](y))*weight
            y_out.append(y)
        return torch.stack(y_out,dim=0)
    
    def select(self, x,y,mask):
        x=x.view(-1,self.state_dim)
        y=F.one_hot(y.view(-1).long(),num_classes=self.action_size)#(bs*l,a)
        y_true=y*mask.view(-1).unsqueeze(-1)
        prob_list=[]
        
        for i in range(self.latent_dim):
            y=self.relu(self.linear1(x)+self.lora1_down[i](self.lora1_up[i](x)))
            y=self.relu(self.linear2(y)+self.lora2_down[i](self.lora2_up[i](y)))
            y=self.linear3(y)+self.lora3_down[i](self.lora3_up[i](y))
            
            y_pred=F.softmax(y)#(bs*l,a)
            prob=(y_true*y_pred).sum(1).view(-1,self.episode_limit).sum(1)#(bs)
            
            prob_list.append(prob.unsqueeze(-1))
        prob_list=torch.cat(prob_list,1)
        return prob_list/prob_list.sum(1,keepdim=True)
    
    def selectv2(self, x,y,mask):
        x=x.view(-1,self.state_dim)
        y=F.one_hot(y.view(-1).long(),num_classes=self.action_size)#(bs*l,a)
        y_true=y*mask.view(-1).unsqueeze(-1)
        logprob_list=[]
        
        for i in range(self.latent_dim):
            y=self.relu(self.linear1(x)+self.lora1_down[i](self.lora1_up[i](x)))
            y=self.relu(self.linear2(y)+self.lora2_down[i](self.lora2_up[i](y)))
            y=self.linear3(y)+self.lora3_down[i](self.lora3_up[i](y))
            
            y_pred=F.softmax(y)#(bs*l,a)
            logprob=(y_true*y_pred).sum(1).view(-1,self.episode_limit).log().sum(1)#(bs)
            
            logprob_list.append(logprob.unsqueeze(-1))
        prob_list=torch.cat(logprob_list,1).exp()
        return prob_list/prob_list.sum(1,keepdim=True)
    
    def set_requires_grad_false(self):
        self.linear1.weight.requires_grad=False
        self.linear1.bias.requires_grad=False
        self.linear2.weight.requires_grad=False
        self.linear2.bias.requires_grad=False
        self.linear3.weight.requires_grad=False
        self.linear3.bias.requires_grad=False
    
'''class ClassMLP(nn.Module):
    def __init__(self, env,latent_dim=3):
        super().__init__()
        self.network1 = nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.single_action_space.n),
        )
        self.network2 = nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.single_action_space.n),
        )

    def forward(self, x,z):
        if len(z.shape)==2:
            return self.network1(x)*z[:,0:1]+self.network2(x)*z[:,1:2]
        else:
            return self.network1(x)*z[0]+self.network2(x)*z[1]'''
    
class Encoder(nn.Module):
    def __init__(self, env,latent_dim=3):
        super().__init__()
        self.encoder_layer=nn.TransformerEncoderLayer(d_model=np.array(env.single_observation_space.shape).prod()+env.single_action_space.n, nhead=1,batch_first=True)
        self.network = nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod()+env.single_action_space.n, 120),
            nn.ReLU(),
            nn.Linear(120, latent_dim),
        )

    def forward(self, x,src_mask=None):
        x=self.encoder_layer(x,src_key_padding_mask=src_mask)
        x=x.sum(1)
        x=self.network(x)
        x=F.softmax(x,-1)#x/(x**2).sum(dim=-1,keepdim=True)**0.5
        return x

class Discriminator(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.encoder_layer=nn.TransformerEncoderLayer(d_model=np.array(env.single_observation_space.shape).prod()+env.single_action_space.n, nhead=1,batch_first=True)
        self.network = nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod()+env.single_action_space.n, 120),
            nn.ReLU(),
            nn.Linear(120, 120),
            nn.ReLU(),
            nn.Linear(120, 2),
        )

    def forward(self, x):
        x=self.encoder_layer(x)
        x=x.sum(1)
        x=self.network(x)
        return x



def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class Agent(nn.Module):
    def __init__(self, envs):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
        )

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        logits = self.actor(x)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        #print(x.shape,logits.shape,action.shape,probs.log_prob(action).shape)
        #return action, F.softmax(logits,dim=-1).log(), probs.entropy(), logits#self.critic(x),
        return action, probs.log_prob(action), probs.entropy(), logits#self.critic(x),
