import torch
from torch import nn
from torch.nn import functional as F

import distributed as dist_fn
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .tools.mlps import *
from .tools.quat import *

class SequenceVQVAE(nn.Module):
    def __init__(self,
                 in_dim=1536,
                 obj_dim = 768,
                 hidden_dim=1024,
                 embed_dim=512,
                 n_embed=1024,  
                 n_res_blocks=2):
        super().__init__()

        self.enc_1 = Encoder(in_dim//2 + embed_dim +32*2 +obj_dim, in_dim//2, n_res_blocks)
        self.enc_2 = Encoder(in_dim//2 + embed_dim +obj_dim, in_dim//4, n_res_blocks)
        self.enc_3 = Encoder(in_dim//4 + embed_dim +obj_dim, in_dim//8, n_res_blocks) 
        
        self.to_quant_1 = nn.Linear(in_dim//2, embed_dim)
        self.to_quant_2 = nn.Linear(in_dim//4, embed_dim)
        self.to_quant_3 = nn.Linear(in_dim//8, embed_dim)
        
        self.quantize_1 = QuantizeReset(n_embed,embed_dim) 
        self.quantize_2 = QuantizeReset(n_embed,embed_dim)
        self.quantize_3 = QuantizeReset(n_embed,embed_dim)
        self.dir_fuse = nn.Linear(3,32)
        self.contact_point_fuse = nn.Linear(3,32)

        self.decoder_3 = Decoder(embed_dim + embed_dim + obj_dim, hidden_dim, embed_dim, n_res_blocks)
        self.decoder_2 = Decoder(embed_dim * 2 + embed_dim + obj_dim, hidden_dim, embed_dim, n_res_blocks)
        self.decoder_1 = Decoder(embed_dim * 2 + embed_dim + obj_dim, hidden_dim, embed_dim, n_res_blocks) #embed_dim//2

        self.enc_obj = Encoder(obj_dim + 32*2, obj_dim//2, n_res_blocks)
        self.to_quant_obj = nn.Linear(obj_dim//2, obj_dim)
        self.quantize_obj = QuantizeReset(n_embed,obj_dim)
        self.decoder_obj = Decoder(obj_dim, hidden_dim, embed_dim, n_res_blocks)

    def forward(self, objPcFeature, handPcFeature,dir,hottest_points,initOBJfeature):
        dir_ex = self.dir_fuse(dir)
        hottest_points_ex = self.contact_point_fuse(hottest_points)
        
        helpinfo = torch.cat([dir_ex.squeeze(1),hottest_points_ex.squeeze(1)],dim=-1)
        quatList,diffList,codeList = self.encode(objPcFeature, handPcFeature,helpinfo,initOBJfeature)
        dec_2_input,dec_1_input,dec_0_input,dec_obj_input = self.decode(quatList,initOBJfeature)
        return [dec_2_input,dec_1_input,dec_0_input,dec_obj_input], diffList, codeList, helpinfo # 

    def encode(self,  objPcFeature, handPcFeature,helpinfo,initOBJfeature):
        dir_objPcFeature = torch.cat([objPcFeature,helpinfo],dim=-1)
        quat_objftature = self.enc_obj(dir_objPcFeature)
        quant_obj_input = self.to_quant_obj(quat_objftature)
        quant_obj, code_idxobj, diff_obj = self.quantize_obj(quant_obj_input)
        mixobjFeature = quant_obj+initOBJfeature 

        dec_obj_input = self.decoder_obj(mixobjFeature) 
        dir_handfeature = torch.cat([handPcFeature,helpinfo],dim=-1)

        enc_1_input = torch.cat([dir_handfeature, dec_obj_input, initOBJfeature], dim=-1)
        enc_1 = self.enc_1(enc_1_input)
        enc_2_input = torch.cat([enc_1, dec_obj_input,initOBJfeature], dim=-1)
        enc_2 = self.enc_2(enc_2_input)
        enc_3_input = torch.cat([enc_2, dec_obj_input,initOBJfeature], dim=-1)
        enc_3 = self.enc_3(enc_3_input)
        
        quant_1_input = self.to_quant_1(enc_1)
        quant_2_input = self.to_quant_2(enc_2)
        quant_3_input = self.to_quant_3(enc_3)
        
        quant_1, code_idx1, diff_1 = self.quantize_1(quant_1_input) # x_d, code_idx, commit_loss
        quant_2, code_idx2, diff_2 = self.quantize_1(quant_2_input)
        quant_3, code_idx3, diff_3 = self.quantize_1(quant_3_input)

        return [quant_1,quant_2,quant_3,quant_obj,dec_obj_input],[diff_1,diff_2,diff_3,diff_obj],torch.stack([code_idx3,code_idx2,code_idx1,code_idxobj],dim=1) 
        
    def decode(self,quatList,initOBJfeature):
        dec_3 = torch.cat([quatList[2], quatList[4],initOBJfeature], dim=-1)
        dec_2_input = self.decoder_3(dec_3)
        dec_2 = torch.cat([quatList[1], quatList[4], dec_2_input,initOBJfeature], dim=-1)
        dec_1_input = self.decoder_2(dec_2)
        dec_1 = torch.cat([quatList[0], quatList[4], dec_1_input,initOBJfeature], dim=-1)
        dec_0_input = self.decoder_1(dec_1)
        
        return dec_2_input,dec_1_input,dec_0_input,quatList[4] #将dec_obj_input 作为输入值

    def decode_code(self, code_1, code_2, code_3, code_obj,initOBJfeature):
        quat_1 = self.quantize_1.dequantize(code_1)
        quat_2 = self.quantize_2.dequantize(code_2)
        quat_3 = self.quantize_3.dequantize(code_3)
        quat_obj = self.quantize_obj.dequantize(code_obj)
        
        mixobjFeature = quat_obj + initOBJfeature # 768
        dec_obj_input = self.decoder_obj(mixobjFeature)
        
        quatList = [quat_1,quat_2,quat_3,quat_obj,dec_obj_input]
        
        dec_2_input,dec_1_input,dec_0_input,dec_obj_input = self.decode(quatList,initOBJfeature)
        return [dec_2_input,dec_1_input,dec_0_input,dec_obj_input] 
    
    def getDecodeFeature(self, code_1, code_2, code_3, code_obj):
        quat_1 = self.quantize_1.dequantize(code_1.reshape(-1)).reshape(code_1.shape[0],code_2.shape[-1],-1)
        quat_1 = quat_1.reshape(code_1.shape[0],code_1.shape[-1],-1) # 667,3,C
        quat_1 = torch.mean(quat_1,dim=1)
        
        quat_2 = self.quantize_2.dequantize(code_2.reshape(-1)).reshape(code_1.shape[0],code_2.shape[-1],-1)
        quat_2 = torch.mean(quat_2,dim=1)
        
        quat_3 = self.quantize_3.dequantize(code_3.reshape(-1)).reshape(code_1.shape[0],code_3.shape[-1],-1)
        quat_3 = torch.mean(quat_3,dim=1)
        
        quat_obj = self.quantize_obj.dequantize(code_obj.reshape(-1)).reshape(code_1.shape[0],code_obj.shape[-1],-1)
        quat_obj = torch.mean(quat_obj,dim=1)
         
        mix_quat = torch.cat([quat_1,quat_2,quat_3,quat_obj],axis=1)
        return mix_quat