import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F

from utils.embed import polynomial_embed, binary_embed
from utils.transformer import Transformer


from torch.autograd import Function    
class VectorQuantization(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):
        with th.no_grad():
            embedding_size = codebook.size(1)
            inputs_size = inputs.size()
            inputs_flatten = inputs.view(-1, embedding_size)

            codebook_sqr = th.sum(codebook ** 2, dim=1)
            inputs_sqr = th.sum(inputs_flatten ** 2, dim=1, keepdim=True)

            distances = th.addmm(codebook_sqr + inputs_sqr,
                inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)

            _, indices_flatten = th.min(distances, dim=1)
            indices = indices_flatten.view(*inputs_size[:-1])
            ctx.mark_non_differentiable(indices)

            return indices

    @staticmethod
    def backward(ctx, grad_output):
        raise RuntimeError('Trying to call `.grad()` on graph containing '
            '`VectorQuantization`. The function `VectorQuantization` '
            'is not differentiable. Use `VectorQuantizationStraightThrough` '
            'if you want a straight-through estimator of the gradient.')

class VectorQuantizationStraightThrough(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):
        indices = vq(inputs, codebook)
        indices_flatten = indices.view(-1)
        ctx.save_for_backward(indices_flatten, codebook)
        ctx.mark_non_differentiable(indices_flatten)

        codes_flatten = th.index_select(codebook, dim=0,
            index=indices_flatten)
        codes = codes_flatten.view_as(inputs)

        return (codes, indices_flatten)

    @staticmethod
    def backward(ctx, grad_output, grad_indices):
        grad_inputs, grad_codebook = None, None

        if ctx.needs_input_grad[0]:
            # Straight-through estimator
            grad_inputs = grad_output.clone()
        if ctx.needs_input_grad[1]:
            # Gradient wrt. the codebook
            indices, codebook = ctx.saved_tensors
            embedding_size = codebook.size(1)

            grad_output_flatten = (grad_output.contiguous()
                                              .view(-1, embedding_size))
            grad_codebook = th.zeros_like(codebook)
            grad_codebook.index_add_(0, indices, grad_output_flatten)

        return (grad_inputs, grad_codebook)

vq = VectorQuantization.apply
vq_st = VectorQuantizationStraightThrough.apply


class VQEmbedding(nn.Module):
    def __init__(self, K, D):
        super().__init__()
        self.embedding = nn.Embedding(K, D)
        self.embedding.weight.data.uniform_(-1./K, 1./K)

    def forward(self, z_e_x):
        z_e_x_ = z_e_x.contiguous()
        latents = vq(z_e_x_, self.embedding.weight)
        return latents
        
    def staight_through4_indices(self, z_e_x, agent_num):
        z_e_x_ = z_e_x.contiguous()   
        z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach())  
        
        indices = indices.unsqueeze(1).repeat(1, agent_num).reshape(-1, 1) 
        return indices  

    def straight_through(self, z_e_x):
        z_e_x_ = z_e_x.contiguous()   
        z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach())  
        z_q_x = z_q_x_  

        z_q_x_bar_flatten = th.index_select(self.embedding.weight,
            dim=0, index=indices)       
        z_q_x_bar = z_q_x_bar_flatten.view_as(z_e_x_).contiguous()   

        return z_q_x, z_q_x_bar 
        
    
class VQVAEAgent(nn.Module):
    def __init__(self, task2input_shape_info,args, task2decomposer, task2n_agents, decomposer,task2args ):
        super(VQVAEAgent, self).__init__()
        # basic info
        self.task2last_action_shape = {task: task2input_shape_info[task]["last_action_shape"] for task in
                                       task2input_shape_info}
        self.task2decomposer = task2decomposer
        self.task2n_agents = task2n_agents
        self.task2args = task2args
        self.args = args
        self.z_e_dim = args.z_e_dim

        self.hidden_states_enc = None
        self.hidden_states_dec = None
        self.state_encoder = StateEncoder(task2input_shape_info, task2decomposer, task2n_agents, decomposer, args)
        self.obs_encoder = ObsEncoder(task2input_shape_info, task2decomposer, task2n_agents, decomposer, args)
        self.encoder = Encoder(args)
        self.decoder = Decoder(task2input_shape_info, task2decomposer, task2n_agents, decomposer, args)
        self.codebook = VQEmbedding(self.args.vqvae_K, self.args.vqvae_D).to('cuda')

    
    def _build_inputs(self, batch, t, task):
        bs = batch.batch_size
        inputs = []
        inputs.append(batch["obs"][:, t])
        task_args, n_agents = self.task2args[task], self.task2n_agents[task]
        if task_args.obs_last_action:
            if t == 0:
                inputs.append(th.zeros_like(batch["actions_onehot"][:, t]))
            else:
                inputs.append(batch["actions_onehot"][:, t - 1])
        if task_args.obs_agent_id:
            inputs.append(th.eye(n_agents, device=batch.device).unsqueeze(0).expand(bs, -1, -1))

        inputs = th.cat([x.reshape(bs * n_agents, -1) for x in inputs], dim=1)
        return inputs
    
    
    def init_hidden(self):
        
        return (
            self.decoder.skill_i.weight.new(1, self.args.entity_embed_dim).zero_(),
            self.decoder.skill_i.weight.new(1, self.args.entity_embed_dim).zero_()    
                )
        
        return (self.encoder.skill_logits.weight.new(1, self.args.entity_embed_dim).zero_(),
                self.encoder.skill_logits.weight.new(1, self.args.entity_embed_dim).zero_())
    
    def decoder_forward(self, batch, z_q_st, hidden_state_dec, task, t_env):
        """
        i: agent index
        """
        agent_inputs = self._build_inputs(batch, t_env, task) 
        x_tilde_i, hidden_states_dec = self.decoder(batch, agent_inputs, z_q_st, hidden_state_dec, task)
        
        return x_tilde_i, hidden_states_dec
    
    
    def forward_indices(self, batch, t_env, task, hidden_states_enc):
        z_e, hidden_states_enc = self.state_encoder(batch, hidden_states_enc, task, t_env)
        indices = self.codebook.staight_through4_indices(z_e, self.task2n_agents[task])

        return indices, hidden_states_enc
    
    def forward(self, batch, t_env, task, hidden_states_enc, hidden_states_dec):
        z_e, hidden_states_enc = self.state_encoder(batch, hidden_states_enc, task, t_env)
        z_q_st, z_q = self.codebook.straight_through(z_e)
        
        x_tildes, hidden_states_dec = self.decoder_forward(batch, z_q_st, hidden_states_dec, task, t_env)

        return z_e, z_q, x_tildes, hidden_states_enc, hidden_states_dec
    
    def forward_z_e(self, inputs, hidden_state_enc, hidden_state_dec, task):
        z_e = self.obs_encoder(inputs, hidden_state_enc, task)
        z_q_st, z_q = self.codebook.straight_through(z_e)

        return z_q
        


    def save_embedding_to_npy(self, file_path):
        embedding_weights = self.codebook.embedding.weight.detach().cpu().numpy()
        np.save("{}/codebook".format(file_path), embedding_weights)


    def forward_seq_action(self, seq_inputs, hidden_state_dec, task, skill):
        seq_act = []
        # hidden_state = None
        for i in range(self.c):
            act, hidden_state_dec = self.forward_action(seq_inputs[:, i, :], hidden_state_dec, task, skill)
            if i == 0:
                hidden_state = hidden_state_dec
            seq_act.append(act)
        seq_act = th.stack(seq_act, dim=1)

        return seq_act, hidden_state

    def forward_action(self, inputs, hidden_state_dec, task, skill):
        act, h_dec = self.decoder(inputs, hidden_state_dec, task, skill)
        return act, h_dec

    def forward_skill(self, inputs, hidden_state_enc, task, actions=None):
        attn_out, hidden_state_enc = self.state_encoder(inputs, hidden_state_enc, task, actions=actions)
        skill_logits = self.encoder(attn_out)
        return skill_logits, hidden_state_enc

    def forward_obs_skill(self, inputs, hidden_state_enc, task):
        attn_out, hidden_state_enc = self.obs_encoder(inputs, hidden_state_enc, task)
        skill_logits = self.encoder(attn_out)
        return skill_logits, hidden_state_enc

    def forward_qvalue(self, inputs, hidden_state_enc, task, pre_hidden=False):
        attn_out, hidden_state_enc = self.obs_encoder(inputs, hidden_state_enc, task)
        skill_logits = self.q(attn_out)
        return skill_logits, hidden_state_enc

    def forward_both(self, inputs, hidden_state_enc, task):
        attn_out, hidden_state_enc = self.obs_encoder(inputs, hidden_state_enc, task)
        skill_logits = self.q(attn_out)
        p_skill = self.encoder(attn_out)
        return skill_logits, p_skill, hidden_state_enc

class StateEncoder(nn.Module):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, decomposer, args):
        super(StateEncoder, self).__init__()

        self.task2last_action_shape = {task: task2input_shape_info[task]["last_action_shape"] for task in
                                       task2input_shape_info}
        self.task2decomposer = task2decomposer
        for key in task2decomposer.keys():
            task2decomposer_ = task2decomposer[key]
            break

        self.task2n_agents = task2n_agents
        self.args = args

        self.skill_dim = args.skill_dim

        self.embed_dim = args.mixing_embed_dim
        self.attn_embed_dim = args.attn_embed_dim
        self.entity_embed_dim = args.entity_embed_dim

        state_nf_al, state_nf_en, timestep_state_dim = \
            task2decomposer_.state_nf_al, task2decomposer_.state_nf_en, task2decomposer_.timestep_number_state_dim
        self.state_last_action, self.state_timestep_number = task2decomposer_.state_last_action, task2decomposer_.state_timestep_number

        self.n_actions_no_attack = task2decomposer_.n_actions_no_attack

        has_attack_action = task2decomposer_.n_actions_no_attack != task2decomposer_.n_actions
        if has_attack_action:
            ally_dim = state_nf_al + (self.n_actions_no_attack + 1)
            enemy_dim = state_nf_en + 1
        else:
            ally_dim = state_nf_al + self.n_actions_no_attack
            enemy_dim = state_nf_en
        self.has_attack_action = has_attack_action


        self.ally_encoder = nn.Linear(ally_dim + args.skill_dim, self.entity_embed_dim).to('cuda')
        self.enemy_encoder = nn.Linear(enemy_dim, self.entity_embed_dim).to('cuda')

        self.query = nn.Linear(self.entity_embed_dim, self.attn_embed_dim).to('cuda')
        self.key = nn.Linear(self.entity_embed_dim, self.attn_embed_dim).to('cuda')

    def forward(self, batch, hidden_state, task, t_env, actions=None):
        states = batch['state'][:,t_env,:]                             
        states = states.unsqueeze(1)                                  
        task_decomposer = self.task2decomposer[task]

        actions = batch['actions'][:, t_env,:,:]                                         
        skills = batch['skill'][:, t_env, :].unsqueeze(-2).permute(1,0,2,3).to('cuda')  
        bs = states.size(0)
        n_agents = task_decomposer.n_agents
        n_enemies = task_decomposer.n_enemies
        n_entities = n_agents + n_enemies

        ally_states, enemy_states, last_action_states, timestep_number_state = task_decomposer.decompose_state(states)
        ally_states = th.stack(ally_states, dim=0) 

        _, current_attack_action_info, current_compact_action_states = task_decomposer.decompose_action_info(F.one_hot(actions.reshape(-1), num_classes=self.task2last_action_shape[task]))
        current_compact_action_states = current_compact_action_states.reshape(bs, n_agents, -1).permute(1, 0, 2).unsqueeze(2)
        
        ally_states = th.cat([ally_states, current_compact_action_states, skills], dim=-1)  
        if self.has_attack_action:
            current_attack_action_info = current_attack_action_info.reshape(bs, n_agents, n_enemies).sum(dim=1)
            attack_action_states = (current_attack_action_info > 0).type(ally_states.dtype).reshape(bs, n_enemies, 1, 1).permute(1, 0, 2, 3)
            enemy_states = th.stack(enemy_states, dim=0)  
            enemy_states = th.cat([enemy_states, attack_action_states], dim=-1)
        else:
            enemy_states = th.stack(enemy_states, dim=0)  

        ally_embed = self.ally_encoder(ally_states)
        enemy_embed = self.enemy_encoder(enemy_states)


        entity_embed = th.cat([ally_embed, enemy_embed], dim=0)


        proj_query = self.query(entity_embed).permute(1, 2, 0, 3).reshape(bs, n_entities, self.attn_embed_dim)
        proj_key = self.key(entity_embed).permute(1, 2, 3, 0).reshape(bs, self.attn_embed_dim, n_entities)
        energy = th.bmm(proj_query / (self.attn_embed_dim ** (1 / 2)), proj_key)
        attn_score = F.softmax(energy, dim=1)
        proj_value = entity_embed.permute(1, 2, 3, 0).reshape(bs, self.entity_embed_dim, n_entities)
        attn_out = th.bmm(proj_value, attn_score).squeeze(1).permute(0, 2, 1)[:, :n_agents, :] 


        attn_out = attn_out.mean(1) 
        return attn_out, hidden_state

class ResBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(dim, dim, 3, 1, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, dim, 1),
            nn.BatchNorm2d(dim)
        )

    def forward(self, x):
        return x + self.block(x)



