from data_loaders.HHI.networks.modules import *
from os.path import join as pjoin
import copy
from pytorch3d import transforms
import random
from utils.quaternion import *

def build_models(cfg):
    model = InterCLIP(cfg)

    checkpoint = torch.load(cfg.eval_checkpoint, map_location="cpu")
    # checkpoint = torch.load(pjoin('checkpoints/interclip/model/5.ckpt'),map_location="cpu")
    for k in list(checkpoint["state_dict"].keys()):
        if "model" in k:
            checkpoint["state_dict"][k.replace("model.", "")] = checkpoint["state_dict"].pop(k)
    model.load_state_dict(checkpoint["state_dict"], strict=True)

    # print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
    return model


class EvaluatorModelWrapper(object):

    def __init__(self, cfg, device):

        self.model = build_models(cfg)
        self.cfg = cfg
        self.device = device

        self.model = self.model.to(device)
        self.model.eval()
        
        self.process_mode = cfg.process_mode


    # Please note that the results does not following the order of inputs
    def get_co_embeddings(self, batch_data, primitive_utility, feet_thres=0.001):
        with torch.no_grad():
            if self.process_mode == 0:
                name, text, motion, motion_lens = batch_data
                if random.random() > 0.5:
                    temp = copy.deepcopy(motion['person1'])
                    motion['person1'] = copy.deepcopy(motion['person2'])
                    motion['person2'] = temp

                def process_motion(
                    motion_dict,         # {'joints': [B, T, J_all, 3], 'body_pose': [B, T, R]}
                    feet_thre,           # float
                    prev_frames,         # int
                    n_joints,            # int, number of joints to use (<= motion_dict['joints'].shape[2])
                    device,
                ):
                    face_joint_indx = [2,1,17,16]
                    fid_l = [7,10]
                    fid_r = [8,11]
                    
                    B, T, *_ = motion_dict['joints'].shape
                    joints = motion_dict['joints'].reshape(B, T, n_joints, 3).to(device)   # [B, T, J, 3]
                    rotations = transforms.matrix_to_rotation_6d(motion_dict['body_pose']).reshape(B, T, -1).to(device)                       # [B, T, R]

                    trans_matrix = torch.tensor([
                        [1.0, 0.0, 0.0],
                        [0.0, 0.0, 1.0],
                        [0.0, -1.0, 0.0]
                    ], dtype=joints.dtype, device=device)

                    joints = joints @ trans_matrix.T  # [B, T, J, 3]

                    floor_height = joints[..., 1].amin(dim=(1,2), keepdim=True)  # [B,1,1,1]
                    joints[..., 1] -= floor_height

                    root_init = joints[:, prev_frames, 0, :]  # [B, 3]
                    root_init_xz = root_init * torch.tensor([1, 0, 1], dtype=joints.dtype, device=device)  # [B, 3]
                    joints -= root_init_xz[:, None, None, :]

                    r_hip, l_hip, *_ = face_joint_indx
                    across = joints[:, prev_frames, r_hip] - joints[:, prev_frames, l_hip]  # [B, 3]
                    across = across / across.norm(dim=-1, keepdim=True)

                    up = torch.tensor([0, 1, 0], dtype=joints.dtype, device=device).expand(B, 3)
                    forward = torch.cross(up, across, dim=-1)
                    forward = forward / forward.norm(dim=-1, keepdim=True)
                    target = torch.tensor([0, 0, 1], dtype=joints.dtype, device=device).expand(B, 3)

                    quat = qbetween(forward, target)  # [B, 4]
                    root_quat_init_for_all = quat[:, None, None, :].expand(-1, T, n_joints, 4)

                    
                    joints = qrot(root_quat_init_for_all, joints)  # [B, T, J, 3]

                    def detect_feet(joints, fid):
                        vel = (joints[:, 1:, fid] - joints[:, :-1, fid]) ** 2  # [B, T-1, 3]
                        vel_sum = vel.sum(-1)  # [B, T-1]
                        height = joints[:, :-1, fid, 1]  # [B, T-1]
                        contact = ((vel_sum < feet_thre) & (height < 0.05)).float()
                        return contact  # [B, T-1, 1]

                    feet_l = detect_feet(joints, fid_l)
                    feet_r = detect_feet(joints, fid_r)

                    pos_flat = joints[:, :-1].reshape(B, T - 1, -1)  # [B, T-1, J*3]
                    vel_flat = (joints[:, 1:] - joints[:, :-1]).reshape(B, T - 1, -1)
                    rot_flat = rotations[:, :-1]  # [B, T-1, R]
                    data = torch.cat([pos_flat, vel_flat, rot_flat, feet_l, feet_r], dim=-1)  # [B, T-1, D']

                    return data, quat, root_init_xz  # [B, T-1, D'], [B, 4], [B, 3]

                motion1, root_quat_init1, root_pos_init1 = process_motion(motion['person1'], 0.001, 0, 22, self.device)
                motion2, root_quat_init2, root_pos_init2 = process_motion(motion['person2'], 0.001, 0, 22, self.device)
                r_relative = qmul(root_quat_init2, qinv(root_quat_init1))
                angle = torch.atan2(r_relative[:, 2:3], r_relative[:, 0:1])
                xz = qrot(root_quat_init1, root_pos_init2 - root_pos_init1)[:, [0, 2]]
                relative = torch.cat([angle, xz], dim=-1)
                motion2 = rigid_transform_tensor(relative, motion2)
                
                if random.random() > 0.5:
                    motion1, motion2 = motion2, motion1
                
                motions = torch.cat([motion1, motion2], dim=-1)
            elif self.process_mode == 1:
                name, text, motion, motion_lens = batch_data
                def tensor_to_device(tensor_dict, device):
                    return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in tensor_dict.items()}
                
                for person in ['person1', 'person2']:
                    motion[person] = tensor_to_device(motion[person], self.device)

                canonicalized_dict = {}
                transf_rotmat, transf_transl, canonicalized_dict['person1'] = primitive_utility.canonicalize(
                    copy.deepcopy(motion['person1']), use_predicted_joints=True)
                canonicalized_dict['person2'] = primitive_utility.relative_canonicalize(
                    copy.deepcopy(motion['person2']), transf_rotmat, transf_transl)
                features = {}
                for person in ['person1', 'person2']:
                    feature_dict = primitive_utility.calc_features(canonicalized_dict[person], use_predicted_joints=True)
                    if primitive_utility.feature_dim == 276:
                        feature_dict['transl'] = feature_dict['transl'][:, :-1, :]      # [B, T, 3]
                        feature_dict['poses_6d'] = feature_dict['poses_6d'][:, :-1, :]  # [B, T, 66]
                        feature_dict['joints'] = feature_dict['joints'][:, :-1, :]      # [B, T, 22 * 3]
                    features[person] = primitive_utility.dict_to_tensor(feature_dict)

                # swap the person1 and person2 if needed
                swap_mask = torch.rand(features['person1'].shape[0], device=features['person1'].device) < 0.5  # shape: (B,)
                new_person1 = features['person1'].clone()
                new_person2 = features['person2'].clone()
                new_person1[swap_mask] = features['person2'][swap_mask]
                new_person2[swap_mask] = features['person1'][swap_mask]
                features['person1'] = new_person1
                features['person2'] = new_person2
                
                motions = torch.cat([features['person1'], features['person2']], dim=-1)
            
            # motions: [B, max_lengths, D]
            B, T = motions.shape[:2]
            motion_lens = motion_lens.to(self.device)
            frame_idx = torch.arange(T, device=self.device).unsqueeze(0).expand(B, T)
            valid_mask = frame_idx < motion_lens.unsqueeze(1)  # [B, T]
            valid_mask = valid_mask.unsqueeze(-1)  # [B, T, 1]
            motions = motions * valid_mask
            
            motions = motions.detach().float()

            align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy()
            motions = motions[align_idx]
            motion_lens = motion_lens[align_idx]
            text = list(text)

            # B, T = motions.shape[:2]
            cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device)
            padded_len = cur_len.max()

            batch = {}
            batch["text"] = text
            batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len]
            batch["motion_lens"] = motion_lens

            '''Motion Encoding'''
            motion_embedding = self.model.encode_motion(batch)['motion_emb']

            '''Text Encoding'''
            text_embedding = self.model.encode_text(batch)['text_emb'][align_idx]

        return text_embedding, motion_embedding

    # Please note that the results does not following the order of inputs
    def get_motion_embeddings(self, batch_data, primitive_utility, feet_thres=0.001):
        with torch.no_grad():
            if self.process_mode == 0:
                name, text, motion, motion_lens = batch_data
                if random.random() > 0.5:
                    temp = copy.deepcopy(motion['person1'])
                    motion['person1'] = copy.deepcopy(motion['person2'])
                    motion['person2'] = temp

                def process_motion(
                    motion_dict,         # {'joints': [B, T, J_all, 3], 'body_pose': [B, T, R]}
                    feet_thre,           # float
                    prev_frames,         # int
                    n_joints,            # int, number of joints to use (<= motion_dict['joints'].shape[2])
                    device,
                ):
                    face_joint_indx = [2,1,17,16]
                    fid_l = [7,10]
                    fid_r = [8,11]
                    
                    B, T, *_ = motion_dict['joints'].shape
                    joints = motion_dict['joints'].reshape(B, T, n_joints, 3).to(device)   # [B, T, J, 3]
                    rotations = transforms.matrix_to_rotation_6d(motion_dict['body_pose']).reshape(B, T, -1).to(device)                       # [B, T, R]

                    trans_matrix = torch.tensor([
                        [1.0, 0.0, 0.0],
                        [0.0, 0.0, 1.0],
                        [0.0, -1.0, 0.0]
                    ], dtype=joints.dtype, device=device)

                    joints = joints @ trans_matrix.T  # [B, T, J, 3]

                    floor_height = joints[..., 1].amin(dim=(1,2), keepdim=True)  # [B,1,1,1]
                    joints[..., 1] -= floor_height

                    root_init = joints[:, prev_frames, 0, :]  # [B, 3]
                    root_init_xz = root_init * torch.tensor([1, 0, 1], dtype=joints.dtype, device=device)  # [B, 3]
                    joints -= root_init_xz[:, None, None, :]

                    r_hip, l_hip, *_ = face_joint_indx
                    across = joints[:, prev_frames, r_hip] - joints[:, prev_frames, l_hip]  # [B, 3]
                    across = across / across.norm(dim=-1, keepdim=True)

                    up = torch.tensor([0, 1, 0], dtype=joints.dtype, device=device).expand(B, 3)
                    forward = torch.cross(up, across, dim=-1)
                    forward = forward / forward.norm(dim=-1, keepdim=True)
                    target = torch.tensor([0, 0, 1], dtype=joints.dtype, device=device).expand(B, 3)

                    quat = qbetween(forward, target)  # [B, 4]
                    root_quat_init_for_all = quat[:, None, None, :].expand(-1, T, n_joints, 4)

                    
                    joints = qrot(root_quat_init_for_all, joints)  # [B, T, J, 3]

                    def detect_feet(joints, fid):
                        vel = (joints[:, 1:, fid] - joints[:, :-1, fid]) ** 2  # [B, T-1, 3]
                        vel_sum = vel.sum(-1)  # [B, T-1]
                        height = joints[:, :-1, fid, 1]  # [B, T-1]
                        contact = ((vel_sum < feet_thre) & (height < 0.05)).float()
                        return contact  # [B, T-1, 1]

                    feet_l = detect_feet(joints, fid_l)
                    feet_r = detect_feet(joints, fid_r)

                    pos_flat = joints[:, :-1].reshape(B, T - 1, -1)  # [B, T-1, J*3]
                    vel_flat = (joints[:, 1:] - joints[:, :-1]).reshape(B, T - 1, -1)
                    rot_flat = rotations[:, :-1]  # [B, T-1, R]
                    data = torch.cat([pos_flat, vel_flat, rot_flat, feet_l, feet_r], dim=-1)  # [B, T-1, D']

                    return data, quat, root_init_xz  # [B, T-1, D'], [B, 4], [B, 3]

                motion1, root_quat_init1, root_pos_init1 = process_motion(motion['person1'], 0.001, 0, 22, self.device)
                motion2, root_quat_init2, root_pos_init2 = process_motion(motion['person2'], 0.001, 0, 22, self.device)
                r_relative = qmul(root_quat_init2, qinv(root_quat_init1))
                angle = torch.atan2(r_relative[:, 2:3], r_relative[:, 0:1])
                xz = qrot(root_quat_init1, root_pos_init2 - root_pos_init1)[:, [0, 2]]
                relative = torch.cat([angle, xz], dim=-1)
                motion2 = rigid_transform_tensor(relative, motion2)
                
                if random.random() > 0.5:
                    motion1, motion2 = motion2, motion1
                
                motions = torch.cat([motion1, motion2], dim=-1)

            elif self.process_mode == 1:
                name, text, motion, motion_lens = batch_data
                def tensor_to_device(tensor_dict, device):
                    return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in tensor_dict.items()}
                
                for person in ['person1', 'person2']:
                    motion[person] = tensor_to_device(motion[person], self.device)

                canonicalized_dict = {}
                transf_rotmat, transf_transl, canonicalized_dict['person1'] = primitive_utility.canonicalize(
                    copy.deepcopy(motion['person1']), use_predicted_joints=True)
                canonicalized_dict['person2'] = primitive_utility.relative_canonicalize(
                    copy.deepcopy(motion['person2']), transf_rotmat, transf_transl)
                features = {}
                for person in ['person1', 'person2']:
                    feature_dict = primitive_utility.calc_features(canonicalized_dict[person], use_predicted_joints=True)
                    if primitive_utility.feature_dim == 276:
                        feature_dict['transl'] = feature_dict['transl'][:, :-1, :]      # [B, T, 3]
                        feature_dict['poses_6d'] = feature_dict['poses_6d'][:, :-1, :]  # [B, T, 66]
                        feature_dict['joints'] = feature_dict['joints'][:, :-1, :]      # [B, T, 22 * 3]
                    features[person] = primitive_utility.dict_to_tensor(feature_dict)

                # swap the person1 and person2 if needed
                swap_mask = torch.rand(features['person1'].shape[0], device=features['person1'].device) < 0.5  # shape: (B,)
                new_person1 = features['person1'].clone()
                new_person2 = features['person2'].clone()
                new_person1[swap_mask] = features['person2'][swap_mask]
                new_person2[swap_mask] = features['person1'][swap_mask]
                features['person1'] = new_person1
                features['person2'] = new_person2
                
                motions = torch.cat([features['person1'], features['person2']], dim=-1)

            # motions: [B, max_lengths, D]
            B, T = motions.shape[:2]
            motion_lens = motion_lens.to(self.device)
            frame_idx = torch.arange(T, device=self.device).unsqueeze(0).expand(B, T)
            valid_mask = frame_idx < motion_lens.unsqueeze(1)  # [B, T]
            valid_mask = valid_mask.unsqueeze(-1)  # [B, T, 1]
            motions = motions * valid_mask
            
            motions = motions.detach().to(self.device).float()

            align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy()
            motions = motions[align_idx]
            motion_lens = motion_lens[align_idx]
            text = list(text)

            # B, T = motions.shape[:2]
            cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device)
            padded_len = cur_len.max()

            batch = {}
            batch["text"] = text
            batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len]
            batch["motion_lens"] = motion_lens

            '''Motion Encoding'''
            motion_embedding = self.model.encode_motion(batch)['motion_emb']

        return motion_embedding
