import json
import random
from typing import List
import cv2
import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
from decord import VideoReader
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor
import torchvision.transforms.functional as TF

limbSeq = [
        [2, 3],
        [2, 6],
        [3, 4],
        [4, 5],
        [6, 7],
        [7, 8],
        [2, 9],
        [9, 10],
        [10, 11],
        [2, 12],
        [12, 13],
        [13, 14],
        [2, 1],
        [1, 15],
        [15, 17],
        [1, 16],
        [16, 18],
    ]
edges = [
        [0, 1],
        [1, 2],
        [2, 3],
        [3, 4],
        [0, 5],
        [5, 6],
        [6, 7],
        [7, 8],
        [0, 9],
        [9, 10],
        [10, 11],
        [11, 12],
        [0, 13],
        [13, 14],
        [14, 15],
        [15, 16],
        [0, 17],
        [17, 18],
        [18, 19],
        [19, 20],
    ]


def calculate_and_plot_optical_flow(frame1_index, frame2_index, keypoints, keypoints_type, color, title_prefix="", pose_id=None):
    """
    Calculate optical flow between two frames using the difference in keypoints and plot the flow vectors.
    """
    flow_all = {}
    kp1_all = {}
    kp2_all = {}
    for kp_type in keypoints_type:  # Loop through body, hands, and face keypoints
        if kp_type == 'body':
            flow = []
            kp1 = []
            kp2 = []
            for i in range(18):
                if pose_id[frame1_index][0][i] == -1 or pose_id[frame2_index][0][i] == -1:
                    flow.append([0.0, 0.0])
                    kp1.append([0.0, 0.0])
                    kp2.append([0.0, 0.0])
                else:
                    kp1_s = keypoints[kp_type][frame1_index][i] * [784, 720]
                    kp2_s = keypoints[kp_type][frame2_index][i] * [784, 720]
                    flow.append(kp2_s - kp1_s)
                    kp1.append(kp1_s)
                    kp2.append(kp2_s)
                    
            flow = np.array(flow)
            kp1 = np.array(kp1)
            kp2 = np.array(kp2)
            # print(flow.shape)
        else:
            kp1 = keypoints[kp_type][frame1_index] * [784, 720]
            kp2 = keypoints[kp_type][frame2_index] * [784, 720]
            flow = kp2 - kp1

        flow_all[kp_type] = flow
        kp1_all[kp_type] = kp1
        kp2_all[kp_type] = kp2
    return flow_all, kp1_all, kp2_all


def calculate_and_visualize_blended_flow(point1, point2, flow1, flow2, thickness=20, canvas_size=[720,784]):
    # Initialize the flow field
    flow_field = np.zeros((canvas_size[0], canvas_size[1], 2))
    
    # Draw the line on a canvas
    canvas = np.zeros((canvas_size[0], canvas_size[1]), dtype=np.uint8)
    cv2.line(canvas, point1.astype(np.int32), point2.astype(np.int32), 255, thickness=thickness)

    if point1.astype(np.int32)[0] > point2.astype(np.int32)[0]:
        x_end = min(point1.astype(np.int32)[0] + 2 * thickness, canvas_size[1])
        x_start = max(point2.astype(np.int32)[0] - 2 * thickness, 0)
    else:
        x_end = min(point2.astype(np.int32)[0] + 2 * thickness,  canvas_size[1])
        x_start = max(point1.astype(np.int32)[0] - 2 * thickness, 0)

    if point1.astype(np.int32)[1] > point2.astype(np.int32)[1]:
        y_end = min(point1.astype(np.int32)[1] + 2 * thickness, canvas_size[0])
        y_start = max(point2.astype(np.int32)[1] - 2 * thickness, 0)
    else:
        y_end = min(point2.astype(np.int32)[1] + 2 * thickness,  canvas_size[0])
        y_start = max(point1.astype(np.int32)[1] - 2 * thickness, 0)
        
    
    # Calculate blended flow for points on the line
    for y in range(y_start, y_end):
        for x in range(x_start, x_end):
            if canvas[y, x] == 255:
                ratio = np.linalg.norm([x - point1[0], y - point1[1]]) / np.linalg.norm([point2[0] - point1[0], point2[1] - point1[1]])
                blended_flow = flow1 * (1 - ratio) + flow2 * ratio
                flow_field[y, x, :] = blended_flow
                # print(blended_flow)
    return flow_field

def circle_blended_flow(point1, flow1, thickness=20, canvas_size=[720,784]):
    # Initialize the flow field
    flow_field = np.zeros((canvas_size[0], canvas_size[1], 2))
    
    # Draw the line on a canvas
    canvas = np.zeros((canvas_size[0], canvas_size[1]), dtype=np.uint8)
    # cv2.line(canvas, point1.astype(np.int32), point2.astype(np.int32), 255, thickness=thickness)
    cv2.circle(canvas, center=point1.astype(np.int32), radius=thickness, color=255, thickness=-1)

    y_start = max(point1.astype(np.int32)[1]-thickness, 0)
    y_end = min(point1.astype(np.int32)[1]+thickness, canvas_size[0])
    x_start = max(point1.astype(np.int32)[0]-thickness, 0)
    x_end = min(point1.astype(np.int32)[0]+thickness, canvas_size[1])
    # print(y_start, y_end, x_start, x_end)
    # Calculate blended flow for points on the line
    for y in range(y_start, y_end):
        for x in range(x_start, x_end):
    # for y in range(canvas_size[0]):
    #     for x in range(canvas_size[1]):
            if canvas[y, x] == 255:
                # print(y, x)
                # blended_flow = flow1 * (1 - ratio) + flow2 * ratio
                flow_field[y, x, :] = flow1
                # print(blended_flow)
    return flow_field

def np_get_flow(np_path, start, end):
    pose_pos = np.load(np_path[:-4]+"_pose_pos.npy")
    pose_id = np.load(np_path[:-4]+"_pose_id.npy") 
    hands_pos = np.load(np_path[:-4]+"_hands_pos.npy")
    face_pos = np.load(np_path[:-4]+"_face_pos.npy")
    # Define keypoints data for optical flow calculation
    keypoints_data = {
        'body': pose_pos,
        'left_hand': hands_pos[:, 0, :, :],
        'right_hand': hands_pos[:, 1, :, :],
        'face': face_pos[:, 0, :, :],
    }
    
    # Colors for different types of keypoints
    keypoints_color = {
        'body': 'blue',
        'left_hand': 'red',
        'right_hand': 'green',
        'face': 'yellow'
    }
    
    # Calculate and plot optical flow
    flow_all, kp1_all, kp2_all = calculate_and_plot_optical_flow(start, end, keypoints_data, keypoints_data.keys(), keypoints_color, title_prefix="OpenPose ", pose_id=pose_id)
    full_flow_field = np.zeros((720, 784, 2))
        
    for limb in limbSeq:
        point1_idx, point2_idx = limb[0] - 1, limb[1] - 1 
        if (flow_all['body'][point1_idx][0] == 0.0 and flow_all['body'][point1_idx][1] == 0.0) or (flow_all['body'][point2_idx][0] == 0.0 and flow_all['body'][point2_idx][1] == 0.0):
            continue
        # print(kp1_all['body'][point1_idx], kp1_all['body'][point2_idx])
        blended_flow = calculate_and_visualize_blended_flow(
                kp1_all['body'][point1_idx], kp1_all['body'][point2_idx],
                flow_all['body'][point1_idx], flow_all['body'][point2_idx],
            )
        full_flow_field += blended_flow
    
    for edge in edges:
        point1_idx, point2_idx = edge[0], edge[1] 
        blended_flow = calculate_and_visualize_blended_flow(
                kp1_all['left_hand'][point1_idx], kp1_all['left_hand'][point2_idx],
                flow_all['left_hand'][point1_idx], flow_all['left_hand'][point2_idx],
                thickness = 10
            )
        full_flow_field += blended_flow
    
        blended_flow = calculate_and_visualize_blended_flow(
                kp1_all['right_hand'][point1_idx], kp1_all['right_hand'][point2_idx],
                flow_all['right_hand'][point1_idx], flow_all['right_hand'][point2_idx],
                thickness = 10
            )
        full_flow_field += blended_flow
    
    for i in range(68):
        blended_flow = circle_blended_flow(
            kp1_all['face'][i],
            flow_all['face'][i],
            thickness = 10
        )
        full_flow_field += blended_flow
    flow_field_smoothed = cv2.GaussianBlur(full_flow_field, (5, 5), 0)
    flow_field_smoothed = flow_field_smoothed
    # magnitude_smoothed, angle_smoothed = cv2.cartToPolar(flow_field_smoothed[..., 0], flow_field_smoothed[..., 1])
    # hsv_image_smoothed = np.zeros((784, 720, 3), dtype=np.uint8)
    # hsv_image_smoothed[..., 0] = angle_smoothed * 180 / np.pi / 2
    # hsv_image_smoothed[..., 1] = 255
    # hsv_image_smoothed[..., 2] = cv2.normalize(magnitude_smoothed, None, 0, 255, cv2.NORM_MINMAX)
    # rgb_flow_smoothed = cv2.cvtColor(hsv_image_smoothed, cv2.COLOR_HSV2RGB)
    # rgb_flow = rgb_flow_smoothed
        
    # plt.figure(figsize=(10, 6))
    # plt.imshow(rgb_flow)
    # plt.title('Full Body Flow Visualization')
    # plt.axis('off')
    # plt.show()
    return flow_field_smoothed



# '/data/Moore-AnimateAnyone/configs/inference/test_videos_dwpose/214438-00_07_16-00_07_26_pose_pos.npy'

class RandomResizedCropTensor(torch.nn.Module):
    def __init__(self, size, scale=(0.08, 1.0), ratio=(3./4., 4./3.)):
        super().__init__()
        self.size = size
        self.scale = scale
        self.ratio = ratio

    def forward(self, img):
        # # 随机选择一个裁剪参数
        # i, j, h, w = TF._get_random_crop_size(img, self.scale, self.ratio)
        # # 裁剪并调整大小
        # img_cropped = TF.crop(img, i, j, h, w)
        img_resized = TF.resize(img, self.size)
        return img_resized


class HumanDanceVideoDataset(Dataset):
    def __init__(
        self,
        sample_rate,
        n_sample_frames,
        width,
        height,
        img_scale=(1.0, 1.0),
        img_ratio=(0.9, 1.0),
        drop_ratio=0.1,
        data_meta_paths=["./data/fashion_meta.json"],
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_sample_frames = n_sample_frames
        self.width = width
        self.height = height
        self.img_scale = img_scale
        self.img_ratio = img_ratio

        vid_meta = []
        for data_meta_path in data_meta_paths:
            vid_meta.extend(json.load(open(data_meta_path, "r")))
        self.vid_meta = vid_meta

        self.clip_image_processor = CLIPImageProcessor()

        self.pixel_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    (height, width),
                    scale=self.img_scale,
                    ratio=self.img_ratio,
                    interpolation=transforms.InterpolationMode.BILINEAR,
                ),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

        self.cond_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    (height, width),
                    scale=self.img_scale,
                    ratio=self.img_ratio,
                    interpolation=transforms.InterpolationMode.BILINEAR,
                ),
                transforms.ToTensor(),
            ]
        )

        self.cond_transform_flow = transforms.Compose(
            [
                # transforms.RandomResizedCrop(
                #     self.img_size,
                #     scale=self.img_scale,
                #     ratio=self.img_ratio,
                #     interpolation=transforms.InterpolationMode.BILINEAR,
                # ),
                transforms.ToTensor(),
                RandomResizedCropTensor((height, width)),
            ]
        )
        self.drop_ratio = drop_ratio

    def augmentation(self, images, transform, state=None):
        if state is not None:
            torch.set_rng_state(state)
        if isinstance(images, List):
            transformed_images = [transform(img) for img in images]
            ret_tensor = torch.stack(transformed_images, dim=0)  # (f, c, h, w)
        else:
            ret_tensor = transform(images)  # (c, h, w)
        return ret_tensor

    def __getitem__(self, index):
        video_meta = self.vid_meta[index]
        video_path = video_meta["video_path"]
        kps_path = video_meta["kps_path"]

        video_reader = VideoReader(video_path)
        kps_reader = VideoReader(kps_path)

        assert len(video_reader) == len(
            kps_reader
        ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"

        video_length = len(video_reader)

        clip_length = min(
            video_length, (self.n_sample_frames - 1) * self.sample_rate + 1
        )
        start_idx = random.randint(0, video_length - clip_length)
        batch_index = np.linspace(
            start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int
        ).tolist()
        
        ref_img_idx = random.randint(0, video_length - 1)
        ref_img = Image.fromarray(video_reader[ref_img_idx].asnumpy())
        
        # read frames and kps
        vid_pil_image_list = []
        pose_pil_image_list = []
        flow_pil_image_list = []
        for index in batch_index:
            img = video_reader[index]
            vid_pil_image_list.append(Image.fromarray(img.asnumpy()))
            
            flow = np_get_flow(kps_path, ref_img_idx, index)
            flow = flow / 200
            flow = np.clip(flow, -0.5, 0.5)
            flow_zero = np.zeros((720,784,3))
            flow_zero[:, :, :2] = flow
            flow_pil_image_list.append(flow_zero)
            
            img = kps_reader[index]
            pose_pil_image_list.append(Image.fromarray(img.asnumpy()))

        
        # transform
        state = torch.get_rng_state()
        pixel_values_vid = self.augmentation(
            vid_pil_image_list, self.pixel_transform, state
        )
        pixel_values_flow = self.augmentation(
            flow_pil_image_list, self.cond_transform_flow, state
        )
        pixel_values_pose = self.augmentation(
            pose_pil_image_list, self.cond_transform, state
        )
        pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
        clip_ref_img = self.clip_image_processor(
            images=ref_img, return_tensors="pt"
        ).pixel_values[0]

        sample = dict(
            video_dir=video_path,
            pixel_values_vid=pixel_values_vid,
            pixel_values_pose=pixel_values_pose,
            pixel_values_ref_img=pixel_values_ref_img,
            clip_ref_img=clip_ref_img,
            flow=pixel_values_flow,
        )

        return sample

    def __len__(self):
        return len(self.vid_meta)
