import torch
import torch.nn as nn
import numpy as np
from scipy.spatial.transform import Rotation as R
import logging
from typing import Tuple, Optional, List, Dict
import time
import os
import sys
import tqdm

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from articulate.utils.rbdl import *
from articulate.math.angular import *
import articulate as art
from utils import smpl_to_rbdl, rbdl_to_smpl, Body, _smpl_to_rbdl, _rbdl_to_smpl

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
smpl_file = 'models/SMPL_male.pkl'
body_model = art.ParametricModel(smpl_file)


class ImprovedMotionDiversity:
    def __init__(self, num_joints: int, num_frames: int,
                 body_enum_class, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        self.J = num_joints
        self.T = num_frames
        self.device = device
        self.Body = body_enum_class

        self.dt = 1.0 / 60.0

        physics_model_file = 'models/physics.urdf'
        self.model = RBDLModel(physics_model_file, update_kinematics_by_hand=True)
        self.dof = self.model.qdot_size

        self.batch_size = min(5, num_frames)

        # logger.info(f"Initialized RBDL model with {self.dof} DOF")
        # logger.info(f"Processing {num_frames} frames with improved diversity metrics")

    def compute_sequence_kinematics(self, pose_rot: torch.Tensor, trans: torch.Tensor) -> Dict[str, torch.Tensor]:

        T = pose_rot.shape[0]

        _, joint_positions = body_model.forward_kinematics(pose=pose_rot, shape=None, tran=trans, calc_mesh=False)

        lin_vel = torch.zeros_like(joint_positions)
        lin_acc = torch.zeros_like(joint_positions)

        for t in range(T):
            if t == 0:
                lin_vel[t] = (joint_positions[t + 1] - joint_positions[t]) / self.dt
            elif t == T - 1:
                lin_vel[t] = (joint_positions[t] - joint_positions[t - 1]) / self.dt
            else:
                lin_vel[t] = (joint_positions[t + 1] - joint_positions[t - 1]) / (2 * self.dt)

        for t in range(T):
            if t == 0:
                lin_acc[t] = (lin_vel[t + 1] - lin_vel[t]) / self.dt
            elif t == T - 1:
                lin_acc[t] = (lin_vel[t] - lin_vel[t - 1]) / self.dt
            else:
                lin_acc[t] = (lin_vel[t + 1] - lin_vel[t - 1]) / (2 * self.dt)

        ang_vel = torch.zeros_like(joint_positions)
        ang_acc = torch.zeros_like(joint_positions)

        for t in range(T - 1):
            delta_rot = pose_rot[t + 1] @ pose_rot[t].transpose(1, 2)

            ang_vel[t] = art.math.rotation_matrix_to_axis_angle(delta_rot.view(-1, 3, 3)).view(self.J, 3) / self.dt

        ang_vel[-1] = ang_vel[-2]

        for t in range(T):
            if t == 0:
                ang_acc[t] = (ang_vel[t + 1] - ang_vel[t]) / self.dt
            elif t == T - 1:
                ang_acc[t] = (ang_vel[t] - ang_vel[t - 1]) / self.dt
            else:
                ang_acc[t] = (ang_vel[t + 1] - ang_vel[t - 1]) / (2 * self.dt)

        return {
            'positions': joint_positions,
            'lin_vel': lin_vel,
            'lin_acc': lin_acc,
            'ang_vel': ang_vel,
            'ang_acc': ang_acc
        }

    def pose_to_rbdl_params(self, pose_rot: torch.Tensor, trans: torch.Tensor,
                            lin_vel: torch.Tensor, ang_vel: torch.Tensor,
                            lin_acc: torch.Tensor, ang_acc: torch.Tensor) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:

        q_np = np.zeros(self.dof)
        qdot_np = np.zeros(self.dof)
        qddot_np = np.zeros(self.dof)

        try:
            poses_np = pose_rot.detach().cpu().numpy()
            trans_np = trans.detach().cpu().numpy()

            if len(poses_np.shape) == 4:  # [batch, 24, 3, 3]
                batch_size = poses_np.shape[0]

                poses_np = poses_np[0:1] if batch_size > 0 else poses_np
                trans_np = trans_np[0:1] if batch_size > 0 else trans_np
            else:
                raise ValueError(f"Unexpected pose shape: {poses_np.shape}")

            if len(trans_np.shape) == 2 and trans_np.shape[1] == 3:
                pass
            elif len(trans_np.shape) == 1 and trans_np.shape[0] == 3:

                trans_np = trans_np.reshape(1, 3)
            else:
                logger.warning(f"Unexpected trans shape: {trans_np.shape}, reshaping to [1, 3]")
                trans_np = trans_np.reshape(1, 3)

            q_smpl = smpl_to_rbdl(poses_np, trans_np)
            q_np = q_smpl[0]

            root_lin_vel = lin_vel[0, 0].detach().cpu().numpy()
            root_lin_acc = lin_acc[0, 0].detach().cpu().numpy()

            qdot_np[0:3] = root_lin_vel
            qddot_np[0:3] = root_lin_acc


            for j in range(min(self.J, (self.dof - 3) // 3)):
                ang_vel_j = ang_vel[0, j].detach().cpu().numpy()
                ang_acc_j = ang_acc[0, j].detach().cpu().numpy()

                if 3 + j * 3 + 2 < self.dof:
                    qdot_np[3 + j * 3:3 + j * 3 + 3] = ang_vel_j
                    qddot_np[3 + j * 3:3 + j * 3 + 3] = ang_acc_j

        except Exception as e:
            logger.warning(f"SMPL to RBDL conversion failed: {e}")

        return q_np, qdot_np, qddot_np

    def compute_torque_from_rbdl(self, q: np.ndarray, qdot: np.ndarray, qddot: np.ndarray) -> torch.Tensor:

        try:
            qddot = np.clip(qddot, -1e3, 1e3)

            M = self.model.calc_M(q)
            h = self.model.calc_h(q, qdot)
            tau_np = M @ qddot + h


            tau_np = np.clip(tau_np, -1e4, 1e4)


            tau = torch.from_numpy(tau_np).to(self.device).float()


            if len(tau) >= self.J:
                tau = tau[6:6 + self.J] if len(tau) >= 6 + self.J else tau[:self.J]
            else:
                tau = torch.nn.functional.pad(tau, (0, self.J - len(tau)))

        except Exception as e:
            logger.warning(f"RBDL torque computation failed: {e}")
            tau = torch.zeros(self.J, device=self.device)

        return tau

    def compute_torque_jacobian_with_updated_dynamics(self, pose_rot: torch.Tensor, trans: torch.Tensor,
                                                      target_frame: int, epsilon: float = 1e-6) -> torch.Tensor:

        key_joints = list(range(self.J))
        num_key_joints = len(key_joints)

        J = torch.zeros(num_key_joints, num_key_joints * 3, device=self.device)

        original_kinematics = self.compute_sequence_kinematics(pose_rot, trans)

        pose_frame = pose_rot[target_frame:target_frame + 1]
        trans_frame = trans[target_frame:target_frame + 1]
        lin_vel_frame = original_kinematics['lin_vel'][target_frame:target_frame + 1]
        ang_vel_frame = original_kinematics['ang_vel'][target_frame:target_frame + 1]
        lin_acc_frame = original_kinematics['lin_acc'][target_frame:target_frame + 1]
        ang_acc_frame = original_kinematics['ang_acc'][target_frame:target_frame + 1]

        q_base, qdot_base, qddot_base = self.pose_to_rbdl_params(
            pose_frame, trans_frame, lin_vel_frame, ang_vel_frame, lin_acc_frame, ang_acc_frame)
        tau_base = self.compute_torque_from_rbdl(q_base, qdot_base, qddot_base)
        tau_base_key = tau_base[key_joints]

        param_idx = 0

        pose_np = pose_rot.detach().cpu().numpy()
        trans_np = trans.detach().cpu().numpy()

        for i, j in enumerate(key_joints):
            for dim in range(3):
                pose_perturbed_plus = pose_np.copy()
                rot_mat = pose_perturbed_plus[target_frame, j]
                if dim == 0:
                    delta_rot = R.from_euler('x', epsilon, degrees=False).as_matrix()
                elif dim == 1:
                    delta_rot = R.from_euler('y', epsilon, degrees=False).as_matrix()
                else:
                    delta_rot = R.from_euler('z', epsilon, degrees=False).as_matrix()

                pose_perturbed_plus[target_frame, j] = delta_rot @ rot_mat
                pose_perturbed_plus_tensor = torch.from_numpy(pose_perturbed_plus).to(self.device)

                perturbed_kinematics_plus = self.compute_sequence_kinematics(pose_perturbed_plus_tensor, trans)

                lin_vel_perturbed_plus = perturbed_kinematics_plus['lin_vel'][target_frame:target_frame + 1]
                ang_vel_perturbed_plus = perturbed_kinematics_plus['ang_vel'][target_frame:target_frame + 1]
                lin_acc_perturbed_plus = perturbed_kinematics_plus['lin_acc'][target_frame:target_frame + 1]
                ang_acc_perturbed_plus = perturbed_kinematics_plus['ang_acc'][target_frame:target_frame + 1]

                q_plus, qdot_plus, qddot_plus = self.pose_to_rbdl_params(
                    pose_perturbed_plus_tensor[target_frame:target_frame + 1],
                    trans[target_frame:target_frame + 1],
                    lin_vel_perturbed_plus, ang_vel_perturbed_plus,
                    lin_acc_perturbed_plus, ang_acc_perturbed_plus
                )
                tau_plus = self.compute_torque_from_rbdl(q_plus, qdot_plus, qddot_plus)
                tau_plus_key = tau_plus[key_joints]

                pose_perturbed_minus = pose_np.copy()
                if dim == 0:
                    delta_rot = R.from_euler('x', -epsilon, degrees=False).as_matrix()
                elif dim == 1:
                    delta_rot = R.from_euler('y', -epsilon, degrees=False).as_matrix()
                else:
                    delta_rot = R.from_euler('z', -epsilon, degrees=False).as_matrix()

                pose_perturbed_minus[target_frame, j] = delta_rot @ rot_mat
                pose_perturbed_minus_tensor = torch.from_numpy(pose_perturbed_minus).to(self.device)

                perturbed_kinematics_minus = self.compute_sequence_kinematics(pose_perturbed_minus_tensor, trans)

                lin_vel_perturbed_minus = perturbed_kinematics_minus['lin_vel'][target_frame:target_frame + 1]
                ang_vel_perturbed_minus = perturbed_kinematics_minus['ang_vel'][target_frame:target_frame + 1]
                lin_acc_perturbed_minus = perturbed_kinematics_minus['lin_acc'][target_frame:target_frame + 1]
                ang_acc_perturbed_minus = perturbed_kinematics_minus['ang_acc'][target_frame:target_frame + 1]

                q_minus, qdot_minus, qddot_minus = self.pose_to_rbdl_params(
                    pose_perturbed_minus_tensor[target_frame:target_frame + 1],
                    trans[target_frame:target_frame + 1],
                    lin_vel_perturbed_minus, ang_vel_perturbed_minus,
                    lin_acc_perturbed_minus, ang_acc_perturbed_minus
                )
                tau_minus = self.compute_torque_from_rbdl(q_minus, qdot_minus, qddot_minus)
                tau_minus_key = tau_minus[key_joints]

                derivative = (tau_plus_key - tau_minus_key) / (2 * epsilon)
                J[:, param_idx] = derivative
                param_idx += 1

        return J

    def compute_motion_dynamics(self, pose_rot: torch.Tensor, trans: torch.Tensor) -> float:

        T = pose_rot.shape[0]

        pose_changes = []
        for t in range(1, T):
            delta_rot = pose_rot[t] @ pose_rot[t - 1].transpose(1, 2)
            delta_angle = art.math.rotation_matrix_to_axis_angle(
                delta_rot.view(-1, 3, 3)).view(-1, 24, 3).norm(dim=-1)
            pose_changes.append(delta_angle.mean().item())

        trans_changes = []
        for t in range(1, T):
            delta_trans = (trans[t] - trans[t - 1]).norm().item()
            trans_changes.append(delta_trans)

        pose_dynamics = np.mean(pose_changes) if pose_changes else 0
        trans_dynamics = np.mean(trans_changes) if trans_changes else 0

        overall_dynamics = pose_dynamics + trans_dynamics * 10

        return min(1.0, overall_dynamics * 10)

    def improved_diversity_metrics(self, J_sequence: torch.Tensor) -> Dict[str, float]:

        T, J, D = J_sequence.shape

        J_flat = J_sequence.view(T, -1)

        U, s, Vt = torch.svd(J_flat)
        spectral_diversity = torch.sum(torch.log(s[s > 1e-10] + 1e-6)).item()

        joint_variance = torch.var(J_sequence, dim=[0, 2])
        variance_diversity = torch.sum(torch.log(joint_variance + 1e-6)).item()

        if T >= 4:
            segments = torch.chunk(J_sequence, 4, dim=0)
            segment_diversity = 0
            for i, seg in enumerate(segments):
                if len(seg) > 0:
                    seg_flat = seg.view(-1, J * D)
                    U_seg, s_seg, _ = torch.svd(seg_flat)
                    seg_div = torch.sum(torch.log(s_seg[s_seg > 1e-10] + 1e-6))
                    segment_diversity += seg_div.item()
            segment_diversity /= len(segments)
        else:
            segment_diversity = spectral_diversity

        dynamic_range = torch.log(torch.max(J_sequence) - torch.min(J_sequence) + 1e-6).item()

        return {
            'spectral_diversity': spectral_diversity,
            'variance_diversity': variance_diversity,
            'segment_diversity': segment_diversity,
            'dynamic_range': dynamic_range
        }

    def compute_enhanced_diversity(self, pose_rot: torch.Tensor, trans: torch.Tensor,
                                   sampling_strategy: str = "uniform") -> Dict[str, float]:

        start_time = time.time()

        # logger.info("Computing enhanced diversity metrics...")

        if sampling_strategy == "all":
            num_frames_to_process = self.T
            # num_frames_to_process = 5
            frames_to_process = list(range(num_frames_to_process))
        elif sampling_strategy == "sparse":
            step = max(1, self.T // 10)
            frames_to_process = list(range(0, self.T, step))
            num_frames_to_process = len(frames_to_process)
        else:
            max_frames = min(20, self.T)
            if self.T <= max_frames:
                frames_to_process = list(range(self.T))
            else:
                step = self.T // max_frames
                frames_to_process = list(range(0, self.T, step))[:max_frames]
            num_frames_to_process = len(frames_to_process)

        logger.info(f"Processing {num_frames_to_process} frames using {sampling_strategy} sampling strategy")

        J_sequence = []

        # for i, t in enumerate(frames_to_process):
        #     logger.info(f"Computing Jacobian for frame {i + 1}/{num_frames_to_process} (global frame {t})...")
        #
        #     J_frame = self.compute_torque_jacobian_with_updated_dynamics(pose_rot, trans, t)
        #     J_sequence.append(J_frame)

        progress_bar = tqdm.tqdm(frames_to_process, desc="Computing Jacobian matrices", unit="frame")
        for t in progress_bar:
            J_frame = self.compute_torque_jacobian_with_updated_dynamics(pose_rot, trans, t)
            J_sequence.append(J_frame)
        progress_bar.close()

        J_all = torch.stack(J_sequence)

        enhanced_metrics = self.improved_diversity_metrics(J_all)

        motion_dynamics = self.compute_motion_dynamics(pose_rot, trans)

        if motion_dynamics > 0.5:
            weights = {'spectral': 0.4, 'variance': 0.3, 'segment': 0.3}
            logger.info("High dynamic motion detected - adjusting weights")
        else:
            weights = {'spectral': 0.3, 'variance': 0.4, 'segment': 0.3}
            logger.info("Low dynamic motion detected - adjusting weights")

        final_score = (
                weights['spectral'] * enhanced_metrics['spectral_diversity'] +
                weights['variance'] * enhanced_metrics['variance_diversity'] +
                weights['segment'] * enhanced_metrics['segment_diversity']
        )

        end_time = time.time()
        processing_time = end_time - start_time

        results = {
            'final_score': final_score,
            'enhanced_metrics': enhanced_metrics,
            'motion_dynamics': motion_dynamics,
            'processing_time': processing_time,
            'num_frames_processed': num_frames_to_process,
            'total_frames': self.T,
            'sampling_strategy': sampling_strategy,
            'jacobian_shape': J_all.shape
        }

        # logger.info("=== ENHANCED DIVERSITY ANALYSIS RESULTS ===")
        # logger.info(f"Motion Dynamics: {motion_dynamics:.4f}")
        # logger.info(f"Sampling Strategy: {sampling_strategy}")
        # # logger.info(f"Frames Processed: {num_frames_processed}/{self.T}")
        # logger.info(f"Spectral Diversity: {enhanced_metrics['spectral_diversity']:.4f}")
        # logger.info(f"Variance Diversity: {enhanced_metrics['variance_diversity']:.4f}")
        # logger.info(f"Segment Diversity: {enhanced_metrics['segment_diversity']:.4f}")
        # logger.info(f"Dynamic Range: {enhanced_metrics['dynamic_range']:.4f}")
        # logger.info(f"Final Score: {final_score:.4f}")
        # logger.info(f"Processing Time: {processing_time:.2f} seconds")
        # logger.info(f"Jacobian Tensor Shape: {J_all.shape}")
        # logger.info("===========================================")

        return results


def cal_improved_diversity(pose, tran, sampling_strategy="uniform"):
    num_joints = pose.shape[1]
    num_frames = pose.shape[0]

    optimizer = ImprovedMotionDiversity(num_joints, num_frames, Body)

    pose_rot = art.math.axis_angle_to_rotation_matrix(pose.reshape(-1, 3)).reshape(-1, 24, 3, 3)

    results = optimizer.compute_enhanced_diversity(pose_rot, tran, sampling_strategy)

    return results['final_score']


if __name__ == "__main__":

    # test_names = ['run', 'walk', 't-pose', 'hand_waving', 'jump', 'small_jump', 'twist', 'tennis', 'dodge']
    test_names = ['small_jump', 'twist', 'tennis', 'dodge']
    test_dir = r'insert your path'
    test_files = []
    for name in test_names:
        test_files.append(os.path.join(test_dir, name + '.pt'))


    strategies = ["all"]

    for test_file in test_files:
        if os.path.exists(test_file):
            print(f"\n=== Testing file: {os.path.basename(test_file)} ===")
            data = torch.load(test_file)
            pose_aa, tran = data[0], data[1]
            print(f"Input pose shape: {pose_aa.shape}, tran shape: {tran.shape}")

            for strategy in strategies:
                print(f"\n--- Using {strategy} sampling strategy ---")
                result = cal_improved_diversity(pose_aa, tran, strategy)
                print(f"Final diversity score: {result}")
        else:
            print(f"File not found: {test_file}")