import torch
import os
import tqdm
import numpy as np
from typing import List, Dict, Any
import sys
import logging
import argparse
import multiprocessing as mp
from multiprocessing import Pool
import time


logging.getLogger().setLevel(logging.WARNING)

sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from cal_seq_debug import ImprovedMotionDiversity, cal_improved_diversity, Body
import articulate as art


def process_single_segment(args):
    seq_idx, seg_idx, segment_pose, segment_tran, start_frame, end_frame, output_path = args

    try:
        results = cal_improved_diversity_with_metrics(segment_pose, segment_tran, sampling_strategy="all")

        output_data = {
            'pose': segment_pose,
            'tran': segment_tran,
            'idx': [seq_idx, start_frame, end_frame],
            'scores': {
                'Motion Dynamics': results['motion_dynamics'],
                'Sampling Strategy': results['sampling_strategy'],
                'Spectral Diversity': results['enhanced_metrics']['spectral_diversity'],
                'Variance Diversity': results['enhanced_metrics']['variance_diversity'],
                'Segment Diversity': results['enhanced_metrics']['segment_diversity'],
                'Dynamic Range': results['enhanced_metrics']['dynamic_range'],
                'Final Score': results['final_score']
            }
        }

        final_score_int = int(results['final_score'])
        filename = f"{seq_idx:05d}_{start_frame:05d}_{end_frame:05d}_{final_score_int:04d}.pt"
        filepath = os.path.join(output_path, filename)

        torch.save(output_data, filepath)

        return {'success': True, 'filepath': filepath, 'seq_idx': seq_idx, 'seg_idx': seg_idx}

    except Exception as e:
        return {'success': False, 'error': str(e), 'seq_idx': seq_idx, 'seg_idx': seg_idx}


def process_amass_segments_parallel(amass_path: str, start_idx: int, end_idx: int,
                                    output_path: str, segment_length: int = 100,
                                    num_processes: int = None):

    os.makedirs(output_path, exist_ok=True)

    pose_file = os.path.join(amass_path, 'pose.pt')
    tran_file = os.path.join(amass_path, 'tran.pt')

    if not os.path.exists(pose_file) or not os.path.exists(tran_file):
        raise FileNotFoundError(f"error: {amass_path}")

    all_poses = torch.load(pose_file)
    all_trans = torch.load(tran_file)

    if len(all_poses) != len(all_trans):
        raise ValueError("error")

    if end_idx > len(all_poses):
        end_idx = len(all_poses)
        print(f"error{end_idx}")

    total_seqs = end_idx - start_idx

    tasks = []
    for seq_idx in range(start_idx, end_idx):
        seq_pose = all_poses[seq_idx]
        seq_tran = all_trans[seq_idx]

        if seq_pose is None or seq_tran is None:
            continue

        if len(seq_pose) == 0 or len(seq_tran) == 0:
            continue

        min_length = min(len(seq_pose), len(seq_tran))
        seq_pose = seq_pose[:min_length]
        seq_tran = seq_tran[:min_length]

        num_segments = (min_length + segment_length - 1) // segment_length

        for seg_idx in range(num_segments):
            start_frame = seg_idx * segment_length
            end_frame = min((seg_idx + 1) * segment_length, min_length)

            segment_pose = seq_pose[start_frame:end_frame]
            segment_tran = seq_tran[start_frame:end_frame]

            if len(segment_pose) < 10:
                continue

            tasks.append((seq_idx, seg_idx, segment_pose, segment_tran,
                          start_frame, end_frame, output_path))


    if num_processes is None:
        num_processes = mp.cpu_count()

    successful_tasks = 0
    failed_tasks = 0

    with Pool(processes=num_processes) as pool:
        with tqdm.tqdm(total=len(tasks), desc="processing", unit="segment") as pbar:
            for result in pool.imap_unordered(process_single_segment, tasks):
                if result['success']:
                    successful_tasks += 1
                else:
                    failed_tasks += 1
                    print(f"fail")

                pbar.update(1)
                pbar.set_postfix({
                    'succ': successful_tasks,
                    'fail': failed_tasks,
                    'proc': f"{successful_tasks + failed_tasks}/{len(tasks)}"
                })

    print(f"complete! succ: {successful_tasks}, fail: {failed_tasks}")
    print(f"save: {output_path}")


def process_amass_segments_sequential(amass_path: str, start_idx: int, end_idx: int,
                                      output_path: str, segment_length: int = 100):

    os.makedirs(output_path, exist_ok=True)

    pose_file = os.path.join(amass_path, 'pose.pt')
    tran_file = os.path.join(amass_path, 'tran.pt')

    if not os.path.exists(pose_file) or not os.path.exists(tran_file):
        raise FileNotFoundError(f"error{amass_path}")

    all_poses = torch.load(pose_file)
    all_trans = torch.load(tran_file)

    if len(all_poses) != len(all_trans):
        raise ValueError("error")

    if end_idx > len(all_poses):
        end_idx = len(all_poses)
        print(f"error{end_idx}")

    total_seqs = end_idx - start_idx

    seq_progress = tqdm.tqdm(range(start_idx, end_idx), desc="processing", unit="seq")

    total_segments = 0

    for seq_idx in seq_progress:
        seq_pose = all_poses[seq_idx]
        seq_tran = all_trans[seq_idx]

        if seq_pose is None or seq_tran is None:
            continue

        if len(seq_pose) == 0 or len(seq_tran) == 0:
            continue

        min_length = min(len(seq_pose), len(seq_tran))
        seq_pose = seq_pose[:min_length]
        seq_tran = seq_tran[:min_length]

        num_segments = (min_length + segment_length - 1) // segment_length
        total_segments += num_segments

        segment_progress = tqdm.tqdm(range(num_segments), desc=f"{seq_idx} proccessing",
                                     unit="segment", leave=False, position=1)

        for seg_idx in segment_progress:
            start_frame = seg_idx * segment_length
            end_frame = min((seg_idx + 1) * segment_length, min_length)

            segment_pose = seq_pose[start_frame:end_frame]
            segment_tran = seq_tran[start_frame:end_frame]

            if len(segment_pose) < 10:
                continue

            try:
                results = cal_improved_diversity_with_metrics(segment_pose, segment_tran, sampling_strategy="all")

                output_data = {
                    'pose': segment_pose,
                    'tran': segment_tran,
                    'idx': [seq_idx, start_frame, end_frame],
                    'scores': {
                        'Motion Dynamics': results['motion_dynamics'],
                        'Sampling Strategy': results['sampling_strategy'],
                        'Spectral Diversity': results['enhanced_metrics']['spectral_diversity'],
                        'Variance Diversity': results['enhanced_metrics']['variance_diversity'],
                        'Segment Diversity': results['enhanced_metrics']['segment_diversity'],
                        'Dynamic Range': results['enhanced_metrics']['dynamic_range'],
                        'Final Score': results['final_score']
                    }
                }

                final_score_int = int(results['final_score'])
                filename = f"{seq_idx:05d}_{start_frame:05d}_{end_frame:05d}_{final_score_int:04d}.pt"
                filepath = os.path.join(output_path, filename)

                torch.save(output_data, filepath)

            except Exception as e:
                print(f"fail")
                continue

        segment_progress.close()

    print(f"complete！total: {total_segments} ")
    print(f"save: {output_path}")


def cal_improved_diversity_with_metrics(pose, tran, sampling_strategy="all"):

    num_joints = pose.shape[1]
    num_frames = pose.shape[0]

    optimizer = ImprovedMotionDiversity(num_joints, num_frames, Body)

    pose_rot = art.math.axis_angle_to_rotation_matrix(pose.reshape(-1, 3)).reshape(-1, 24, 3, 3)

    results = optimizer.compute_enhanced_diversity(pose_rot, tran, sampling_strategy)

    return results


def patch_mds_progress_bar():

    original_method = ImprovedMotionDiversity.compute_enhanced_diversity

    def patched_compute_enhanced_diversity(self, pose_rot, trans, sampling_strategy="uniform"):

        import tqdm

        if sampling_strategy == "all":
            num_frames_to_process = self.T
            frames_to_process = list(range(num_frames_to_process))
        elif sampling_strategy == "sparse":
            step = max(1, self.T // 10)
            frames_to_process = list(range(0, self.T, step))
            num_frames_to_process = len(frames_to_process)
        else:
            max_frames = min(20, self.T)
            if self.T <= max_frames:
                frames_to_process = list(range(self.T))
            else:
                step = self.T // max_frames
                frames_to_process = list(range(0, self.T, step))[:max_frames]
            num_frames_to_process = len(frames_to_process)

        J_sequence = []

        progress_bar = tqdm.tqdm(frames_to_process, desc="计算雅可比矩阵", unit="帧",
                                 bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
        for t in progress_bar:
            J_frame = self.compute_torque_jacobian_with_updated_dynamics(pose_rot, trans, t)
            J_sequence.append(J_frame)
        progress_bar.close()

        J_all = torch.stack(J_sequence)

        enhanced_metrics = self.improved_diversity_metrics(J_all)

        motion_dynamics = self.compute_motion_dynamics(pose_rot, trans)

        if motion_dynamics > 0.5:
            weights = {'spectral': 0.4, 'variance': 0.3, 'segment': 0.3}
        else:
            weights = {'spectral': 0.3, 'variance': 0.4, 'segment': 0.3}

        final_score = (
                weights['spectral'] * enhanced_metrics['spectral_diversity'] +
                weights['variance'] * enhanced_metrics['variance_diversity'] +
                weights['segment'] * enhanced_metrics['segment_diversity']
        )

        results = {
            'final_score': final_score,
            'enhanced_metrics': enhanced_metrics,
            'motion_dynamics': motion_dynamics,
            'sampling_strategy': sampling_strategy,
        }

        return results

    ImprovedMotionDiversity.compute_enhanced_diversity = patched_compute_enhanced_diversity


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="compute MDS of AMASS")
    parser.add_argument("--amass_path", type=str, required=True, help="AMASS path")
    parser.add_argument("--start_idx", type=int, default=0, help="start index")
    parser.add_argument("--end_idx", type=int, required=True, help="end index")
    parser.add_argument("--output_path", type=str, required=True, help="output path")
    parser.add_argument("--segment_length", type=int, default=100, help="seq length")
    parser.add_argument("--num_processes", type=int, default=None, help="num_processes")
    parser.add_argument("--sequential", action="store_true", help="not parallel")

    args = parser.parse_args()

    patch_mds_progress_bar()

    start_time = time.time()

    if args.sequential:
        process_amass_segments_sequential(
            amass_path=args.amass_path,
            start_idx=args.start_idx,
            end_idx=args.end_idx,
            output_path=args.output_path,
            segment_length=args.segment_length
        )
    else:
        process_amass_segments_parallel(
            amass_path=args.amass_path,
            start_idx=args.start_idx,
            end_idx=args.end_idx,
            output_path=args.output_path,
            segment_length=args.segment_length,
            num_processes=args.num_processes
        )

    end_time = time.time()
    print(f"total time: {end_time - start_time:.2f} sec")