import json
import random
from typing import List

import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
from decord import VideoReader
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor

import numpy as np
import cv2


def calculate_translation_hw(image1, image2, hwxregion=None, hmyregion=None, trianing_size=(768,768)):
    """
    Calculate the translation Homography matrix between two images using specified horizontal (x)
    and vertical (y) regions.

    Args:
    image1 (np.array): The first image in numpy array format.
    image2 (np.array): The second image in numpy array format.
    hwxregion (tuple, optional): The horizontal region specified as (startx, endx). Default is None.
    hmyregion (tuple, optional): The vertical region specified as (starty, endy). Default is None.

    Returns:
    np.array: An array containing dx and dy as percentage of total width and height respectively.
    """
    if hwxregion is not None:
        image1 = image1[:, hwxregion[0]:hwxregion[1]]
        image2 = image2[:, hwxregion[0]:hwxregion[1]]
    
    if hmyregion is not None:
        image1 = image1[hmyregion[0]:hmyregion[1], :]
        image2 = image2[hmyregion[0]:hmyregion[1], :]

    # Convert images to grayscale
    img1_gray = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
    img2_gray = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY)

    # Initiate ORB detector
    orb = cv2.ORB_create()

    # Find the keypoints and descriptors with ORB
    kp1, des1 = orb.detectAndCompute(img1_gray, None)
    kp2, des2 = orb.detectAndCompute(img2_gray, None)

    # Create BFMatcher object
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)

    # Match descriptors
    if des1 is None or des2 is None:
        print("someone")
        return np.ones((2, trianing_size[0], trianing_size[1]))
    matches = bf.match(des1, des2)

    # Sort them in the order of their distance
    matches = sorted(matches, key = lambda x: x.distance)

    # Extract location of good matches
    points1 = np.zeros((len(matches), 2), dtype=np.float32)
    points2 = np.zeros((len(matches), 2), dtype=np.float32)

    for i, match in enumerate(matches):
        points1[i, :] = kp1[match.queryIdx].pt
        points2[i, :] = kp2[match.trainIdx].pt

    # Find translation using median of differences to reduce the effect of outliers
    dx = np.median(points2[:, 0] - points1[:, 0])
    dy = np.median(points2[:, 1] - points1[:, 1])

    # Calculate the translation percentages
    dx_percent = dx / image1.shape[1]  # width
    dy_percent = dy / image1.shape[0]  # height

    np_array = np.zeros((2, trianing_size[0], trianing_size[1]))
    np_array[0] = dx_percent * 10
    np_array[1] = dy_percent * 10
    return np_array


class HumanDanceVideoDataset(Dataset):
    def __init__(
        self,
        sample_rate,
        n_sample_frames,
        width,
        height,
        img_scale=(1.0, 1.0),
        img_ratio=(1.0, 1.0),
        drop_ratio=0.1,
        data_meta_paths=["./data/fashion_meta.json"],
        original_size_img=(768, 768),
        original_size_pos=(768, 768),
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_sample_frames = n_sample_frames
        self.width = width
        self.height = height
        self.img_scale = img_scale
        self.img_ratio = img_ratio
        self.img_size = (width, height)

        vid_meta = []
        for data_meta_path in data_meta_paths:
            vid_meta.extend(json.load(open(data_meta_path, "r")))
        self.vid_meta = vid_meta

        self.clip_image_processor = CLIPImageProcessor()

        self.pixel_transform = transforms.Compose(
            [
                transforms.CenterCrop((min(original_size_img), min(original_size_img))),
                transforms.Resize((height, width)),
                # transforms.RandomResizedCrop(
                #     (height, width),
                #     scale=self.img_scale,
                #     ratio=self.img_ratio,
                #     interpolation=transforms.InterpolationMode.BILINEAR,
                # ),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

        self.cond_transform = transforms.Compose(
            [
                transforms.CenterCrop((min(original_size_pos), min(original_size_pos))),
                transforms.Resize((height, width)),
                # transforms.RandomResizedCrop(
                #     (height, width),
                #     scale=self.img_scale,
                #     ratio=self.img_ratio,
                #     interpolation=transforms.InterpolationMode.BILINEAR,
                # ),
                transforms.ToTensor(),
            ]
        )

        self.drop_ratio = drop_ratio

    def augmentation(self, images, transform, state=None):
        if state is not None:
            torch.set_rng_state(state)
        if isinstance(images, List):
            transformed_images = [transform(img) for img in images]
            ret_tensor = torch.stack(transformed_images, dim=0)  # (f, c, h, w)
        else:
            ret_tensor = transform(images)  # (c, h, w)
        return ret_tensor

    def __getitem__(self, index):
        video_meta = self.vid_meta[index]
        video_path = video_meta["video_path"]
        kps_path = video_meta["kps_path"]

        video_reader = VideoReader(video_path)
        kps_reader = VideoReader(kps_path)

        assert len(video_reader) == len(
            kps_reader
        ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"

        video_length = len(video_reader)

        clip_length = min(
            video_length, (self.n_sample_frames - 1) * self.sample_rate + 1
        )
        start_idx = random.randint(0, video_length - clip_length)
        batch_index = np.linspace(
            start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int
        ).tolist()
        # 0, 4, 8, 12
        state = torch.get_rng_state()
        start_ref_img_idx = batch_index[0]
        ref_img_np = video_reader[start_ref_img_idx].asnumpy()
        start_ref_img = Image.fromarray(ref_img_np)
        pixel_values_start_ref_img = self.augmentation(
            start_ref_img, self.pixel_transform, state
        )

        
        # read frames and kps
        vid_pil_image_list = []
        pose_pil_image_list = []
        translation_list = []
        for index in batch_index:
            img = video_reader[index]
            vid_pil_image_list.append(Image.fromarray(img.asnumpy()))
            # translation = calculate_translation_hw(ref_img_np, img.asnumpy(), hwxregion=(-201,-1), trianing_size=self.img_size)
            # translation = torch.tensor(translation).to(dtype=pixel_values_start_ref_img.dtype)
            # translation_list.append(translation)
            img = kps_reader[index]
            pose_pil_image_list.append(Image.fromarray(img.asnumpy()))
        

        # transform
        
        pixel_values_vid = self.augmentation(
            vid_pil_image_list, self.pixel_transform, state
        )
        pixel_values_pose = self.augmentation(
            pose_pil_image_list, self.cond_transform, state
        )
        # translation_list = torch.stack(translation_list, dim=0)
        # print(len(pixel_values_pose), pixel_values_pose[0].shape)
        # print(translation_list.shape, pixel_values_pose.shape)
        # pixel_values_pose = torch.cat((pixel_values_pose, translation_list), dim=1)



        clip_start_ref_img = self.clip_image_processor(
            images=start_ref_img, return_tensors="pt"
        ).pixel_values[0]


        sample = dict(
            video_dir=video_path,
            pixel_values_vid=pixel_values_vid,
            pixel_values_pose=pixel_values_pose,
            pixel_values_ref_img=torch.stack([pixel_values_start_ref_img,pixel_values_start_ref_img],dim=0),
            clip_ref_img=torch.stack([clip_start_ref_img,clip_start_ref_img],dim=0),
        )

        return sample

    def __len__(self):
        return len(self.vid_meta)
