import json
import random
import cv2
import torch
import torchvision.transforms as transforms
from decord import VideoReader
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor
import numpy as np
import torchvision.transforms.functional as TF

limbSeq = [
        [2, 3],
        [2, 6],
        [3, 4],
        [4, 5],
        [6, 7],
        [7, 8],
        [2, 9],
        [9, 10],
        [10, 11],
        [2, 12],
        [12, 13],
        [13, 14],
        [2, 1],
        [1, 15],
        [15, 17],
        [1, 16],
        [16, 18],
    ]
edges = [
        [0, 1],
        [1, 2],
        [2, 3],
        [3, 4],
        [0, 5],
        [5, 6],
        [6, 7],
        [7, 8],
        [0, 9],
        [9, 10],
        [10, 11],
        [11, 12],
        [0, 13],
        [13, 14],
        [14, 15],
        [15, 16],
        [0, 17],
        [17, 18],
        [18, 19],
        [19, 20],
    ]


def calculate_and_plot_optical_flow(frame1_index, frame2_index, keypoints, keypoints_type, color, title_prefix="", pose_id=None):
    """
    Calculate optical flow between two frames using the difference in keypoints and plot the flow vectors.
    """
    flow_all = {}
    kp1_all = {}
    kp2_all = {}
    for kp_type in keypoints_type:  # Loop through body, hands, and face keypoints
        if kp_type == 'body':
            flow = []
            kp1 = []
            kp2 = []
            for i in range(18):
                if pose_id[frame1_index][0][i] == -1 or pose_id[frame2_index][0][i] == -1:
                    flow.append([0.0, 0.0])
                    kp1.append([0.0, 0.0])
                    kp2.append([0.0, 0.0])
                else:
                    kp1_s = keypoints[kp_type][frame1_index][i] * [784, 720]
                    kp2_s = keypoints[kp_type][frame2_index][i] * [784, 720]
                    flow.append(kp2_s - kp1_s)
                    kp1.append(kp1_s)
                    kp2.append(kp2_s)
                    
            flow = np.array(flow)
            kp1 = np.array(kp1)
            kp2 = np.array(kp2)
            # print(flow.shape)
        else:
            kp1 = keypoints[kp_type][frame1_index] * [784, 720]
            kp2 = keypoints[kp_type][frame2_index] * [784, 720]
            flow = kp2 - kp1

        flow_all[kp_type] = flow
        kp1_all[kp_type] = kp1
        kp2_all[kp_type] = kp2
    return flow_all, kp1_all, kp2_all


def calculate_and_visualize_blended_flow(point1, point2, flow1, flow2, thickness=20, canvas_size=[720,784]):
    # Initialize the flow field
    flow_field = np.zeros((canvas_size[0], canvas_size[1], 2))
    
    # Draw the line on a canvas
    canvas = np.zeros((canvas_size[0], canvas_size[1]), dtype=np.uint8)
    cv2.line(canvas, point1.astype(np.int32), point2.astype(np.int32), 255, thickness=thickness)

    if point1.astype(np.int32)[0] > point2.astype(np.int32)[0]:
        x_end = min(point1.astype(np.int32)[0] + 2 * thickness, canvas_size[1])
        x_start = max(point2.astype(np.int32)[0] - 2 * thickness, 0)
    else:
        x_end = min(point2.astype(np.int32)[0] + 2 * thickness,  canvas_size[1])
        x_start = max(point1.astype(np.int32)[0] - 2 * thickness, 0)

    if point1.astype(np.int32)[1] > point2.astype(np.int32)[1]:
        y_end = min(point1.astype(np.int32)[1] + 2 * thickness, canvas_size[0])
        y_start = max(point2.astype(np.int32)[1] - 2 * thickness, 0)
    else:
        y_end = min(point2.astype(np.int32)[1] + 2 * thickness,  canvas_size[0])
        y_start = max(point1.astype(np.int32)[1] - 2 * thickness, 0)
        
    
    # Calculate blended flow for points on the line
    for y in range(y_start, y_end):
        for x in range(x_start, x_end):
            if canvas[y, x] == 255:
                ratio = np.linalg.norm([x - point1[0], y - point1[1]]) / np.linalg.norm([point2[0] - point1[0], point2[1] - point1[1]])
                blended_flow = flow1 * (1 - ratio) + flow2 * ratio
                flow_field[y, x, :] = blended_flow
                # print(blended_flow)
    return flow_field

def circle_blended_flow(point1, flow1, thickness=20, canvas_size=[720,784]):
    # Initialize the flow field
    flow_field = np.zeros((canvas_size[0], canvas_size[1], 2))
    
    # Draw the line on a canvas
    canvas = np.zeros((canvas_size[0], canvas_size[1]), dtype=np.uint8)
    # cv2.line(canvas, point1.astype(np.int32), point2.astype(np.int32), 255, thickness=thickness)
    cv2.circle(canvas, center=point1.astype(np.int32), radius=thickness, color=255, thickness=-1)

    y_start = max(point1.astype(np.int32)[1]-thickness, 0)
    y_end = min(point1.astype(np.int32)[1]+thickness, canvas_size[0])
    x_start = max(point1.astype(np.int32)[0]-thickness, 0)
    x_end = min(point1.astype(np.int32)[0]+thickness, canvas_size[1])
    # print(y_start, y_end, x_start, x_end)
    # Calculate blended flow for points on the line
    for y in range(y_start, y_end):
        for x in range(x_start, x_end):
    # for y in range(canvas_size[0]):
    #     for x in range(canvas_size[1]):
            if canvas[y, x] == 255:
                # print(y, x)
                # blended_flow = flow1 * (1 - ratio) + flow2 * ratio
                flow_field[y, x, :] = flow1
                # print(blended_flow)
    return flow_field

def np_get_flow(np_path, start, end):
    pose_pos = np.load(np_path[:-4]+"_pose_pos.npy")
    pose_id = np.load(np_path[:-4]+"_pose_id.npy") 
    hands_pos = np.load(np_path[:-4]+"_hands_pos.npy")
    face_pos = np.load(np_path[:-4]+"_face_pos.npy")
    # Define keypoints data for optical flow calculation
    keypoints_data = {
        'body': pose_pos,
        'left_hand': hands_pos[:, 0, :, :],
        'right_hand': hands_pos[:, 1, :, :],
        'face': face_pos[:, 0, :, :],
    }
    
    # Colors for different types of keypoints
    keypoints_color = {
        'body': 'blue',
        'left_hand': 'red',
        'right_hand': 'green',
        'face': 'yellow'
    }
    
    # Calculate and plot optical flow
    flow_all, kp1_all, kp2_all = calculate_and_plot_optical_flow(start, end, keypoints_data, keypoints_data.keys(), keypoints_color, title_prefix="OpenPose ", pose_id=pose_id)
    full_flow_field = np.zeros((720, 784, 2))
        
    for limb in limbSeq:
        point1_idx, point2_idx = limb[0] - 1, limb[1] - 1 
        if (flow_all['body'][point1_idx][0] == 0.0 and flow_all['body'][point1_idx][1] == 0.0) or (flow_all['body'][point2_idx][0] == 0.0 and flow_all['body'][point2_idx][1] == 0.0):
            continue
        # print(kp1_all['body'][point1_idx], kp1_all['body'][point2_idx])
        blended_flow = calculate_and_visualize_blended_flow(
                kp1_all['body'][point1_idx], kp1_all['body'][point2_idx],
                flow_all['body'][point1_idx], flow_all['body'][point2_idx],
            )
        full_flow_field += blended_flow
    
    for edge in edges:
        point1_idx, point2_idx = edge[0], edge[1] 
        blended_flow = calculate_and_visualize_blended_flow(
                kp1_all['left_hand'][point1_idx], kp1_all['left_hand'][point2_idx],
                flow_all['left_hand'][point1_idx], flow_all['left_hand'][point2_idx],
                thickness = 10
            )
        full_flow_field += blended_flow
    
        blended_flow = calculate_and_visualize_blended_flow(
                kp1_all['right_hand'][point1_idx], kp1_all['right_hand'][point2_idx],
                flow_all['right_hand'][point1_idx], flow_all['right_hand'][point2_idx],
                thickness = 10
            )
        full_flow_field += blended_flow
    
    for i in range(68):
        blended_flow = circle_blended_flow(
            kp1_all['face'][i],
            flow_all['face'][i],
            thickness = 10
        )
        full_flow_field += blended_flow
    flow_field_smoothed = cv2.GaussianBlur(full_flow_field, (5, 5), 0)
    flow_field_smoothed = flow_field_smoothed
    # magnitude_smoothed, angle_smoothed = cv2.cartToPolar(flow_field_smoothed[..., 0], flow_field_smoothed[..., 1])
    # hsv_image_smoothed = np.zeros((784, 720, 3), dtype=np.uint8)
    # hsv_image_smoothed[..., 0] = angle_smoothed * 180 / np.pi / 2
    # hsv_image_smoothed[..., 1] = 255
    # hsv_image_smoothed[..., 2] = cv2.normalize(magnitude_smoothed, None, 0, 255, cv2.NORM_MINMAX)
    # rgb_flow_smoothed = cv2.cvtColor(hsv_image_smoothed, cv2.COLOR_HSV2RGB)
    # rgb_flow = rgb_flow_smoothed
        
    # plt.figure(figsize=(10, 6))
    # plt.imshow(rgb_flow)
    # plt.title('Full Body Flow Visualization')
    # plt.axis('off')
    # plt.show()
    return flow_field_smoothed



# '/data/Moore-AnimateAnyone/configs/inference/test_videos_dwpose/214438-00_07_16-00_07_26_pose_pos.npy'

class RandomResizedCropTensor(torch.nn.Module):
    def __init__(self, size, scale=(0.08, 1.0), ratio=(3./4., 4./3.)):
        super().__init__()
        self.size = size
        self.scale = scale
        self.ratio = ratio

    def forward(self, img):
        # # 随机选择一个裁剪参数
        # i, j, h, w = TF._get_random_crop_size(img, self.scale, self.ratio)
        # # 裁剪并调整大小
        # img_cropped = TF.crop(img, i, j, h, w)
        img_resized = TF.resize(img, self.size)
        return img_resized

class HumanDanceDataset(Dataset):
    def __init__(
        self,
        img_size,
        img_scale=(1.0, 1.0),
        img_ratio=(0.9, 1.0),
        drop_ratio=0.1,
        data_meta_paths=["./data/fahsion_meta.json"],
        sample_margin=30,
    ):
        super().__init__()

        self.img_size = img_size
        self.img_scale = img_scale
        self.img_ratio = img_ratio
        self.sample_margin = sample_margin

        # -----
        # vid_meta format:
        # [{'video_path': , 'kps_path': , 'other':},
        #  {'video_path': , 'kps_path': , 'other':}]
        # -----
        vid_meta = []
        for data_meta_path in data_meta_paths:
            vid_meta.extend(json.load(open(data_meta_path, "r")))
        self.vid_meta = vid_meta

        self.clip_image_processor = CLIPImageProcessor()

        self.transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    self.img_size,
                    scale=self.img_scale,
                    ratio=self.img_ratio,
                    interpolation=transforms.InterpolationMode.BILINEAR,
                ),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

        self.cond_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    self.img_size,
                    scale=self.img_scale,
                    ratio=self.img_ratio,
                    interpolation=transforms.InterpolationMode.BILINEAR,
                ),
                transforms.ToTensor(),
            ]
        )

        self.cond_transform_flow = transforms.Compose(
            [
                # transforms.RandomResizedCrop(
                #     self.img_size,
                #     scale=self.img_scale,
                #     ratio=self.img_ratio,
                #     interpolation=transforms.InterpolationMode.BILINEAR,
                # ),
                transforms.ToTensor(),
                RandomResizedCropTensor(self.img_size),
            ]
        )

        self.drop_ratio = drop_ratio

    def augmentation(self, image, transform, state=None):
        if state is not None:
            torch.set_rng_state(state)
        return transform(image)

    def __getitem__(self, index):
        video_meta = self.vid_meta[index]
        video_path = video_meta["video_path"]
        kps_path = video_meta["kps_path"]

        video_reader = VideoReader(video_path)
        kps_reader = VideoReader(kps_path)
        assert len(video_reader) == len(
            kps_reader
        ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"

        video_length = len(video_reader)
        margin = min(self.sample_margin, video_length)
        ref_img_idx = random.randint(0, video_length - 1)
        pob = random.randint(0, 99)
        
        if pob >= 90:
            tgt_img_idx = ref_img_idx
            flow_zero = np.zeros((720,784,3))
        else:         
            if ref_img_idx + margin < video_length:
                tgt_img_idx = random.randint(ref_img_idx, ref_img_idx + margin - 1)
            else: 
                tgt_img_idx = random.randint(ref_img_idx - margin, ref_img_idx)

                # print(tgt_img_idx, ref_img_idx)
            flow = np_get_flow(kps_path, ref_img_idx, tgt_img_idx)
            flow = flow / 200
            flow = np.clip(flow, -0.5, 0.5)
            flow_zero = np.zeros((720,784,3))
            flow_zero[:, :, :2] = flow
        ref_img = video_reader[ref_img_idx]
        ref_img_pil = Image.fromarray(ref_img.asnumpy())
        
        # tgt_img = video_reader[tgt_img_idx]
        # tgt_img_pil = Image.fromarray(tgt_img.asnumpy())
        # print(flow.shape, ref_img.shape, flow_zero.shape)
        # flow_pil = torch.from_numpy(flow_zero).float()
        # flow_pil = Image.fromarray(flow_zero)
        # margin = min(self.sample_margin, video_length)
        # ref_img_idx = random.randint(0, video_length - 1)
        # if ref_img_idx + margin < video_length:
        #     tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1)
        # elif ref_img_idx - margin > 0:
        #     tgt_img_idx = random.randint(0, ref_img_idx - margin)
        # else:
        #     tgt_img_idx = random.randint(0, video_length - 1)
        
        
        tgt_img = video_reader[tgt_img_idx]
        tgt_img_pil = Image.fromarray(tgt_img.asnumpy())

        tgt_pose = kps_reader[tgt_img_idx]
        tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy())
      
        state = torch.get_rng_state()
        tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
        # tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
        flow_img = self.augmentation(flow_zero, self.cond_transform_flow, state)
        tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
        ref_img_vae = self.augmentation(ref_img_pil, self.transform, state)
        tgt_pose_img = tgt_pose_img.to(dtype=ref_img_vae.dtype)
        
        clip_image = self.clip_image_processor(
            images=ref_img_pil, return_tensors="pt"
        ).pixel_values[0]
        
        sample = dict(
            video_dir=video_path,
            img=tgt_img,
            tgt_pose=tgt_pose_img,
            ref_img=ref_img_vae,
            clip_images=clip_image,
            flow=flow_img,
        )
        return sample

    def __len__(self):
        return len(self.vid_meta)
