import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

def get_output_size_with_batch(model, input_size, dtype=torch.float):
    with torch.no_grad():
        output = model(torch.zeros([1, *input_size[1:]], dtype=dtype))
        output_size = [None] + list(output.size())[1:]
    return output_size

def embedding_layer(input_size, num_embeddings, embedding_dim, **kwargs):
    class EmbeddingLayer(nn.Module):
        def __init__(self, num_embeddings, embedding_dim, **kwargs):
            super(EmbeddingLayer, self).__init__()
            self.layer = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, **kwargs)

        def forward(self, x: torch.Tensor):
            # if x.dtype != torch.int:
            #     x = x.int()
            return self.layer(x)
    layer = EmbeddingLayer(num_embeddings, embedding_dim, **kwargs)
    output_size = [None, embedding_dim]
    # output_size = get_output_size_with_batch(layer, input_size=input_size, dtype=torch.long)
    return layer, output_size

def linear_layer(input_size, layer_dim):
    input_dim = input_size[1]
    output_size = [None, layer_dim]
    layer = nn.Sequential(nn.Linear(input_dim, layer_dim), nn.ReLU())
    return layer, output_size
class GlobalStateLayer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.global_state_len = 30 # lenth of global state
        self.linear_layer, self.linear_layer_out_dim = linear_layer([None, self.global_state_len], 64)
    def forward(self, x):
        x = x.float()
        x = self.linear_layer(x)
        return x

class AgentStateLayer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.agent_state_len = 73 # lenth of agent state
        self.my_character_type_embed_layer, self.my_character_type_embed_layer_out_dim = embedding_layer([None], 100, 16)
        self.my_role_type_embed_layer, self.my_role_type_embed_layer_out_dim = embedding_layer([None], 8 ,8)
        self.my_buff_type_embed_layer, self.my_buff_type_embed_layer_out_dim = embedding_layer([None], 50, 6)
        self.agent_state_dim = 16+8+6-3 + self.agent_state_len
        self.out_dim = 128
        self.linear_layer, self.linear_layer_out_dim = linear_layer([None, self.agent_state_dim], self.out_dim )
    def forward(self, x):
        my_character_type = x[:, 0].long()
        my_role_type = x[:, 1].long()
        my_buff_type = x[:, 2].long()
        my_character_type = self.my_character_type_embed_layer(my_character_type)
        my_role_type = self.my_role_type_embed_layer(my_role_type)
        my_buff_type = self.my_buff_type_embed_layer(my_buff_type)
        my_states = x[:,3:].float()
        x = torch.cat([my_character_type, my_role_type, my_buff_type, my_states], dim=1).float()
        x = self.linear_layer(x)
        return x

class Model(nn.Module):
    def __init__(self,latent_dim=3,device='cuda') -> None:
        super().__init__()
        self.global_state_layer_dim = 64
        self.agent_state_layer_dim = 128
        self.global_state_layer = GlobalStateLayer()
        self.self_state_layer = AgentStateLayer()
        self.ally0_state_layer = AgentStateLayer()
        self.ally1_state_layer = AgentStateLayer()
        self.enemy0_state_layer = AgentStateLayer()
        self.enemy1_state_layer = AgentStateLayer()
        self.enemy2_state_layer = AgentStateLayer()
        # self.skill_layer = SkillLayer()
        self.share_layer_dim = self.global_state_layer_dim + self.agent_state_layer_dim * 6
        self.share_layer = nn.Linear(self.share_layer_dim, 256)
        self.share_layer2 = nn.Linear(256, 128)
        self.value_layer = nn.Sequential(nn.Linear(128, 1))
        self.action_layer = nn.Linear(128, 52)
        self.opt = optim.Adam(self.parameters(), lr=1e-3)
        
        self.latent_dim=latent_dim
        self.rank=9#4
        self.lora1_up=nn.ModuleList([nn.Linear(self.share_layer_dim,self.rank) for _ in range(self.latent_dim)])
        self.lora1_down=nn.ModuleList([nn.Linear(self.rank,256) for _ in range(self.latent_dim)])
        self.lora2_up=nn.ModuleList([nn.Linear(256,self.rank) for _ in range(self.latent_dim)])
        self.lora2_down=nn.ModuleList([nn.Linear(self.rank,128) for _ in range(self.latent_dim)])
        self.lora3_up=nn.ModuleList([nn.Linear(128,self.rank) for _ in range(self.latent_dim)])
        self.lora3_down=nn.ModuleList([nn.Linear(self.rank,52) for _ in range(self.latent_dim)])
        
        self.relu=nn.ReLU()
        
        self.device=device
    
    def select(self, states,actions,tags,n_sampled_traj=5):
        global_feature = states[0].float()
        self_feature = states[1]
        ally0_feature = states[2]
        ally1_feature = states[3]
        enemy0_feature = states[4]
        enemy1_feature = states[5]
        enemy2_feature = states[6]
        global_feature = self.global_state_layer(global_feature)
        self_feature = self.self_state_layer(self_feature)
        ally0_feature = self.ally0_state_layer(ally0_feature)
        ally1_feature = self.ally1_state_layer(ally1_feature)
        enemy0_feature = self.enemy0_state_layer(enemy0_feature)
        enemy1_feature = self.enemy1_state_layer(enemy1_feature)
        enemy2_feature = self.enemy2_state_layer(enemy2_feature)
        x_original = torch.cat([global_feature,self_feature, ally0_feature, ally1_feature, enemy0_feature, enemy1_feature, enemy2_feature], dim=1)
        prob_list=[[] for _ in range(n_sampled_traj)]
        for i in range(self.latent_dim):
            x = x_original
            x = self.relu(self.share_layer(x.float())+self.lora1_down[i](self.lora1_up[i](x)))
            x = self.relu(self.share_layer2(x)+self.lora2_down[i](self.lora2_up[i](x)))
            logits_p = self.action_layer(x)+self.lora3_down[i](self.lora3_up[i](x))#(bs,a)
            prob_p=logits_p.softmax(1).gather(1,actions).squeeze(-1)#(bs)
            for j in range(n_sampled_traj):
                prob=(prob_p*tags[:,j]).sum()
                prob_list[j].append(prob)
        prob_list=torch.tensor(prob_list).to(self.device)#(n_sample_traj,latent_dim)
        return prob_list/(prob_list.sum(1,keepdim=True)+0.0001)
    
    def forward(self, states,tags,z,n_sampled_traj=5):
        global_feature = states[0].float()
        self_feature = states[1]
        ally0_feature = states[2]
        ally1_feature = states[3]
        enemy0_feature = states[4]
        enemy1_feature = states[5]
        enemy2_feature = states[6]
        global_feature = self.global_state_layer(global_feature)
        self_feature = self.self_state_layer(self_feature)
        ally0_feature = self.ally0_state_layer(ally0_feature)
        ally1_feature = self.ally1_state_layer(ally1_feature)
        enemy0_feature = self.enemy0_state_layer(enemy0_feature)
        enemy1_feature = self.enemy1_state_layer(enemy1_feature)
        enemy2_feature = self.enemy2_state_layer(enemy2_feature)
        x_original = torch.cat([global_feature,self_feature, ally0_feature, ally1_feature, enemy0_feature, enemy1_feature, enemy2_feature], dim=1)
        y=torch.zeros((len(tags),52)).to(self.device)
        for i in range(self.latent_dim):
            x = x_original
            x = self.relu(self.share_layer(x.float())+self.lora1_down[i](self.lora1_up[i](x)))
            x = self.relu(self.share_layer2(x)+self.lora2_down[i](self.lora2_up[i](x)))
            logits_p = self.action_layer(x)+self.lora3_down[i](self.lora3_up[i](x))#(bs,a)
            for j in range(n_sampled_traj):
                y+=logits_p*z[j,i]*tags[:,j].unsqueeze(1)
        return y
    
    def forward_all_head(self, states,n_sampled_traj=5):
        global_feature = states[0].float()
        self_feature = states[1]
        ally0_feature = states[2]
        ally1_feature = states[3]
        enemy0_feature = states[4]
        enemy1_feature = states[5]
        enemy2_feature = states[6]
        global_feature = self.global_state_layer(global_feature)
        self_feature = self.self_state_layer(self_feature)
        ally0_feature = self.ally0_state_layer(ally0_feature)
        ally1_feature = self.ally1_state_layer(ally1_feature)
        enemy0_feature = self.enemy0_state_layer(enemy0_feature)
        enemy1_feature = self.enemy1_state_layer(enemy1_feature)
        enemy2_feature = self.enemy2_state_layer(enemy2_feature)
        x_original = torch.cat([global_feature,self_feature, ally0_feature, ally1_feature, enemy0_feature, enemy1_feature, enemy2_feature], dim=1)
        y_list=[]
        for i in range(self.latent_dim):
            x = x_original
            x = self.relu(self.share_layer(x.float())+self.lora1_down[i](self.lora1_up[i](x)))
            x = self.relu(self.share_layer2(x)+self.lora2_down[i](self.lora2_up[i](x)))
            logits_p = self.action_layer(x)+self.lora3_down[i](self.lora3_up[i](x))#(bs,a)
            y_list.append(logits_p)
        return y_list#torch.stack(y_list,dim=0)