import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F

from utils.embed import polynomial_embed, binary_embed
from utils.transformer import Transformer

class VectorQuantizer(nn.Module):
    """A simple vector quantizer (non-EMA) based on the original VQ-VAE paper.

    Args:
        num_embeddings: number of discrete codebook vectors
        embedding_dim: dimensionality of each code vector (latent_dim)
        commitment_cost: beta, weight for commitment loss
    """

    def __init__(self, num_embeddings: int, embedding_dim: int, commitment_cost: float = 0.25):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost

        # embedding table: (num_embeddings, embedding_dim)
        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
        nn.init.uniform_(self.embedding.weight, -1.0 / self.num_embeddings, 1.0 / self.num_embeddings)

    def forward(self, inputs: th.Tensor):
        """
        inputs: (B, D) or (B, N, D) where D == embedding_dim
        returns: quantized (same shape), vq_loss (scalar), perplexity, encodings (LongTensor of indices)
        """
        orig_shape = inputs.shape
        if inputs.dim() == 3:
            B, N, D = inputs.shape
            flat_input = inputs.view(B * N, D)
        elif inputs.dim() == 2:
            flat_input = inputs
            B = flat_input.shape[0]
            N = None
        else:
            raise ValueError("inputs must be 2D or 3D tensor")

        # compute distances between encoder outputs and embedding vectors
        # flat_input: (M, D), embedding.weight: (K, D)
        # use (x - e)^2 = x^2 + e^2 - 2xe
        # compute dot-products
        embedding_weight = self.embedding.weight  # (K, D)
        # distances: (M, K)
        distances = (
            flat_input.pow(2).sum(1, keepdim=True)
            + embedding_weight.pow(2).sum(1)
            - 2.0 * th.matmul(flat_input, embedding_weight.t())
        )

        # encoding indices: nearest embedding
        encoding_indices = th.argmin(distances, dim=1)  # (M,)

        # one-hot encodings
        encodings = F.one_hot(encoding_indices, num_classes=self.num_embeddings).type(flat_input.dtype)  # (M, K)

        # quantized vectors
        quantized_flat = th.matmul(encodings, embedding_weight)  # (M, D)

        # reshape back to original
        if N is not None:
            quantized = quantized_flat.view(B, N, self.embedding_dim)
        else:
            quantized = quantized_flat.view(B, self.embedding_dim)

        # losses
        # embedding loss (train codebook to match encoder outputs)
        embedding_loss = F.mse_loss(quantized.detach(), flat_input, reduction='none')
        # commitment loss (train encoder to commit to quantized vectors)
        commitment_loss = self.commitment_cost * F.mse_loss(quantized, flat_input.detach(), reduction='none')
        vq_loss = embedding_loss + commitment_loss

        # straight-through estimator: pass gradients to encoder
        quantized_st = flat_input + (quantized_flat - flat_input).detach()
        if N is not None:
            quantized_st = quantized_st.view(B, N, self.embedding_dim)

        # perplexity (diagnostic)
        avg_probs = encodings.mean(dim=0)
        perplexity = th.exp(-th.sum(avg_probs * th.log(avg_probs + 1e-10)))

        return quantized_st, vq_loss, perplexity, encoding_indices.view(B, N) if N is not None else encoding_indices

class VQCVAEAgent(nn.Module):
    """VQ-CVAE agent: 使用 state encoder -> 投影 -> 向量量化 -> decoder

    保持与原 CVAEAgent 接口相似：调用 forward 返回 (q, h, z_e, vq_loss, encoding_indices)
    其中 z_e 是 encoder 的连续输出（未量化），decoder 接收量化向量 z_q
    """

    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, args):
        super(VQCVAEAgent, self).__init__()
        self.args = args
        self.task2decomposer = task2decomposer
        self.task2n_agents = task2n_agents
        self.task2input_shape_info = task2input_shape_info
        self.latent_dim = args.latent_dim
        self.entity_embed_dim = args.entity_embed_dim

        # 保持原来的编码器和解码器
        self.state_encoder = StateEncoder(task2input_shape_info, task2decomposer, task2n_agents, args)
        self.decoder = Decoder(task2input_shape_info, task2decomposer, task2n_agents, args)

        # 将 state_encoder 的输出投影到 latent_dim（连续编码 z_e）
        self.encoder_proj = nn.Linear(self.entity_embed_dim, self.latent_dim)

        # Vector Quantizer
        num_embeddings = getattr(args, "num_embeddings", 5)
        commitment_cost = getattr(args, "commitment_cost", 0.25)
        self.num_embeddings = num_embeddings
        self.vq = VectorQuantizer(num_embeddings=num_embeddings, embedding_dim=self.latent_dim, commitment_cost=commitment_cost)

    def init_hidden(self):
        return self.decoder.init_hidden()

    def encode(self, states, task, actions):
        attn_out = self.state_encoder(states, task, actions)  # (B, entity_embed_dim) or (B, N, entity_embed_dim)
        z_e = self.encoder_proj(attn_out)  # (B, latent_dim) or (B, N, latent_dim)
        return z_e

    def quantize(self, z_e):
        z_q, vq_loss, perplexity, encoding_indices = self.vq(z_e)
        return z_q, vq_loss, perplexity, encoding_indices

    def decode(self, inputs, hidden_state, task, latent_z):
        return self.decoder(inputs, hidden_state, task, latent_z)
    
    def decode_id(self, inputs, hidden_state, task, encoding_onehot):
        enc = encoding_onehot
        emb = self.vq.embedding.weight  # (K, D) where D == self.latent_dim
        K, D = emb.shape[0], emb.shape[1]

        # 支持 (B, N, K) 或 (B, K)
        if enc.dim() == 3:
            B, N, K_in = enc.shape
            flat = enc.view(B * N, K_in)               # (B*N, K)
            quantized_flat = th.matmul(flat, emb)      # (B*N, D)
            z_q = quantized_flat.view(B, N, D)         # (B, N, D)
        elif enc.dim() == 2:
            B, K_in = enc.shape
            z_q = th.matmul(enc, emb)                  # (B, D)
        else:
            raise ValueError("encoding_onehot must be 2D (B, K) or 3D (B, N, K)")

        return self.decode(inputs, hidden_state, task, z_q)

    def get_encoding(self, states, task, actions):
        z_e = self.encode(states, task, actions)
        z_q, _, _, encoding_indices = self.quantize(z_e)
        return z_q
    
    def get_encoding_id(self, states, task, actions):
        z_e = self.encode(states, task, actions)  # (B, latent_dim) or (B, N, latent_dim)
        # 使用 quantize 获取 encoding indices（vq.quantize 返回 encoding_indices）
        _, _, _, encoding_indices = self.quantize(z_e)
        # encoding_indices: (B, N) 或 (B,)
        encodings_onehot = F.one_hot(encoding_indices, num_classes=self.vq.num_embeddings).type(z_e.dtype)
        # 返回 float 类型的 one-hot（和 z_e 同 dtype），方便后续矩阵乘法使用
        return encodings_onehot
    
    def get_encoding_id_wo_onehot(self, states, task, actions):
        z_e = self.encode(states, task, actions)  # (B, latent_dim) or (B, N, latent_dim)
        # 使用 quantize 获取 encoding indices（vq.quantize 返回 encoding_indices）
        _, _, _, encoding_indices = self.quantize(z_e)
        return encoding_indices
    
    def forward(self, states, task, actions, inputs, hidden_state):
        z_e = self.encode(states, task, actions)
        z_q, vq_loss, perplexity, encoding_indices = self.quantize(z_e)
        q, h = self.decode(inputs, hidden_state, task, z_q)
        return q, h, z_e, vq_loss, perplexity, encoding_indices

class CVAEAgent(nn.Module):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, args):
        super(CVAEAgent, self).__init__()
        self.args = args
        self.task2decomposer = task2decomposer
        self.task2n_agents = task2n_agents
        self.task2input_shape_info = task2input_shape_info
        self.latent_dim = args.latent_dim
        self.entity_embed_dim = args.entity_embed_dim

        self.state_encoder = StateEncoder(task2input_shape_info, task2decomposer, task2n_agents, args)
        self.decoder = Decoder(task2input_shape_info, task2decomposer, task2n_agents, args)
        self.encoder_mu = nn.Linear(self.entity_embed_dim, self.latent_dim)
        self.encoder_logvar = nn.Linear(self.entity_embed_dim, self.latent_dim)
    
    def init_hidden(self):
        return self.decoder.init_hidden()

    def encode(self, states, task, actions):
        attn_out = self.state_encoder(states, task, actions)
        mu = self.encoder_mu(attn_out)
        logvar = self.encoder_logvar(attn_out)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = th.exp(0.5 * logvar)
        eps = th.randn_like(std)
        return mu + eps * std
    
    def decode(self, inputs, hidden_state, task, latent_z):
        return self.decoder(inputs, hidden_state, task, latent_z)
    
    def get_encoding(self, states, task, actions):
        mu, logvar = self.encode(states, task, actions)
        z = self.reparameterize(mu, logvar)
        return z
    
    def forward(self, states, task, actions, inputs, hidden_state):
        mu, logvar = self.encode(states, task, actions)
        z = self.reparameterize(mu, logvar)
        q, h = self.decode(inputs, hidden_state, task, z)
        return q, h, mu, logvar

class StateEncoder(nn.Module):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, args):
        super(StateEncoder, self).__init__()

        self.task2last_action_shape = {task: task2input_shape_info[task]["last_action_shape"] for task in
                                       task2input_shape_info}
        self.task2decomposer = task2decomposer
        for key in task2decomposer.keys():
            task2decomposer_ = task2decomposer[key]
            break

        self.task2n_agents = task2n_agents
        self.args = args

        self.latent_dim = args.latent_dim
        self.attn_embed_dim = args.attn_embed_dim
        self.entity_embed_dim = args.entity_embed_dim

        # get detailed state shape information
        state_nf_al, state_nf_en, timestep_state_dim = \
            task2decomposer_.state_nf_al, task2decomposer_.state_nf_en, task2decomposer_.timestep_number_state_dim
        self.state_last_action, self.state_timestep_number = task2decomposer_.state_last_action, task2decomposer_.state_timestep_number

        self.n_actions_no_attack = task2decomposer_.n_actions_no_attack

        has_attack_action = task2decomposer_.n_actions_no_attack != task2decomposer_.n_actions
        if has_attack_action:
            ally_dim = state_nf_al + (self.n_actions_no_attack + 1)
            enemy_dim = state_nf_en + 1
        else:
            ally_dim = state_nf_al + self.n_actions_no_attack
            enemy_dim = state_nf_en
        self.has_attack_action = has_attack_action

        # define state information processor
        if self.state_last_action:
            if has_attack_action:
                ally_dim += (self.n_actions_no_attack + 1)
            else:
                ally_dim += self.n_actions_no_attack

        self.ally_encoder = nn.Linear(ally_dim, self.entity_embed_dim)
        self.enemy_encoder = nn.Linear(enemy_dim, self.entity_embed_dim)

        max_ally_num = args.max_ally_num
        self.ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)

        max_enemy_num = args.max_enemy_num
        self.enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)

        # we ought to do attention
        self.query = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)
        self.key = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)

    def forward(self, states, task, actions):
        states = states.unsqueeze(1)
        task_decomposer = self.task2decomposer[task]

        bs = states.size(0)
        n_agents = task_decomposer.n_agents
        n_enemies = task_decomposer.n_enemies
        n_entities = n_agents + n_enemies

        # get decomposed state information
        ally_states, enemy_states, last_action_states, timestep_number_state = task_decomposer.decompose_state(states)
        ally_states = th.stack(ally_states, dim=0)  # [n_agents, bs, 1, state_nf_al]

        _, current_attack_action_info, current_compact_action_states = task_decomposer.decompose_action_info(F.one_hot(actions.reshape(-1), num_classes=self.task2last_action_shape[task]))
        current_compact_action_states = current_compact_action_states.reshape(bs, n_agents, -1).permute(1, 0, 2).unsqueeze(2)
        ally_states = th.cat([ally_states, current_compact_action_states], dim=-1)

        if self.has_attack_action:
            current_attack_action_info = current_attack_action_info.reshape(bs, n_agents, n_enemies).sum(dim=1)
            attack_action_states = (current_attack_action_info > 0).type(ally_states.dtype).reshape(bs, n_enemies, 1, 1).permute(1, 0, 2, 3)
            enemy_states = th.stack(enemy_states, dim=0)  # [n_enemies, bs, 1, state_nf_en]
            enemy_states = th.cat([enemy_states, attack_action_states], dim=-1)
        else:
            enemy_states = th.stack(enemy_states, dim=0)  # [n_enemies, bs, 1, state_nf_en]

        # stack action information
        if self.state_last_action:
            last_action_states = th.stack(last_action_states, dim=0)
            _, _, compact_action_states = task_decomposer.decompose_action_info(last_action_states)
            ally_states = th.cat([ally_states, compact_action_states], dim=-1)

        # do inference and get entity_embed
        ally_embed = self.ally_encoder(ally_states) # [seq, bs, 1, x]
        enemy_embed = self.enemy_encoder(enemy_states) # [seq, bs, 1, x]

        # ally time embedding
        # bs, ally_seq_len, _ = ally_hidden.shape
        ally_seq_len, bs, _, _ = ally_embed.shape
        ally_steps = th.arange(ally_seq_len, device=ally_embed.device).long()  # (seq_len,)
        ally_step_emb = ally_steps.view(ally_seq_len, 1, 1).expand(-1, bs, 1) # [seq_len, bs, 1]
        ally_step_emb = self.ally_time_embed(ally_step_emb)  # (seq_len, bs, 1, entity_embed_dim)
        ally_embed = ally_embed + ally_step_emb
        # enemy time embedding
        enemy_seq_len, bs, _, _ = enemy_embed.shape
        enemy_steps = th.arange(enemy_seq_len, device=enemy_embed.device).long()  # (seq_len,)
        enemy_step_emb = enemy_steps.view(enemy_seq_len, 1, 1).expand(-1, bs, 1)
        enemy_step_emb = self.enemy_time_embed(enemy_step_emb)  # (seq_len, bs, 1, entity_embed_dim)
        enemy_embed = enemy_embed + enemy_step_emb

        # we ought to do self-attention
        entity_embed = th.cat([ally_embed, enemy_embed], dim=0)

        # do attention
        proj_query = self.query(entity_embed).permute(1, 2, 0, 3).reshape(bs, n_entities, self.attn_embed_dim)
        proj_key = self.key(entity_embed).permute(1, 2, 3, 0).reshape(bs, self.attn_embed_dim, n_entities)
        energy = th.bmm(proj_query / (self.attn_embed_dim ** (1 / 2)), proj_key)
        attn_score = F.softmax(energy, dim=1)
        proj_value = entity_embed.permute(1, 2, 3, 0).reshape(bs, self.entity_embed_dim, n_entities)
        attn_out = th.bmm(proj_value, attn_score).squeeze(1).permute(0, 2, 1)[:, :n_agents, :]  #.reshape(bs, n_entities, self.entity_embed_dim)[:, :n_agents, :]

        attn_out = attn_out.reshape(bs * n_agents, self.entity_embed_dim)

        return attn_out


class Decoder(nn.Module):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, args):
        super(Decoder, self).__init__()
        self.args = args
        self.task2last_action_shape = {task: task2input_shape_info[task]["last_action_shape"] for task in
                                       task2input_shape_info}
        
        self.task2decomposer = task2decomposer
        self.task2n_agents = task2n_agents

        #### define various dimension information
        ## set attributes
        self.entity_embed_dim = args.entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        # self.task_repre_dim = args.task_repre_dim
        ## get obs shape information
        task_0 = list(task2decomposer.keys())[0]
        obs_own_dim = task2decomposer[task_0].own_obs_dim
        obs_en_dim, obs_al_dim = task2decomposer[task_0].obs_nf_en, task2decomposer[task_0].obs_nf_al
        n_actions_no_attack = task2decomposer[task_0].n_actions_no_attack
        
        has_attack_action = n_actions_no_attack != task2decomposer[task_0].n_actions

        if args.obs_agent_id and args.obs_last_action:
            if has_attack_action:
                ## get wrapped obs_own_dim
                wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack + 1
                ## enemy_obs ought to add attack_action_info
                obs_en_dim += 1
            else:
                wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack
                # wrapped_obs_own_dim = obs_own_dim + args.n_agents + n_actions_no_attack
        else:
            wrapped_obs_own_dim = obs_own_dim
        
        self.latent_dim = args.latent_dim
        self.latent_enc = nn.Linear(self.latent_dim, self.entity_embed_dim)
        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim)
        self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim)
        self.own_value = nn.Linear(wrapped_obs_own_dim, self.entity_embed_dim)

        # self.value = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)
        # self.skill_value = nn.Linear(self.skill_dim, self.entity_embed_dim)

        # self.time_embed = nn.Embedding(max_seq_len, self.entity_embed_dim)

        max_ally_num = args.max_ally_num
        self.ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)

        max_enemy_num = args.max_enemy_num
        self.enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)

        self.transformer = Transformer(self.entity_embed_dim, args.head, args.depth, self.entity_embed_dim)
        self.q_skill = nn.Linear(4*self.entity_embed_dim, n_actions_no_attack)
        self.attack_skill = nn.Linear(2*self.entity_embed_dim, 1)
        

    def init_hidden(self):
        # make hidden states on the same device as model
        return self.q_skill.weight.new(1, self.entity_embed_dim).zero_()

    def forward(self, inputs, hidden_state, task, latent_z):
        hidden_state = hidden_state.view(-1, 1, self.entity_embed_dim)
        # get decomposer, last_action_shape and n_agents of this specific task
        task_decomposer = self.task2decomposer[task]
        task_n_agents = self.task2n_agents[task]
        last_action_shape = self.task2last_action_shape[task]

        # decompose inputs into observation inputs, last_action_info, agent_id_info
        obs_dim = task_decomposer.obs_dim
        obs_inputs, last_action_inputs, agent_id_inputs = inputs[:, :obs_dim], \
                                                          inputs[:, obs_dim:obs_dim + last_action_shape], inputs[:,
                                                                                                          obs_dim + last_action_shape:]

        # decompose observation input
        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(
            obs_inputs)  # own_obs: [bs*self.n_agents, own_obs_dim]
        bs = int(own_obs.shape[0] / task_n_agents)

        # embed agent_id inputs and decompose last_action_inputs
        agent_id_inputs = [
            th.as_tensor(binary_embed(i + 1, self.args.id_length, self.args.max_agent), dtype=own_obs.dtype) for i in
            range(task_n_agents)]
        agent_id_inputs = th.stack(agent_id_inputs, dim=0).repeat(bs, 1).to(own_obs.device)
        _, attack_action_info, compact_action_states = task_decomposer.decompose_action_info(last_action_inputs)

        # incorporate agent_id embed and compact_action_states
        if self.args.obs_last_action and self.args.obs_agent_id:
            # if obs_last_action and obs_agent_id, then own_obs should be wrapped
            own_obs = th.cat([own_obs, agent_id_inputs, compact_action_states], dim=-1)
        else:
            own_obs = own_obs
        # own_obs = th.cat([own_obs, agent_id_inputs, compact_action_states], dim=-1)
        
        # incorporate attack_action_info into enemy_feats
        if np.prod(attack_action_info.shape) > 0:
            attack_action_info = attack_action_info.transpose(0, 1).unsqueeze(-1)
            enemy_feats = th.cat([th.stack(enemy_feats, dim=0), attack_action_info], dim=-1)
        else:
            enemy_feats = th.stack(enemy_feats, dim=0)
        ally_feats = th.stack(ally_feats, dim=0)

        # compute key, query and value for attention
        own_hidden = self.own_value(own_obs).unsqueeze(1)
        ally_hidden = self.ally_value(ally_feats).permute(1, 0, 2)
        enemy_hidden = self.enemy_value(enemy_feats).permute(1, 0, 2)
        # encoding_hidden = self.encoding_value(task_encoding).unsqueeze(1)
        # skill_hidden = self.skill_value(skill).unsqueeze(1)
        history_hidden = hidden_state

        # ally time embedding
        bs, ally_seq_len, _ = ally_hidden.shape
        ally_steps = th.arange(ally_seq_len, device=ally_hidden.device).long()  # (seq_len,)
        ally_step_emb = ally_steps.view(1, ally_seq_len).expand(bs, -1)
        ally_step_emb = self.ally_time_embed(ally_step_emb)  # (bs, seq_len, entity_embed_dim)
        ally_hidden = ally_hidden + ally_step_emb
        # enemy time embedding
        bs, enemy_seq_len, _ = enemy_hidden.shape
        enemy_steps = th.arange(enemy_seq_len, device=enemy_hidden.device).long()  # (seq_len,)
        enemy_step_emb = enemy_steps.view(1, enemy_seq_len).expand(bs, -1)
        enemy_step_emb = self.enemy_time_embed(enemy_step_emb)  # (bs, seq_len, entity_embed_dim)
        enemy_hidden = enemy_hidden + enemy_step_emb

        latent_emb = self.latent_enc(latent_z)

        if self.args.use_hidden:
            total_hidden = th.cat([own_hidden, enemy_hidden, ally_hidden, history_hidden], dim=1)
        else:
            total_hidden = th.cat([own_hidden, enemy_hidden, ally_hidden], dim=1)
        
        outputs = self.transformer(total_hidden, None)

        own_length = 1
        enemy_length = enemy_hidden.shape[1]
        ally_length = ally_hidden.shape[1]

        h = outputs[:, -1:, :]
        # encoding_inputs = outputs[:, 0, :]
        base_action_inputs = outputs[:, 0, :]  # th.cat([outputs[:, 0, :], skill], dim=-1)
        obs_enemy = th.max(outputs[:,1:1+enemy_length,:], dim=1)[0]
        obs_ally = th.max(outputs[:,1+enemy_length:1+enemy_length+ally_length,:], dim=1)[0]
        obs_out = th.cat([latent_emb, base_action_inputs, obs_enemy, obs_ally], dim=-1)
        q_base = self.q_skill(obs_out)

        if task_decomposer.n_actions_no_attack == task_decomposer.n_actions:
            q = q_base
        else:
            q_attack_list = []
            for i in range(enemy_feats.size(0)):
                attack_action_inputs = outputs[:, 1+i, :]
                attack_action_inputs = th.cat([latent_emb, attack_action_inputs], dim=-1)
                q_enemy = self.attack_skill(attack_action_inputs)
                # q_enemy_mean = th.mean(q_enemy, 1, True)
                q_attack_list.append(q_enemy)
            q_attack = th.stack(q_attack_list, dim=1).squeeze()

            q = th.cat([q_base, q_attack], dim=-1)

        return q, h
