import train
import os
import time
import csv
import sys
import warnings
import random
import numpy as np
import time
import pprint
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP
from loguru import logger
import smplx

from utils import config, logger_tools, other_tools, metric
from utils import rotation_conversions as rc
from dataloaders import data_tools
from optimizers.optim_factory import create_optimizer
from optimizers.scheduler_factory import create_scheduler
from optimizers.loss_factory import get_loss_func
from scipy.spatial.transform import Rotation


class CustomTrainer(train.BaseTrainer):
    """
    motion representation learning
    """
    def __init__(self, args):
        super().__init__(args)
        self.joints = self.train_data.joints
        self.tracker = other_tools.EpochTracker(["rec", "vel", "acc", "com", "face", "face_vel", "face_acc", "ver", "ver_vel", "ver_acc"], [False, False, False, False, False, False, False, False, False, False])
        self.rec_loss = get_loss_func("GeodesicLoss")
        self.mse_loss = torch.nn.MSELoss(reduction='mean')
        self.vel_loss = torch.nn.MSELoss(reduction='mean') #torch.nn.L1Loss(reduction='mean')
        self.vectices_loss = torch.nn.MSELoss(reduction='mean')
    
    def inverse_selection(self, filtered_t, selection_array, n):
        # 创建一个全为零的数组，形状为 n*165
        original_shape_t = np.zeros((n, selection_array.size))
        
        # 找到选择数组中为1的索引位置
        selected_indices = np.where(selection_array == 1)[0]
        
        # 将 filtered_t 的值填充到 original_shape_t 中相应的位置
        for i in range(n):
            original_shape_t[i, selected_indices] = filtered_t[i]
            
        return original_shape_t

    def train(self, epoch):
        self.model.train()
        t_start = time.time()
        self.tracker.reset()
        for its, dict_data in enumerate(self.train_loader):
            tar_pose = dict_data["pose"]
            tar_caption = dict_data["caption"].cuda()
            tar_beta = dict_data["beta"].cuda()
            tar_trans = dict_data["trans"].cuda()
            tar_pose = tar_pose.cuda()  
            bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
            tar_exps = dict_data["facial"].to(self.rank)
            tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3))
            tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
            in_tar_pose = torch.cat([tar_pose, tar_exps], -1) # 103
            t_data = time.time() - t_start 
            
            self.opt.zero_grad()
            g_loss_final = 0
            net_out = self.model(in_tar_pose, tar_caption)
            # jaw open 6d loss
            rec_pose = net_out["rec_pose"][:, :, :j*6]
            rec_pose = rec_pose.reshape(bs, n, j, 6)
            rec_pose = rc.rotation_6d_to_matrix(rec_pose)#
            tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6))
            loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight
            self.tracker.update_meter("rec", "train", loss_rec.item())
            g_loss_final += loss_rec
            # jaw open 6d vel and acc loss
            velocity_loss =  self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1]) * self.args.rec_weight
            acceleration_loss =  self.vel_loss(rec_pose[:, 2:] + rec_pose[:, :-2] - 2 * rec_pose[:, 1:-1], tar_pose[:, 2:] + tar_pose[:, :-2] - 2 * tar_pose[:, 1:-1]) * self.args.rec_weight
            self.tracker.update_meter("vel", "train", velocity_loss.item())
            self.tracker.update_meter("acc", "train", acceleration_loss.item())
            g_loss_final += velocity_loss 
            g_loss_final += acceleration_loss 
            # face parameter l1 loss
            rec_exps = net_out["rec_pose"][:, :, j*6:]
            loss_face = self.mse_loss(rec_exps, tar_exps) * self.args.rec_weight
            self.tracker.update_meter("face", "train", loss_face.item())
            g_loss_final += loss_face
            # face parameter l1 vel and acc loss
            face_velocity_loss =  self.vel_loss(rec_exps[:, 1:] - rec_exps[:, :-1], tar_exps[:, 1:] - tar_exps[:, :-1]) * self.args.rec_weight
            face_acceleration_loss =  self.vel_loss(rec_exps[:, 2:] + rec_exps[:, :-2] - 2 * rec_exps[:, 1:-1], tar_exps[:, 2:] + tar_exps[:, :-2] - 2 * tar_exps[:, 1:-1]) * self.args.rec_weight
            self.tracker.update_meter("face_vel", "train", face_velocity_loss.item())
            self.tracker.update_meter("face_acc", "train", face_acceleration_loss.item())
            g_loss_final += face_velocity_loss
            g_loss_final += face_acceleration_loss

             # vertices loss
            if self.args.rec_ver_weight > 0:
                tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3)
                rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3)
                vertices_rec = self.smplx(
                    betas=tar_beta.reshape(bs*n, 300), 
                    transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3), 
                    expression=tar_exps.reshape(bs*n, 100), 
                    jaw_pose=rec_pose, 
                    global_orient=torch.zeros(bs*n, 3).cuda(), 
                    body_pose=torch.zeros(bs*n, 21*3).cuda(), 
                    left_hand_pose=torch.zeros(bs*n, 15*3).cuda(), 
                    right_hand_pose=torch.zeros(bs*n, 15*3).cuda(), 
                    return_verts=True,
                    # return_joints=True,
                    leye_pose=torch.zeros(bs*n, 3).cuda(), 
                    reye_pose=torch.zeros(bs*n, 3).cuda(),
                )
                vertices_tar = self.smplx(
                    betas=tar_beta.reshape(bs*n, 300), 
                    transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3), 
                    expression=rec_exps.reshape(bs*n, 100), 
                    jaw_pose=tar_pose, 
                    global_orient=torch.zeros(bs*n, 3).cuda(), 
                    body_pose=torch.zeros(bs*n, 21*3).cuda(), 
                    left_hand_pose=torch.zeros(bs*n, 15*3).cuda(), 
                    right_hand_pose=torch.zeros(bs*n, 15*3).cuda(), 
                    return_verts=True,
                    # return_joints=True,
                    leye_pose=torch.zeros(bs*n, 3).cuda(), 
                    reye_pose=torch.zeros(bs*n, 3).cuda(),
                )  
                vectices_loss = self.mse_loss(vertices_rec['vertices'], vertices_tar['vertices'])
                self.tracker.update_meter("ver", "train", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight)
                g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight
                # vertices vel and acc loss
                vert_velocity_loss =  self.vel_loss(vertices_rec['vertices'][:, 1:] - vertices_rec['vertices'][:, :-1], vertices_tar['vertices'][:, 1:] - vertices_tar['vertices'][:, :-1]) * self.args.rec_weight * self.args.rec_ver_weight
                vert_acceleration_loss =  self.vel_loss(vertices_rec['vertices'][:, 2:] + vertices_rec['vertices'][:, :-2] - 2 * vertices_rec['vertices'][:, 1:-1], vertices_tar['vertices'][:, 2:] + vertices_tar['vertices'][:, :-2] - 2 * vertices_tar['vertices'][:, 1:-1]) * self.args.rec_weight * self.args.rec_ver_weight
                self.tracker.update_meter("ver_vel", "train", vert_velocity_loss.item())
                self.tracker.update_meter("ver_acc", "train", vert_acceleration_loss.item())
                g_loss_final += vert_velocity_loss
                g_loss_final += vert_acceleration_loss
            
            # ---------------------- vae -------------------------- #
            if "VQVAE" in self.args.g_name:
                loss_embedding = net_out["embedding_loss"].mean()
                g_loss_final += loss_embedding
                self.tracker.update_meter("com", "train", loss_embedding.item())
            # elif "VAE" in self.args.g_name:
            #     pose_mu, pose_logvar = net_out["pose_mu"], net_out["pose_logvar"] 
            #     KLD = -0.5 * torch.sum(1 + pose_logvar - pose_mu.pow(2) - pose_logvar.exp())
            #     if epoch < 0:
            #         KLD_weight = 0
            #     else:
            #         KLD_weight = min(1.0, (epoch - 0) * 0.05) * 0.01
            #     loss += KLD_weight * KLD
            #     self.tracker.update_meter("kl", "train", KLD_weight * KLD.item())    
            g_loss_final.backward()
            if self.args.grad_norm != 0: 
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm)
            self.opt.step()
            t_train = time.time() - t_start - t_data
            t_start = time.time()
            mem_cost = torch.cuda.memory_cached() / 1E9
            lr_g = self.opt.param_groups[0]['lr']
            if its % self.args.log_period == 0:
                self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g)   
            if self.args.debug:
                if its == 1: break
        self.opt_s.step(epoch)
                    
    def val(self, epoch):
        self.model.eval()
        t_start = time.time()
        with torch.no_grad():
            for its, dict_data in enumerate(self.val_loader):
                tar_pose = dict_data["pose"]
                tar_beta = dict_data["beta"].cuda()
                tar_trans = dict_data["trans"].cuda()
                tar_pose = tar_pose.cuda()  
                bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
                tar_exps = dict_data["facial"].to(self.rank)
                tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3))
                tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
                in_tar_pose = torch.cat([tar_pose, tar_exps], -1) # 103
                tar_caption = dict_data["caption"].cuda()
                # print(tar_pose.shape, in_tar_pose.shape, tar_exps.shape)
                t_data = time.time() - t_start 

                #self.opt.zero_grad()
                #g_loss_final = 0
                net_out = self.model(in_tar_pose, tar_caption)
                            # jaw open 6d loss
                rec_pose = net_out["rec_pose"][:, :, :j*6]
                rec_pose = rec_pose.reshape(bs, n, j, 6)
                rec_pose = rc.rotation_6d_to_matrix(rec_pose)#
                tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6))
                loss_rec = self.rec_loss(rec_pose, tar_pose) * self.args.rec_weight * self.args.rec_pos_weight
                self.tracker.update_meter("rec", "val", loss_rec.item())
                # g_loss_final += loss_rec
                # jaw open 6d vel and acc loss
                velocity_loss =  self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1]) * self.args.rec_weight
                acceleration_loss =  self.vel_loss(rec_pose[:, 2:] + rec_pose[:, :-2] - 2 * rec_pose[:, 1:-1], tar_pose[:, 2:] + tar_pose[:, :-2] - 2 * tar_pose[:, 1:-1]) * self.args.rec_weight
                self.tracker.update_meter("vel", "val", velocity_loss.item())
                self.tracker.update_meter("acc", "val", acceleration_loss.item())
                # g_loss_final += velocity_loss 
                # g_loss_final += acceleration_loss 
                # face parameter l1 loss
                rec_exps = net_out["rec_pose"][:, :, j*6:]
                loss_face = self.vel_loss(rec_exps, tar_exps) * self.args.rec_weight
                self.tracker.update_meter("face", "val", loss_face.item())
                # g_loss_final += loss_face
                # face parameter l1 vel and acc loss
                face_velocity_loss =  self.vel_loss(rec_exps[:, 1:] - rec_exps[:, :-1], tar_exps[:, 1:] - tar_exps[:, :-1]) * self.args.rec_weight
                face_acceleration_loss =  self.vel_loss(rec_exps[:, 2:] + rec_exps[:, :-2] - 2 * rec_exps[:, 1:-1], tar_exps[:, 2:] + tar_exps[:, :-2] - 2 * tar_exps[:, 1:-1]) * self.args.rec_weight
                self.tracker.update_meter("face_vel", "val", face_velocity_loss.item())
                self.tracker.update_meter("face_acc", "val", face_acceleration_loss.item())
                # g_loss_final += face_velocity_loss
                # g_loss_final += face_acceleration_loss

                # vertices loss
                if self.args.rec_ver_weight > 0:
                    tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3)
                    rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3)
                    vertices_rec = self.smplx(
                        betas=tar_beta.reshape(bs*n, 300), 
                        transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3), 
                        expression=tar_exps.reshape(bs*n, 100), 
                        jaw_pose=rec_pose, 
                        global_orient=torch.zeros(bs*n, 3).cuda(), 
                        body_pose=torch.zeros(bs*n, 21*3).cuda(), 
                        left_hand_pose=torch.zeros(bs*n, 15*3).cuda(), 
                        right_hand_pose=torch.zeros(bs*n, 15*3).cuda(), 
                        return_verts=True,
                        # return_joints=True,
                        leye_pose=torch.zeros(bs*n, 3).cuda(), 
                        reye_pose=torch.zeros(bs*n, 3).cuda(),
                    )
                    vertices_tar = self.smplx(
                        betas=tar_beta.reshape(bs*n, 300), 
                        transl=tar_trans.reshape(bs*n, 3)-tar_trans.reshape(bs*n, 3), 
                        expression=rec_exps.reshape(bs*n, 100), 
                        jaw_pose=tar_pose, 
                        global_orient=torch.zeros(bs*n, 3).cuda(), 
                        body_pose=torch.zeros(bs*n, 21*3).cuda(), 
                        left_hand_pose=torch.zeros(bs*n, 15*3).cuda(), 
                        right_hand_pose=torch.zeros(bs*n, 15*3).cuda(), 
                        return_verts=True,
                        # return_joints=True,
                        leye_pose=torch.zeros(bs*n, 3).cuda(), 
                        reye_pose=torch.zeros(bs*n, 3).cuda(),
                    )  
                    vectices_loss = self.mse_loss(vertices_rec['vertices'], vertices_tar['vertices'])
                    self.tracker.update_meter("ver", "val", vectices_loss.item()*self.args.rec_weight * self.args.rec_ver_weight)
                    # g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight
                    # vertices vel and acc loss
                    vert_velocity_loss =  self.vel_loss(vertices_rec['vertices'][:, 1:] - vertices_rec['vertices'][:, :-1], vertices_tar['vertices'][:, 1:] - vertices_tar['vertices'][:, :-1]) * self.args.rec_weight * self.args.rec_ver_weight
                    vert_acceleration_loss =  self.vel_loss(vertices_rec['vertices'][:, 2:] + vertices_rec['vertices'][:, :-2] - 2 * vertices_rec['vertices'][:, 1:-1], vertices_tar['vertices'][:, 2:] + vertices_tar['vertices'][:, :-2] - 2 * vertices_tar['vertices'][:, 1:-1]) * self.args.rec_weight * self.args.rec_ver_weight
                    self.tracker.update_meter("ver_vel", "val", vert_velocity_loss.item())
                    self.tracker.update_meter("ver_acc", "val", vert_acceleration_loss.item())
                    # g_loss_final += vert_velocity_loss
                    # g_loss_final += vert_acceleration_loss
                    if "VQVAE" in self.args.g_name:
                        loss_embedding = net_out["embedding_loss"].mean()
                        self.tracker.update_meter("com", "val", loss_embedding.item())
                    #g_loss_final += vectices_loss*self.args.rec_weight*self.args.rec_ver_weight
        self.val_recording(epoch)
            
    def test(self, epoch):
        results_save_path = self.checkpoint_path + f"/{epoch}/"
        if os.path.exists(results_save_path): 
            return 0
        os.makedirs(results_save_path)
        start_time = time.time()
        total_length = 0
        test_seq_list = self.test_data.selected_file
        mse_joints = []
        self.model.eval()
        with torch.no_grad():
            for its, dict_data in enumerate(self.test_loader):
                tar_pose = dict_data["pose"].cuda()
                tar_caption = dict_data["caption"].cuda()
                #tar_pose = tar_pose.cuda()
                tar_exps = dict_data["facial"].to(self.rank)
                bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
                tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3))
                tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
                remain = n%self.args.pose_length
                tar_pose = tar_pose[:, :n-remain, :]
                # print(tar_exps.shape)
                in_tar_pose = torch.cat([tar_pose, tar_exps[:, :n-remain, :]], -1) # 103
                tar_caption = tar_caption[:, :n-remain, :]
                #print(tar_pose.shape)
                if True:
                    net_out = self.model(in_tar_pose, tar_caption)
                    rec_pose = net_out["rec_pose"][:, :, :j*6]
                    n = rec_pose.shape[1]
                    tar_pose = tar_pose[:, :n, :]
                    rec_pose = rec_pose.reshape(bs, n, j, 6) 
                    rec_pose = rc.rotation_6d_to_matrix(rec_pose)#
                    rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3)
                    rec_pose = rec_pose.cpu().numpy()
                    rec_exps = net_out["rec_pose"][:, :, j*6:]
                    rec_exps = rec_exps.cpu().numpy().reshape(bs*n, 100)
                else:
                    pass
#                     for i in range(tar_pose.shape[1]//(self.args.vae_test_len)):
#                         tar_pose_new = tar_pose[:,i*(self.args.vae_test_len):i*(self.args.vae_test_len)+self.args.vae_test_len,:]
#                         net_out = self.model(**dict(inputs=tar_pose_new))
#                         rec_pose = net_out["rec_pose"]
#                         rec_pose = (rec_pose.reshape(rec_pose.shape[0], rec_pose.shape[1], -1, 6) * self.joint_level_mask_cuda).reshape(rec_pose.shape[0], rec_pose.shape[1], -1)
#                         if "rot6d" in self.args.pose_rep:
#                             rec_pose = data_transfer.rotation_6d_to_matrix(rec_pose.reshape(tar_pose.shape[0], self.args.vae_test_len, -1, 6))
#                             rec_pose = data_transfer.matrix_to_euler_angles(rec_pose, "XYZ").reshape(rec_pose.shape[0], rec_pose.shape[1], -1)
#                             if "smplx" not in self.args.pose_rep:
#                                 rec_pose = torch.rad2deg(rec_pose)
#                             rec_pose = rec_pose * self.joint_mask_cuda
                            
#                         out_sub = rec_pose.cpu().numpy().reshape(-1, rec_pose.shape[2])
#                         if i != 0:
#                             out_final = np.concatenate((out_final,out_sub), 0)
#                         else:
#                             out_final = out_sub
                            
                tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6))
                tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3)
                tar_pose = tar_pose.cpu().numpy()
                mse_joints.append(float(np.mean(np.abs(tar_pose - rec_pose))))
                total_length += n 
                # --- save --- #
                if 'smplx' in self.args.pose_rep:
                    gt_npz = np.load(self.args.data_path+self.args.pose_rep+"/"+test_seq_list.iloc[its]['id']+'.npz', allow_pickle=True)
                    stride = int(30 / self.args.pose_fps)
                    tar_pose = self.inverse_selection(tar_pose, self.test_data.joint_mask, tar_pose.shape[0])
                    np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz',
                        betas=gt_npz["betas"],
                        poses=tar_pose[:n],
                        expressions=gt_npz["expressions"],
                        trans=gt_npz["trans"][::stride][:n] - gt_npz["trans"][::stride][:n],
                        model='smplx2020',
                        gender='neutral',
                        mocap_frame_rate = 30 ,
                    )
                    rec_pose = self.inverse_selection(rec_pose, self.test_data.joint_mask, rec_pose.shape[0])
                    np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz',
                        betas=gt_npz["betas"],
                        poses=rec_pose,
                        expressions=rec_exps,
                        trans=gt_npz["trans"][::stride][:n] - gt_npz["trans"][::stride][:n],
                        model='smplx2020',
                        gender='neutral',
                        mocap_frame_rate = 30 ,
                    )       
                else:
                    rec_pose = rc.axis_angle_to_matrix(torch.from_numpy(rec_pose.reshape(bs*n, j, 3)))
                    rec_pose = np.rad2deg(rc.matrix_to_euler_angles(rec_pose, "XYZ")).reshape(bs*n, j*3).numpy()                
                    tar_pose = rc.axis_angle_to_matrix(torch.from_numpy(tar_pose.reshape(bs*n, j, 3)))
                    tar_pose = np.rad2deg(rc.matrix_to_euler_angles(tar_pose, "XYZ")).reshape(bs*n, j*3).numpy() 
                    #trans="0.000000 0.000000 0.000000"
                    
                    with open(f"{self.args.data_path}{self.args.pose_rep}/{test_seq_list.iloc[its]['id']}.bvh", "r") as f_demo:
                        with open(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_gt:
                            with open(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.bvh', 'w+') as f_real:
                                for i, line_data in enumerate(f_demo.readlines()):
                                    if i < 431:
                                        f_real.write(line_data)
                                        f_gt.write(line_data)
                                    else: break
                                for line_id in range(n): #,args.pre_frames, args.pose_length
                                    line_data = np.array2string(rec_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ')
                                    f_real.write(line_data[1:-2]+'\n')
                                for line_id in range(n): #,args.pre_frames, args.pose_length
                                    line_data = np.array2string(tar_pose[line_id], max_line_width=np.inf, precision=6, suppress_small=False, separator=' ')
                                    f_gt.write(line_data[1:-2]+'\n')
                # with open(results_save_path+"gt_"+test_seq_list[its]+'.pkl', 'wb') as fw:
                #     pickle.dump(new_dict, fw)
                # #new_dict2["fullpose"] = out_final
                # with open(results_save_path+"res_"+test_seq_list[its]+'.pkl', 'wb') as fw1:
                #     pickle.dump(new_dict2, fw1)

                # other_tools.render_one_sequence(
                #     results_save_path+"res_"+test_seq_list[its]+'.pkl',
                #     results_save_path+"gt_"+test_seq_list[its]+'.pkl',
                #     results_save_path,
                #     self.args.data_path + self.args.test_data_path + 'wave16k/' + test_seq_list[its]+'.npy',
                # )
                                                                                                
                #if its == 1:break
        end_time = time.time() - start_time
        logger.info(f"joints mse: {np.mean(mse_joints)}")
        logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion")