import json
import random

import torch
import torchvision.transforms as transforms
from decord import VideoReader
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor

import cv2
import numpy as np
import matplotlib.pyplot as plt


optical_flow_dis = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_FAST)

def get_wrapped_img(first_img_path, second_img_path, optical_flow_dis, verbose=False):
    # frame1 = cv2.imread(first_img_path)
    # frame2 = cv2.imread(second_img_path)

    prev_gray = cv2.cvtColor(first_img_path, cv2.COLOR_RGB2GRAY)
    next_gray = cv2.cvtColor(second_img_path, cv2.COLOR_RGB2GRAY)

    flow_1 = optical_flow_dis.calc(prev_gray, next_gray, None)
    flow_2 = optical_flow_dis.calc(next_gray, prev_gray, None)

    h, w = flow_1.shape[:2]
    flow_map = 0.5 * -flow_1.copy() + 0.5 * flow_2.copy()
    flow_map[..., 0] += np.arange(w)
    flow_map[..., 1] += np.arange(h)[:, np.newaxis]
    flow_map = flow_map.astype(np.float32)
    
    wrapped_img = cv2.remap(first_img_path, flow_map, None, cv2.INTER_LINEAR)
    # wrapped_img = wrapped_img.astype(np.uint8)

    # if verbose:
    #     # Prepare for visualization
    #     comparison = np.hstack((frame1, wrapped_img, frame2))
    
    #     # Visualization
    #     plt.figure(figsize=(15, 5))
    #     plt.imshow(cv2.cvtColor(comparison, cv2.COLOR_BGR2RGB))
    #     plt.title('Original Image | Warped Image | Second Image')
    #     plt.axis('off')
    #     plt.show()
    return wrapped_img

# Replace these paths with the paths to your images
# first_img_path = '/data/Moore-AnimateAnyone/configs/inference/talkshow_ref_imges/frame_0_214438-00_07_16-00_07_26.jpg'
# second_img_path = '/data/Moore-AnimateAnyone/configs/inference/talkshow_ref_imges/frame_8_214438-00_07_16-00_07_26.jpg'

class HumanDanceDataset(Dataset):
    def __init__(
        self,
        img_size,
        img_scale=(1.0, 1.0),
        img_ratio=(0.9, 1.0),
        drop_ratio=0.1,
        data_meta_paths=["./data/fahsion_meta.json"],
        sample_margin=30,
    ):
        super().__init__()

        self.img_size = img_size
        self.img_scale = img_scale
        self.img_ratio = img_ratio
        self.sample_margin = sample_margin

        # -----
        # vid_meta format:
        # [{'video_path': , 'kps_path': , 'other':},
        #  {'video_path': , 'kps_path': , 'other':}]
        # -----
        vid_meta = []
        for data_meta_path in data_meta_paths:
            vid_meta.extend(json.load(open(data_meta_path, "r")))
        self.vid_meta = vid_meta

        self.clip_image_processor = CLIPImageProcessor()

        self.transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    self.img_size,
                    scale=self.img_scale,
                    ratio=self.img_ratio,
                    interpolation=transforms.InterpolationMode.BILINEAR,
                ),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

        self.cond_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    self.img_size,
                    scale=self.img_scale,
                    ratio=self.img_ratio,
                    interpolation=transforms.InterpolationMode.BILINEAR,
                ),
                transforms.ToTensor(),
            ]
        )

        self.drop_ratio = drop_ratio

    def augmentation(self, image, transform, state=None):
        if state is not None:
            torch.set_rng_state(state)
        return transform(image)

    def __getitem__(self, index):
        video_meta = self.vid_meta[index]
        video_path = video_meta["video_path"]
        kps_path = video_meta["kps_path"]

        video_reader = VideoReader(video_path)
        kps_reader = VideoReader(kps_path)

        assert len(video_reader) == len(
            kps_reader
        ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"

        video_length = len(video_reader)

        margin = min(self.sample_margin, video_length)
        ref_img_idx = random.randint(0, video_length - 1)
        pob = random.randint(0, 99)
        if pob >= 50:
            tgt_img_idx = ref_img_idx
        else:         
            if ref_img_idx + margin < video_length:
                tgt_img_idx = random.randint(ref_img_idx, ref_img_idx + margin - 1)
            else: 
                tgt_img_idx = random.randint(ref_img_idx - margin, ref_img_idx)

        ref_img = video_reader[ref_img_idx]
        ref_img_pil = Image.fromarray(ref_img.asnumpy())
        tgt_img = video_reader[tgt_img_idx]
        tgt_img_pil = Image.fromarray(tgt_img.asnumpy())

        
        ref_pose = kps_reader[ref_img_idx]
        ref_pose_pil = Image.fromarray(ref_pose.asnumpy())
        tgt_pose = kps_reader[tgt_img_idx]
        tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy())

        wrapped_img_pil = Image.fromarray(get_wrapped_img(ref_pose.asnumpy(), tgt_pose.asnumpy(), optical_flow_dis))

        state = torch.get_rng_state()
        
        tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
        tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
        ref_img_vae = self.augmentation(ref_img_pil, self.transform, state)
        wrapped_img = self.augmentation(wrapped_img_pil, self.transform, state)

        clip_image = self.clip_image_processor(
            images=ref_img_pil, return_tensors="pt"
        ).pixel_values[0]

        sample = dict(
            video_dir=video_path,
            img=tgt_img,
            tgt_pose=tgt_pose_img,
            ref_img=ref_img_vae,
            clip_images=clip_image,
            wrapped_img=wrapped_img,
        )

        return sample

    def __len__(self):
        return len(self.vid_meta)
