import json
import random

import torch
import torchvision.transforms as transforms
from decord import VideoReader
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor


import numpy as np
import cv2

def calculate_translation_hw(image1, image2, hwxregion=None, hmyregion=None, trianing_size=(768,768)):
    """
    Calculate the translation Homography matrix between two images using specified horizontal (x)
    and vertical (y) regions.

    Args:
    image1 (np.array): The first image in numpy array format.
    image2 (np.array): The second image in numpy array format.
    hwxregion (tuple, optional): The horizontal region specified as (startx, endx). Default is None.
    hmyregion (tuple, optional): The vertical region specified as (starty, endy). Default is None.

    Returns:
    np.array: An array containing dx and dy as percentage of total width and height respectively.
    """
    if hwxregion is not None:
        image1 = image1[:, hwxregion[0]:hwxregion[1]]
        image2 = image2[:, hwxregion[0]:hwxregion[1]]
    
    if hmyregion is not None:
        image1 = image1[hmyregion[0]:hmyregion[1], :]
        image2 = image2[hmyregion[0]:hmyregion[1], :]

    # Convert images to grayscale
    img1_gray = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
    img2_gray = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY)

    # Initiate ORB detector
    orb = cv2.ORB_create()

    # Find the keypoints and descriptors with ORB
    kp1, des1 = orb.detectAndCompute(img1_gray, None)
    kp2, des2 = orb.detectAndCompute(img2_gray, None)

    # Create BFMatcher object
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)

    # Match descriptors
    if des1 is None or des2 is None:
        print("someone")
        return np.ones((2, trianing_size[0], trianing_size[1]))
    matches = bf.match(des1, des2)

    # Sort them in the order of their distance
    matches = sorted(matches, key = lambda x: x.distance)

    # Extract location of good matches
    points1 = np.zeros((len(matches), 2), dtype=np.float32)
    points2 = np.zeros((len(matches), 2), dtype=np.float32)

    for i, match in enumerate(matches):
        points1[i, :] = kp1[match.queryIdx].pt
        points2[i, :] = kp2[match.trainIdx].pt

    # Find translation using median of differences to reduce the effect of outliers
    dx = np.median(points2[:, 0] - points1[:, 0])
    dy = np.median(points2[:, 1] - points1[:, 1])

    # Calculate the translation percentages
    dx_percent = dx / image1.shape[1]  # width
    dy_percent = dy / image1.shape[0]  # height

    np_array = np.zeros((2, trianing_size[0], trianing_size[1]))
    np_array[0] = dx_percent * 10
    np_array[1] = dy_percent * 10
    return np_array




class HumanDanceDataset(Dataset):
    def __init__(
        self,
        img_size,
        img_scale=(1.0, 1.0),
        img_ratio=(1.0, 1.0),
        drop_ratio=0.1,
        data_meta_paths=["./data/fahsion_meta.json"],
        sample_margin=30,
        original_size_img=(720, 784),
        original_size_pos=(512, 558),
    ):
        super().__init__()

        self.img_size = img_size
        self.img_scale = img_scale
        self.img_ratio = img_ratio
        self.sample_margin = sample_margin

        # -----
        # vid_meta format:
        # [{'video_path': , 'kps_path': , 'other':},
        #  {'video_path': , 'kps_path': , 'other':}]
        # -----
        vid_meta = []
        for data_meta_path in data_meta_paths:
            vid_meta.extend(json.load(open(data_meta_path, "r")))
        self.vid_meta = vid_meta

        self.clip_image_processor = CLIPImageProcessor()

        self.transform = transforms.Compose(
            [   
                transforms.CenterCrop((min(original_size_img), min(original_size_img))),
                transforms.Resize(self.img_size),
                # transforms.RandomResizedCrop(
                #     self.img_size,
                #     scale=self.img_scale,
                #     ratio=self.img_ratio,
                #     interpolation=transforms.InterpolationMode.BILINEAR,
                # ),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

        self.cond_transform = transforms.Compose(
            [
                transforms.CenterCrop((min(original_size_pos), min(original_size_pos))),
                transforms.Resize(self.img_size),
                # transforms.RandomResizedCrop(
                #     self.img_size,
                #     scale=self.img_scale,
                #     ratio=self.img_ratio,
                #     interpolation=transforms.InterpolationMode.BILINEAR,
                # ),
                transforms.ToTensor(),
            ]
        )

        self.drop_ratio = drop_ratio

    def augmentation(self, image, transform, state=None):
        if state is not None:
            torch.set_rng_state(state)
        return transform(image)

    def __getitem__(self, index):
        video_meta = self.vid_meta[index]
        video_path = video_meta["video_path"]
        kps_path = video_meta["kps_path"]

        video_reader = VideoReader(video_path)
        kps_reader = VideoReader(kps_path)

        assert len(video_reader) == len(
            kps_reader
        ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"

        video_length = len(video_reader)

        margin = min(self.sample_margin, video_length)

        ref_img_idx = random.randint(0, video_length - 1)
        if ref_img_idx + margin < video_length:
            tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1)
        elif ref_img_idx - margin > 0:
            tgt_img_idx = random.randint(0, ref_img_idx - margin)
        else:
            tgt_img_idx = random.randint(0, video_length - 1)

        ref_img = video_reader[ref_img_idx]
        ref_img_pil = Image.fromarray(ref_img.asnumpy())
        tgt_img = video_reader[tgt_img_idx]
        tgt_img_pil = Image.fromarray(tgt_img.asnumpy())
        

        # Calculate the translation between the two images
        # print(ref_img.asnumpy().shape, ref_img.asnumpy())
        translation = calculate_translation_hw(ref_img.asnumpy(), tgt_img.asnumpy(), hwxregion=(-201,-1), trianing_size=self.img_size)
        # print(translation, translation.shape)

        
        tgt_pose = kps_reader[tgt_img_idx]
        tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy())

        state = torch.get_rng_state()
        tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
        tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
        ref_img_vae = self.augmentation(ref_img_pil, self.transform, state)
        clip_image = self.clip_image_processor(
            images=ref_img_pil, return_tensors="pt"
        ).pixel_values[0]
        translation = torch.tensor(translation).to(dtype=tgt_pose_img.dtype)
        # print(tgt_pose_img.shape, translation.shape)
        tgt_pose_img = torch.cat((tgt_pose_img, translation), dim=0)
        
        sample = dict(
            video_dir=video_path,
            img=tgt_img,
            tgt_pose=tgt_pose_img,
            ref_img=ref_img_vae,
            clip_images=clip_image,
        )

        return sample

    def __len__(self):
        return len(self.vid_meta)
