import os
import smplx
import torch
import numpy as np
import cv2
import librosa
import igraph
import json
import utils.rotation_conversions as rc
from moviepy.editor import VideoClip, AudioFileClip
from tqdm import tqdm
import imageio
import tempfile

def get_motion_reps_tensor(motion_tensor, smplx_model, pose_fps=30, device='cuda'):
    bs, n, _ = motion_tensor.shape
    motion_tensor = motion_tensor.float().to(device)
    motion_tensor_reshaped = motion_tensor.reshape(bs * n, 165)
    
    output = smplx_model(
        betas=torch.zeros(bs * n, 300, device=device),
        transl=torch.zeros(bs * n, 3, device=device),
        expression=torch.zeros(bs * n, 100, device=device),
        jaw_pose=torch.zeros(bs * n, 3, device=device),
        global_orient=torch.zeros(bs * n, 3, device=device),
        body_pose=motion_tensor_reshaped[:, 3:21 * 3 + 3],
        left_hand_pose=motion_tensor_reshaped[:, 25 * 3:40 * 3],
        right_hand_pose=motion_tensor_reshaped[:, 40 * 3:55 * 3],
        return_joints=True,
        leye_pose=torch.zeros(bs * n, 3, device=device),
        reye_pose=torch.zeros(bs * n, 3, device=device),
    )
    
    joints = output['joints'].reshape(bs, n, 127, 3)[:, :, :55, :]
    dt = 1 / pose_fps
    init_vel = (joints[:, 1:2] - joints[:, 0:1]) / dt
    middle_vel = (joints[:, 2:] - joints[:, :-2]) / (2 * dt)
    final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
    vel = torch.cat([init_vel, middle_vel, final_vel], dim=1)
    
    position = joints
    rot_matrices = rc.axis_angle_to_matrix(motion_tensor.reshape(bs, n, 55, 3))
    rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(bs, n, 55, 6)

    init_vel_ang = (motion_tensor[:, 1:2] - motion_tensor[:, 0:1]) / dt
    middle_vel_ang = (motion_tensor[:, 2:] - motion_tensor[:, :-2]) / (2 * dt)
    final_vel_ang = (motion_tensor[:, -1:] - motion_tensor[:, -2:-1]) / dt
    angular_velocity = torch.cat([init_vel_ang, middle_vel_ang, final_vel_ang], dim=1).reshape(bs, n, 55, 3)

    rep15d = torch.cat([position, vel, rot6d, angular_velocity], dim=3).reshape(bs, n, 55 * 15)
    
    return {
        "position": position,
        "velocity": vel,
        "rotation": rot6d,
        "axis_angle": motion_tensor,
        "angular_velocity": angular_velocity,
        "rep15d": rep15d,
    }


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
smplx_model = smplx.create(
    "./beat2/smplx_models/",
    model_type='smplx',
    gender='NEUTRAL_2020',
    use_face_contour=False,
    num_betas=300,
    num_expression_coeffs=100,
    ext='npz',
    use_pca=False,
).to(device).eval()

def get_motion_reps(motion, smplx_model=smplx_model, pose_fps=30):
    gt_motion_tensor = motion["poses"]
    n = gt_motion_tensor.shape[0]
    bs = 1
    gt_motion_tensor = torch.from_numpy(gt_motion_tensor).float().to(device).unsqueeze(0)
    gt_motion_tensor_reshaped = gt_motion_tensor.reshape(bs * n, 165)
    output = smplx_model(
        betas=torch.zeros(bs * n, 300).to(device),
        transl=torch.zeros(bs * n, 3).to(device),
        expression=torch.zeros(bs * n, 100).to(device),
        jaw_pose=torch.zeros(bs * n, 3).to(device),
        global_orient=torch.zeros(bs * n, 3).to(device),
        body_pose=gt_motion_tensor_reshaped[:, 3:21 * 3 + 3],
        left_hand_pose=gt_motion_tensor_reshaped[:, 25 * 3:40 * 3],
        right_hand_pose=gt_motion_tensor_reshaped[:, 40 * 3:55 * 3],
        return_joints=True,
        leye_pose=torch.zeros(bs * n, 3).to(device),
        reye_pose=torch.zeros(bs * n, 3).to(device),
    )
    joints = output["joints"].detach().cpu().numpy().reshape(n, 127, 3)[:, :55, :]
    dt = 1 / pose_fps
    init_vel = (joints[1:2] - joints[0:1]) / dt
    middle_vel = (joints[2:] - joints[:-2]) / (2 * dt)
    final_vel = (joints[-1:] - joints[-2:-1]) / dt
    vel = np.concatenate([init_vel, middle_vel, final_vel], axis=0)
    position = joints
    rot_matrices = rc.axis_angle_to_matrix(gt_motion_tensor.reshape(1, n, 55, 3))[0]
    rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(n, 55, 6).cpu().numpy()
    
    init_vel = (motion["poses"][1:2] - motion["poses"][0:1]) / dt
    middle_vel = (motion["poses"][2:] - motion["poses"][:-2]) / (2 * dt)
    final_vel = (motion["poses"][-1:] - motion["poses"][-2:-1]) / dt
    angular_velocity = np.concatenate([init_vel, middle_vel, final_vel], axis=0).reshape(n, 55, 3)

    rep15d = np.concatenate([
        position,
        vel,
        rot6d,
        angular_velocity],
        axis=2
    ).reshape(n, 55*15)
    return {
        "position": position,
        "velocity": vel,
        "rotation": rot6d,
        "axis_angle": motion["poses"],
        "angular_velocity": angular_velocity,
        "rep15d": rep15d,
    }

def create_graph(json_path):
    fps = 30
    data_meta = json.load(open(json_path, "r"))
    graph = igraph.Graph(directed=True)
    global_i = 0
    for data_item in data_meta:
        video_path = data_item['video_path'] + ".mp4"
        audio_path = data_item['audio_path'] + ".wav"
        motion_path = data_item['motion_path'] + ".npz"
        video_id = data_item.get("video_id", "")
        motion = np.load(motion_path, allow_pickle=True)
        motion_reps = get_motion_reps(motion)
        position = motion_reps['position']
        velocity = motion_reps['velocity']
        axis_angle = motion_reps['axis_angle']
        audio, sr = librosa.load(audio_path, sr=None)
        audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
        all_frames = []
        reader = imageio.get_reader(video_path)
        all_frames = []
        for frame in reader:
            all_frames.append(frame)
        video_frames = np.array(all_frames)
        min_frames = min(len(video_frames), position.shape[0])
        position = position[:min_frames]
        velocity = velocity[:min_frames]
        video_frames = video_frames[:min_frames]
        # print(min_frames)
        for i in tqdm(range(min_frames)):
            if i == 0:
                previous = -1
                next_node = global_i + 1
            elif i == min_frames - 1:
                previous = global_i - 1
                next_node = -1
            else:
                previous = global_i - 1
                next_node = global_i + 1
            graph.add_vertex(
                idx=global_i,
                name=video_id,
                motion=motion_reps,
                position=position[i],
                velocity=velocity[i],
                axis_angle=axis_angle[i],
                # audio=audio[],
                video=video_frames[i],
                previous=previous,
                next=next_node,
                frame=i,
                fps=fps,
            )
            global_i += 1
    return graph

def create_edges(graph):
    adaptive_length = [-4, -3, -2, -1, 1, 2, 3, 4]
    # print()
    for i, node in enumerate(graph.vs):
        current_position = node['position']
        current_velocity = node['velocity']
        avg_position = 0
        avg_velocity = 0
        count = 0
        for node_offset in adaptive_length:
            idx = i + node_offset
            if idx < 0 or idx >= len(graph.vs):
                continue
            if node_offset < 0:
              if graph.vs[idx]['next'] == -1:continue
            else:
              if graph.vs[idx]['previous'] == -1:continue
            # add check
            other_node = graph.vs[idx]
            other_position = other_node['position']
            other_velocity = other_node['velocity']
            avg_position += np.linalg.norm(current_position - other_position)
            avg_velocity += np.linalg.norm(current_velocity - other_velocity)
            count += 1
        
        if count == 0:
            continue
        threshold_position = avg_position / count
        threshold_velocity = avg_velocity / count
        # print(threshold_position, threshold_velocity)

        for j, other_node in enumerate(graph.vs):
            if i == j:
                continue
            if j == node['previous'] or j == node['next']:
                graph.add_edge(i, j, is_continue=1)
                continue
            other_position = other_node['position']
            other_velocity = other_node['velocity']
            position_similarity = np.linalg.norm(current_position - other_position)
            velocity_similarity = np.linalg.norm(current_velocity - other_velocity)
            if position_similarity < threshold_position and velocity_similarity < threshold_velocity:
                graph.add_edge(i, j, is_continue=0)

    print(f"nodes: {len(graph.vs)}, edges: {len(graph.es)}")
    in_degrees = graph.indegree()
    out_degrees = graph.outdegree()
    avg_in_degree = sum(in_degrees) / len(in_degrees)
    avg_out_degree = sum(out_degrees) / len(out_degrees)
    print(f"Average In-degree: {avg_in_degree}")
    print(f"Average Out-degree: {avg_out_degree}")
    print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}")
    print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}")
  # igraph.plot(graph, target="/content/test.png", bbox=(1000, 1000), vertex_size=10)
    return graph

def random_walk(graph, walk_length, start_node=None):
    if start_node is None:
        start_node = np.random.choice(graph.vs)
    walk = [start_node]
    is_continue = [1]
    for _ in range(walk_length):
        current_node = walk[-1]
        neighbor_indices = graph.neighbors(current_node.index, mode='OUT')
        if not neighbor_indices:
            break
        next_idx = np.random.choice(neighbor_indices)
        edge_id = graph.get_eid(current_node.index, next_idx)
        is_cont = graph.es[edge_id]['is_continue']
        walk.append(graph.vs[next_idx])
        is_continue.append(is_cont)
    return walk, is_continue

def path_visualization(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False):
    all_frames = [node['video'][:,510:1230] for node in path]
    average_dis_continue = 1 - sum(is_continue) / len(is_continue)
    if verbose_continue:
        print("average_dis_continue:", average_dis_continue)
    duration = len(all_frames) / graph.vs[0]['fps']
    def make_frame(t):
        idx = min(int(t * graph.vs[0]['fps']), len(all_frames) - 1)
        return all_frames[idx]
    video_clip = VideoClip(make_frame, duration=duration)
    if audio_path is not None:
        audio_clip = AudioFileClip(audio_path)
        video_clip = video_clip.set_audio(audio_clip)
    video_clip.write_videofile(save_path, codec='libx264', fps=graph.vs[0]['fps'], audio_codec='aac')

    if return_motion:
        all_motion = [node['axis_angle'] for node in path]
        all_motion = np.stack(all_motion, 0)
        return all_motion



def generate_transition_video(frame_start_path, frame_end_path, output_video_path):
    import subprocess
    import os

    # Define the path to your model and inference script
    model_path = "/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/frame-interpolation-pytorch/film_net_fp32.pt"
    inference_script = "/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/frame-interpolation-pytorch/inference.py"

    # Build the command to run the inference script
    command = [
        "python",
        inference_script,
        model_path,
        frame_start_path,
        frame_end_path,
        "--save_path", output_video_path,
        "--gpu",
        "--frames", "3",
        "--fps", "30"
    ]

    # Run the command
    try:
        subprocess.run(command, check=True)
        print(f"Generated transition video saved at {output_video_path}")
    except subprocess.CalledProcessError as e:
        print(f"Error occurred while generating transition video: {e}")


def path_visualization_v2(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False):
    all_frames = [node['video'][:, 510:1230] for node in path]
    average_dis_continue = 1 - sum(is_continue) / len(is_continue)
    if verbose_continue:
        print("average_dis_continue:", average_dis_continue)
    duration = len(all_frames) / graph.vs[0]['fps']
    
    # First loop: Confirm where blending is needed
    discontinuity_indices = []
    for i, cont in enumerate(is_continue):
        if cont == 0:
            discontinuity_indices.append(i)
    
    # Identify blending positions without overlapping
    blend_positions = []
    processed_frames = set()
    for i in discontinuity_indices:
        # Define the frames for blending: i-2 to i+2
        start_idx = i - 2
        end_idx = i + 2
        # Check index boundaries
        if start_idx < 0 or end_idx >= len(all_frames):
            continue  # Skip if indices are out of bounds
        # Check for overlapping frames
        overlap = any(idx in processed_frames for idx in range(i - 1, i + 2))
        if overlap:
            continue  # Skip if frames have been processed
        # Mark frames as processed
        processed_frames.update(range(i - 1, i + 2))
        blend_positions.append(i)
    
    # Second loop: Perform blending
    temp_dir = tempfile.mkdtemp(prefix='blending_frames_')
    for i in tqdm(blend_positions):
        start_frame_idx = i - 2
        end_frame_idx = i + 2
        frame_start = all_frames[start_frame_idx]
        frame_end = all_frames[end_frame_idx]
        frame_start_path = os.path.join(temp_dir, f'frame_{start_frame_idx}.png')
        frame_end_path = os.path.join(temp_dir, f'frame_{end_frame_idx}.png')
        # Save the start and end frames as images
        imageio.imwrite(frame_start_path, frame_start)
        imageio.imwrite(frame_end_path, frame_end)
        
        # Call FiLM API to generate video
        generated_video_path = os.path.join(temp_dir, f'generated_{start_frame_idx}_{end_frame_idx}.mp4')
        generate_transition_video(frame_start_path, frame_end_path, generated_video_path)
        
        # Read the generated video frames
        reader = imageio.get_reader(generated_video_path)
        generated_frames = [frame for frame in reader]
        reader.close()
        
        # Replace the middle three frames (i-1, i, i+1) in all_frames
        total_generated_frames = len(generated_frames)
        if total_generated_frames < 5:
            print(f"Generated video has insufficient frames ({total_generated_frames}). Skipping blending at position {i}.")
            continue
        middle_start = total_generated_frames // 2 - 1  # Start index for middle 3 frames
        middle_frames = generated_frames[middle_start:middle_start+3]
        for idx, frame_idx in enumerate(range(i - 1, i + 2)):
            all_frames[frame_idx] = middle_frames[idx]
    
    # Create the video clip
    def make_frame(t):
        idx = min(int(t * graph.vs[0]['fps']), len(all_frames) - 1)
        return all_frames[idx]
    
    video_clip = VideoClip(make_frame, duration=duration)
    if audio_path is not None:
        audio_clip = AudioFileClip(audio_path)
        video_clip = video_clip.set_audio(audio_clip)
    video_clip.write_videofile(save_path, codec='libx264', fps=graph.vs[0]['fps'], audio_codec='aac')
    
    if return_motion:
        all_motion = [node['axis_angle'] for node in path]
        all_motion = np.stack(all_motion, 0)
        return all_motion


def graph_pruning(graph):
    ascc = graph.clusters(mode="STRONG")
    graph = ascc.giant()
    print(f"nodes: {len(graph.vs)}, edges: {len(graph.es)}")
    in_degrees = graph.indegree()
    out_degrees = graph.outdegree()
    avg_in_degree = sum(in_degrees) / len(in_degrees)
    avg_out_degree = sum(out_degrees) / len(out_degrees)
    print(f"Average In-degree: {avg_in_degree}")
    print(f"Average Out-degree: {avg_out_degree}")
    print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}")
    print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}")
    return graph

if __name__ == '__main__':
    # graph = create_graph('./datasets/data_json/show-oliver-test.json')
    # graph = create_edges(graph)
    pool_path = "/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/show-oliver-test.pkl"
    graph = igraph.Graph.Read_Pickle(fname=pool_path)
    # graph = igraph.Graph.Read_Pickle(fname="/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/test.pkl")
    # graph = graph_pruning(graph)
    walk, is_continue = random_walk(graph, 10)
    motion = path_visualization_v2(graph, walk, is_continue, "/content/test.mp4", audio_path=None, verbose_continue=True, return_motion=True)
    print(motion.shape)
    # save_graph = graph.write_pickle(fname="/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/show-oliver-test.pkl")