import torch
import torch.nn as nn
import torch.nn.functional as F
from syn_lib.model.point import PointNetfeat
from syn_lib.model import loss
from syn_lib.utils import NetworkType
import numpy as np
from ..utils import decode_angle, generate_xyz_directional_anchors, select_res_by_index, get_gt_anchor_idx_and_residual
from smplx import MANO
from .common.body_models import build_layers, seal_mano_mesh, MANOJointRegressor
from .pointExtractNet import PointCloudFeatureExtractor
from .ResidualBlock import ResidualBlock
import trimesh
from syn_lib.model.pointbert.pointbert import PointTransformer
from .vae import PointVAE
from pytorch3d.transforms import matrix_to_axis_angle,rotation_6d_to_matrix
from .vqvae import VQVAE
from .objHandMixVqvae import SequenceVQVAE

MODEL_DIR = ""

from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes

from torch.distributions import Categorical

class sequenceNetwork(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.cfg = cfg
        self.VQVAE = SequenceVQVAE()
        self.HEATMAP = PointNetfeat()
        channel = 512 
         
        # with torch.no_grad():        
        self.mano_layer = MANO(
            MODEL_DIR,
            create_transl=False,
            use_pca=False,
            flat_hand_mean=False,
            is_rhand=True,
        )
          
        self.rot_head = decodeMLP(in_dim=channel,hidden_dim=channel//2,out_dim=6) #
        self.trans_head = decodeMLP(in_dim=channel,hidden_dim=channel//2,out_dim=3)
        self.global_dir_head = decodeMLP(in_dim=channel,hidden_dim=channel//2,out_dim=3)
        self.pose_head = decodeMLP(in_dim=channel,hidden_dim=channel//2,out_dim=90)
        
        self.rot_head_finetune = decodeMLP(in_dim=channel,hidden_dim=channel//2,out_dim=6)
        self.trans_head_finetune = decodeMLP(in_dim=channel,hidden_dim=channel//2,out_dim=3)
        self.pose_head_finetune = decodeMLP(in_dim=channel,hidden_dim=channel//2,out_dim=90) 
        
        self.angle_list = list(range(-100, 111, 5))  
        num_classes = len(self.angle_list)  
        self.obj_degree =  decodeMLP(in_dim=channel,hidden_dim=channel//2,out_dim=num_classes) 
        
    def forward(self,input,gt):
        hottest_points = input["InitContactPoint"] 
        inputFeature = input["PointCloud"]
        normFeature = input["heatMap"]
        handpc = input["PointCloudHand"]
        dir = input["dir"]
        initFeature = input["PointCloud_init"]

        dirs_normalized = dir / torch.norm(dir, dim=1, keepdim=True)
        batch_num = inputFeature.shape[0]
             
        normFeature = (normFeature - normFeature.min(dim=1, keepdim=True)[0]) / (normFeature.max(dim=1, keepdim=True)[0] - normFeature.min(dim=1, keepdim=True)[0] + 1e-8)
        mixedFeature = torch.cat([inputFeature.permute(0, 2, 1),normFeature.unsqueeze(-1).permute(0, 2, 1)],dim=1)
        heatmapFeature = self.HEATMAP(mixedFeature)
        device = heatmapFeature.device
        
        zeros = torch.ones((handpc.shape[0], 1, handpc.shape[1]), dtype=torch.float32, device=device) 
        mixedFeature_hand = torch.cat([handpc.permute(0, 2, 1),zeros],dim=1) 
        handPcFeature = self.HEATMAP(mixedFeature_hand) 

        zeros_pc = torch.ones((initFeature.shape[0], 1, initFeature.shape[1]), dtype=torch.float32, device=device) 
        mixedFeature_init = torch.cat([initFeature.permute(0, 2, 1),zeros_pc],dim=1) 
        with torch.no_grad():
            initOBJfeature = self.HEATMAP(mixedFeature_init) 
   
        decList, diffList, tokenList, helpinfo = self.VQVAE(heatmapFeature,handPcFeature,dirs_normalized,hottest_points,initOBJfeature) 
        global_rot = self.rot_head(decList[0]) 
        raw_trans = self.trans_head(decList[0])

        raw_pose = self.pose_head(decList[1])
        
        rot_finetune =  self.rot_head_finetune(decList[2])
        pose_finetune = self.pose_head_finetune(decList[2])
        trans_finetune = self.trans_head_finetune(decList[2])
        obj_degree = self.obj_degree(decList[3])
        
        rot_pred_martix_hat = rotation_6d_to_matrix(global_rot)
        rot_martix_finetune = rotation_6d_to_matrix(rot_finetune)
        ref_r = torch.tensor([0.09566994, 0.00638343, 0.0061863], dtype=torch.float32).to(device).view(-1,1,3)
        rot_martix = rot_martix_finetune @ rot_pred_martix_hat
        rot_angle = matrix_to_axis_angle(rot_martix)
        trans_pred = raw_trans + trans_finetune 

        coarse_angle = matrix_to_axis_angle(rot_pred_martix_hat) 

        pose_finetune_Martix = rotation_6d_to_matrix(pose_finetune.reshape(batch_num,-1,6))
        raw_pose_Martix =  rotation_6d_to_matrix(raw_pose.reshape(batch_num,-1,6))
        
        pred_pose = pose_finetune_Martix @ raw_pose_Martix
        
        pred_pose = matrix_to_axis_angle(pred_pose).reshape(batch_num,45)
        without_pose = torch.zeros(pred_pose.shape).to(device)
        corase_pose = matrix_to_axis_angle(raw_pose_Martix).reshape(batch_num,45) #raw_trans
        
        mixed_rot_angle = torch.cat([rot_angle,coarse_angle,coarse_angle],dim=0)
        mixed_pose = torch.cat([pred_pose,without_pose,corase_pose],dim=0)
        
        manoMesh_pred = self.mano_layer(
            global_orient=mixed_rot_angle,
            hand_pose=mixed_pose,
            betas=torch.zeros((batch_num*3, 10)).view(-1, 10).to(device),
        )
#        
        all_manoMesh_pred = manoMesh_pred.vertices[0:batch_num,:,:]
        rot_manoMesh_pred = manoMesh_pred.vertices[batch_num:batch_num*2,:,:]
        corase_manoMesh_pred = manoMesh_pred.vertices[batch_num*2:batch_num*3,:,:]
        
        pred_mesh_v3d = all_manoMesh_pred + (trans_pred).unsqueeze(1) - ref_r +hottest_points.unsqueeze(1) # 此处trans_pred 为预测残差 .view(-1, 1, 3)是否正确 #
        pred_mesh_f3d = torch.LongTensor(self.mano_layer.faces.astype(np.int64))
        pred_mesh_v3d, pred_mesh_f3d = seal_mano_mesh(pred_mesh_v3d, pred_mesh_f3d, True)
        
        pred_mesh_v3d_rot = rot_manoMesh_pred + (raw_trans).unsqueeze(1) - ref_r +hottest_points.unsqueeze(1)
        pred_mesh_f3d_rot = torch.LongTensor(self.mano_layer.faces.astype(np.int64))
        pred_mesh_v3d_rot, pred_mesh_f3d_rot = seal_mano_mesh(pred_mesh_v3d_rot, pred_mesh_f3d_rot, True) # 优化
        
        pred_mesh_v3d_corase = corase_manoMesh_pred + (raw_trans).unsqueeze(1) - ref_r +hottest_points.unsqueeze(1)
        pred_mesh_f3d_corase = torch.LongTensor(self.mano_layer.faces.astype(np.int64))
        pred_mesh_v3d_corase, pred_mesh_f3d_corase = seal_mano_mesh(pred_mesh_v3d_corase, pred_mesh_f3d_corase, True)
        
        pred = {
            "trans_pred": trans_pred - ref_r.squeeze(1) +hottest_points, #这样才从残差转为计算trans
            "pred_pose": pred_pose,
            "pred_mesh_v3d":pred_mesh_v3d,
            "obj_degree":obj_degree,
            "pred_mesh_f3d":pred_mesh_f3d.to(device),
            "hottest_points":hottest_points,
            "pred_rot_angle":rot_angle, # 总的角度 用来生成mesh
            "diff_1":diffList[0],
            "diff_2":diffList[1],
            "diff_3":diffList[2],
            "diff_obj":diffList[3],
            "pred_mesh_v3d_rot":pred_mesh_v3d_rot,
            "pred_mesh_v3d_corase":pred_mesh_v3d_corase,
            "dirs_normalized":dirs_normalized,
            "tokenList":tokenList,
            "heatmapFeature":heatmapFeature,
            "handPcFeature":handPcFeature,
            "helpinfo":helpinfo,
            "normFeature":normFeature
        }
        
        return pred
    
    def decodeCodebook(self,input): # 有一个问题gt怎么算
        initFeature = input["PointCloud_init"]
        hottest_points = input["hottest_points"]
        
        device = initFeature.device
        batch_num = initFeature.shape[0]
         
        code1 = input["code1"]
        code2 = input["code2"]
        code3 = input["code3"]
        code_obj = input["code_obj"]
        zeros_pc = torch.ones((initFeature.shape[0], 1, initFeature.shape[1]), dtype=torch.float32, device=device) # B C N
        mixedFeature_init = torch.cat([initFeature.permute(0, 2, 1),zeros_pc],dim=1) 
        # with torch.no_grad():
        initOBJfeature = self.HEATMAP(mixedFeature_init) 
        decList = self.VQVAE.decode_code(code1,code2,code3,code_obj,initOBJfeature)
        
        
        global_rot = self.rot_head(decList[0]) 
        raw_trans = self.trans_head(decList[0]

        raw_pose = self.pose_head(decList[1])
        
        rot_finetune =  self.rot_head_finetune(decList[2])
        pose_finetune = self.pose_head_finetune(decList[2])
        trans_finetune = self.trans_head_finetune(decList[2]) 
        rot_pred_martix_hat = rotation_6d_to_matrix(global_rot)
        rot_martix_finetune = rotation_6d_to_matrix(rot_finetune)
        ref_r = torch.tensor([0.09566994, 0.00638343, 0.0061863], dtype=torch.float32).to(device).view(-1,1,3)
        rot_martix = rot_martix_finetune @ rot_pred_martix_hat
        rot_angle = matrix_to_axis_angle(rot_martix)
        trans_pred = raw_trans + trans_finetune 

        pose_finetune_Martix = rotation_6d_to_matrix(pose_finetune.reshape(batch_num,-1,6))
        raw_pose_Martix =  rotation_6d_to_matrix(raw_pose.reshape(batch_num,-1,6))
        
        pred_pose = pose_finetune_Martix @ raw_pose_Martix
        pred_pose = matrix_to_axis_angle(pred_pose).reshape(batch_num,45)
        
        
        manoMesh_pred = self.mano_layer(
            global_orient=rot_angle, # 24 3
            hand_pose=pred_pose, # 24 45
            betas=torch.zeros((batch_num, 10)).view(-1, 10).to(device),
        )
        all_manoMesh_pred = manoMesh_pred.vertices
        pred_mesh_v3d = all_manoMesh_pred + (trans_pred).unsqueeze(1) - ref_r +hottest_points.unsqueeze(1) 
        pred_mesh_f3d = torch.LongTensor(self.mano_layer.faces.astype(np.int64))
        pred_mesh_v3d, pred_mesh_f3d = seal_mano_mesh(pred_mesh_v3d, pred_mesh_f3d, True)
        
        return pred_mesh_v3d, pred_mesh_f3d 