import os
import sys
import traceback

import torch
import random
import pickle
import numpy as np
from decord import VideoReader
from copy import deepcopy
from PIL import Image

from torchvision.utils import save_image
from torchvision.transforms import transforms
from torch.utils.data import Dataset, DataLoader

sys.path.append('./')
from .draw_pose import draw_pose
from .pos_embed import get_2d_sincos_pos_embed, get_2d_local_sincos_pos_embed, RandomResizedCropCoord
import json


class VideoDataset(Dataset):
    def __init__(
            self,
            root_dir, split,
            sample_size=(960, 512), clip_size=(224, 224), scale_aug=(1.0, 1.0),
            sample_stride=4, sample_n_frames=16, ref_mode="random", pos_embed=False, pos_embed_dim=320,
            image_finetune=False, do_human_crop=False, local_body='hands', use_hamer=False, text_prompt=False, **kwargs,
    ):
        super().__init__()
        self.root_dir = root_dir
        self.split = split
        self.load_dir = os.path.join(self.root_dir, self.split)
        assert os.path.exists(self.load_dir), f"the path {self.load_dir} of the dataset is wrong"

        self.sample_size = sample_size
        self.clip_size = clip_size
        assert sample_stride >= 1
        self.sample_stride = sample_stride
        self.sample_n_frames = sample_n_frames
        self.at_least_n_frames = (self.sample_n_frames - 1) * self.sample_stride + 1
        # set where the reference frame comes from, which could be "first" or "random"
        assert ref_mode in ["first", "random"], \
            f"the ref_mode could only be \"first\" or \"random\". However \"ref_mode = {ref_mode}\" is given."
        self.ref_mode = ref_mode
        self.image_finetune = image_finetune
        self.do_human_crop = do_human_crop
        self.pos_embed = pos_embed
        self.pos_embed_dim = pos_embed_dim
        self.use_hamer = use_hamer
        self.local_body = local_body
        self.text_prompt = text_prompt

        if self.text_prompt:
            with open('training_humanart_dance.json') as f:
                self.texts = json.load(f)['images']


        # build data info
        self.data_dirs = sorted(os.listdir(os.path.join(self.load_dir)))
        self.data_keys = []
        for dd in self.data_dirs:
            data_keys = sorted(os.listdir(os.path.join(self.load_dir, dd, 'image')))
            data_keys = [os.path.join(self.load_dir, dd, 'image', dk) for dk in data_keys]
            self.data_keys += data_keys
        self.length = len(self.data_keys)

        self.img_transform = transforms.Compose([
            # transforms.ToTensor(),
            # ratio is w/h
            transforms.RandomResizedCrop(
                sample_size, scale=scale_aug,
                ratio=(sample_size[1] / sample_size[0], sample_size[1] / sample_size[0]), antialias=True),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])
        self.clip_transform = transforms.Compose([
            # transforms.ToTensor(),
            # ratio is w/h
            transforms.RandomResizedCrop(
                clip_size, scale=scale_aug,
                ratio=(clip_size[1] / clip_size[0], clip_size[1] / clip_size[0]), antialias=True),
            transforms.Normalize([0.485, 0.456, 0.406],  # used for dino
                                 [0.229, 0.224, 0.225],  # used for dino
                                 inplace=True),
        ])
        self.pose_transform = transforms.Compose([
            # transforms.ToTensor(),
            # ratio is w/h
            transforms.RandomResizedCrop(
                sample_size, scale=scale_aug,
                ratio=(sample_size[1] / sample_size[0], sample_size[1] / sample_size[0]), antialias=True),
        ])

        self.local_rcr = RandomResizedCropCoord(
            self.sample_size, 
            scale=(0.05, 0.25), 
            ratio=(sample_size[1]/sample_size[0], sample_size[1]/sample_size[0]), interpolation=Image.BICUBIC)

    def __len__(self):
        return len(self.data_keys)

    @staticmethod
    def crop_pose_dict(pose_dict, x_min, x_max, y_min, y_max):
        pose_dict = deepcopy(pose_dict)
        bodies_nodes = np.array(pose_dict['bodies']['candidate'])
        hands_nodes = np.array(pose_dict['hands'])
        H = pose_dict['H']
        W = pose_dict['W']

        bodies_nodes[:, 0] *= W
        bodies_nodes[:, 1] *= H
        hands_nodes[:, :, 0] *= W
        hands_nodes[:, :, 1] *= H

        bodies_nodes[:, 0] -= x_min
        bodies_nodes[:, 1] -= y_min
        hands_nodes[:, :, 0] -= x_min
        hands_nodes[:, :, 1] -= y_min

        crop_w = x_max - x_min + 1e-10
        crop_h = y_max - y_min + 1e-10

        bodies_nodes[:, 0] /= crop_w
        bodies_nodes[:, 1] /= crop_h
        hands_nodes[:, :, 0] /= crop_w
        hands_nodes[:, :, 1] /= crop_h

        pose_dict['H'] = crop_h
        pose_dict['W'] = crop_w

        pose_dict['bodies']['candidate'] = bodies_nodes
        pose_dict['hands'] = hands_nodes
        return pose_dict

    def get_dwpose_hand_pos(self, pkl_path):
        with open(pkl_path, 'rb') as f:
            pkl = pickle.load(f)
        h, w = self.sample_size 
        if self.local_body == 'hands':
            hand_idx = np.random.choice(2)
            hand_points = pkl['hands'][hand_idx]
        else:
            hand_points = pkl['faces'][0]
        x0, y0 = np.min(hand_points, axis=0)
        x1, y1 = np.max(hand_points, axis=0)
        if not (x0 >= 0 and x1 >= 0 and y1 >= 0 and y0 >= 0): return None
        x0, x1 = int(x0 * w), int(x1 * w)
        y0, y1 = int(y0 * h), int(y1 * h)
        w_l = (x1 - x0) // 8 * 8
        h_l = w_l * h // w
        # scale = 6
        # w_l, h_l = w_l * scale, h_l * scale
        h_l, w_l = self.sample_size[0]//2, self.sample_size[1]//2
        # expand the hand region w.r.t the center of original hand region
        x0_n = np.clip((x0+x1)//2 - w_l // 2, 0, w)
        x1_n = np.clip((x0+x1)//2 + w_l // 2, 0, w)
        y0_n = np.clip((y0+y1)//2 - h_l // 2, 0, h)
        y1_n = np.clip((y0+y1)//2 + h_l // 2, 0, h)
        # renormalize scale

        h_f, w_f = h // np.gcd(h, w), w // np.gcd(h, w)
        w_l = (x1_n - x0_n) // w_f * w_f 
        # means that out-of-bound could happen
        if (y1_n - y0_n) // h_f * h_f < w_l * h_f // w_f and (y0_n + w_l * h // w) >= h:
            h_l = (y1_n - y0_n) // h_f * h_f
            w_l = h_l * w // h
        else:
            h_l = w_l * h // w
        
        if not(h_l > 0 and w_l > 0): return None
        return (y0_n, x0_n, h_l, w_l)


    def get_human_crop_xy(self, data_id, batch_index):
        smpl_path = os.path.join(self.load_dir, 'smpl', data_id)
        smpl_reader = VideoReader(smpl_path)
        _mask = smpl_reader.get_batch(batch_index).asnumpy().astype(dtype=np.float32).sum(0).sum(-1)
        _mask = (_mask > 0).astype(dtype=np.float32)
        h, w = _mask.shape
        y_lst, x_lst = np.nonzero(_mask)
        x_min = max(x_lst.min(), 0)
        x_max = min(x_lst.max(), w)
        y_min = max(y_lst.min(), 0)
        y_max = min(y_lst.max(), h)
        return x_min, x_max, y_min, y_max

    def get_batch(self, data_id):
        image = Image.open(data_id)
        pose_path = data_id.replace('.jpg', '.pkl').replace('image', 'dwpose', 1)

        with open(pose_path, 'rb') as pose_file:
            pose = pickle.load(pose_file)

        image = torch.from_numpy(
            np.array(image)).permute(2, 0, 1).contiguous() / 255.0

        pose = draw_pose(pose, image.shape[-2], image.shape[-1])
        pose = torch.from_numpy(
            np.stack(pose, axis=0)).permute(2, 0, 1).contiguous() / 255.0

        if not os.path.exists(pose_path): 
            lpos = None
        else:
            lpos = self.get_dwpose_hand_pos(pose_path)

        return image, pose, lpos

    @staticmethod
    def augmentation(frame, transform, state=None):
        if state is not None:
            torch.set_rng_state(state)
        return transform(frame)

    def __getitem__(self, idx):
        try_cnt = 0
        while True:
            try:
                try_cnt += 1
                if try_cnt > 10:
                    break
                data_id = self.data_keys[idx]
                image, pose, lpos = self.get_batch(data_id)
                state = torch.get_rng_state()
                state = torch.get_rng_state()
                image = self.augmentation(image, self.img_transform, state)
                pose = self.augmentation(pose, self.pose_transform, state)

                if lpos is None:
                    print('no hand detected')
                    lpos = self.local_rcr.get_params(image, self.local_rcr.scale, self.local_rcr.ratio)

                lpos = [i // 8 * 8 for i in lpos]
                _, limg = self.local_rcr(image, lpos)
                _, ldensepose = self.local_rcr(pose, lpos)
                # pose = ((pose.permute(1, 2, 0).numpy()*255).astype(np.uint8))
                # ldensepose = ((ldensepose.permute(1, 2, 0).numpy()*255).astype(np.uint8))
                sample = dict(
                    image=image,
                    pose=pose,
                    lpose=ldensepose,
                    lpos=torch.tensor(lpos),
                    )
                if self.text_prompt:
                    idx= np.random.choice(len(self.texts))
                    prompt = self.texts[idx]['description']
                    sample['text'] = prompt
                return sample


            except Exception as e:
                print(f"read idx:{idx} error, {type(e).__name__}: {e}")
                print(traceback.format_exc())
                idx = random.randint(0, self.length - 1)


if __name__ == "__main__":
    from animatediff.utils.util import save_videos_grid

    save_dir = "./debug/video_dataset/img/"
    os.makedirs(save_dir, exist_ok=True)
    val_data = VideoDataset(
        root_dir="/groupnas/zhoujingkai.zjk/data/video_dataset/taobao_dance/",
        split="", do_human_crop=False, ref_mode='random', pos_embed=True, use_hamer=True, image_finetune=True)
    val_loader = DataLoader(val_data, batch_size=3, shuffle=False, num_workers=1)
    for i, sample in enumerate(val_loader):
        # print(sample["data_key"])
        print(f"image shape is {sample['pixel_values'].shape}")
        print(f"pose shape is {sample['pixel_values_pose'].shape}")
        print(f"ref_image shape is {sample['pixel_values_ref_img'].shape}")
        print(f"clip ref_image shape is {sample['clip_ref_image'].shape}")
        # print(f"ref_pose shape is {sample['pixel_values_'].shape}")
        save_obj = torch.cat([
            (sample["pixel_values_ref_img"].cpu() / 2 + 0.5).clamp(0, 1),
            # sample["ref_pose"].cpu(),
            (sample["pixel_values"].cpu() / 2 + 0.5).clamp(0, 1),
            (sample["local_pixel_values"].cpu() / 2 + 0.5).clamp(0, 1),
            (sample["local_pixel_values_pose"].cpu() / 2 + 0.5).clamp(0, 1),
            sample["pixel_values_pose"].cpu(),
        ], dim=-1)
        save_image(
            save_obj, save_dir + f"sample_{i}.png")
        save_image(sample["clip_ref_image"]/2+0.5, save_dir + f"ref_image_clip{i}.png")
        if i > 20:
            break

    # save_dir = "./debug/video_dataset/vid/"
    # os.makedirs(save_dir, exist_ok=True)
    # val_data = VideoDataset(
    #     root_dir="/groupnas/zhoujingkai.zjk/data/video_dataset/taobao_dance/",
    #     split="", do_human_crop=False, ref_mode='random', use_hamer=True)
    # val_loader = DataLoader(val_data, batch_size=10, shuffle=False, num_workers=1)
    # for i, sample in enumerate(val_loader):
    #     print(sample["data_key"])
    #     print(f"image shape is {sample['image'].shape}")
    #     print(f"pose shape is {sample['pose'].shape}")
    #     print(f"ref_image shape is {sample['ref_image'].shape}")
    #     print(f"ref_pose shape is {sample['ref_pose'].shape}")
    #     video_length = sample['image'].shape[2]
    #     sample["ref_image"] = sample["ref_image"].unsqueeze(2).repeat(1, 1, video_length, 1, 1)
    #     sample["ref_pose"] = sample["ref_pose"].unsqueeze(2).repeat(1, 1, video_length, 1, 1)
    #     save_obj = torch.cat([
    #         (sample["ref_image"].cpu() / 2 + 0.5).clamp(0, 1),
    #         sample["ref_pose"].cpu(),
    #         (sample["image"].cpu() / 2 + 0.5).clamp(0, 1),
    #         sample["pose"].cpu(),
    #     ], dim=-1)
    #     save_videos_grid(
    #         save_obj, save_dir + f"sample_{i}.gif", rescale=False)
    #     save_image(sample["ref_image_clip"]/2+0.5, save_dir + f"ref_image_clip{i}.png")
    #     if i > 20:
    #         break