#for multiview
import os
import os.path as osp
import numpy as np
import torchvision.transforms as standard
from torch.utils.data import Dataset
import random
import cv2
import json
from utils.laban_encoder import laban_encoder

def load_img_from_id(img_path):
    
    img = cv2.imread(img_path)
    return img
def gen_frame_shift(view_num, img_num, clip_step):
    data_pairs = []
    view_num_digit = img_num % clip_step
    clip_num = int(img_num / clip_step - 1)
    if view_num_digit < 5:
        clip_num -= 1
    for i in range(view_num):
        for j in range(view_num):
            if i == j:
                continue
            for k in range(clip_num):
                frame_shift = random.randint(0, 5)
                data_pairs.append((i + 1, k * clip_step + 1, j + 1, k * clip_step + 1 + frame_shift))
                data_pairs.append((i + 1, k * clip_step + 1, j + 1, k * clip_step + 1 + clip_step - 1 - frame_shift))
    return data_pairs

# def load_img_from_id(img_dir, img_id):
#     img_name = str(img_id)

#     for i in range(4 - len(img_name)):
#         img_name = '0' + img_name
    
#     img_name += '.jpg'
#     img_path = osp.join(img_dir, img_name)
#     img = cv2.imread(img_path)
#     img = cv2.resize(img, dsize = (960, 512))
#     return img


def blurAugmentation(img, prob):
    if random.random() <= prob:
        type = random.randint(1, 3)
        kernel_size = random.choice([3, 5, 7, 9, 11, 13])
        if type == 1:
            img = cv2.blur(img, (kernel_size, kernel_size))
        elif type == 2:
            img = cv2.medianBlur(img, kernel_size)
        elif type == 3:
            sigma = random.choice([0, 1, 2, 3, 4, 5])
            img = cv2.GaussianBlur(img, (kernel_size, kernel_size), sigmaX=sigma, sigmaY=sigma)
    return img

class HandLanbanDataset(Dataset):
    def __init__(self, root_dir, split, aug = True, clip_step = 10):
        """
        Args:
            data_dir (string): Directory with all the images.
            split (string): Name of the split ('val', 'test',...).
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        mean_std = ([0.485,0.456,0.406],[0.229,0.224,0.225])
        self.root_dir = root_dir
        self.split = split
        self.aug = aug
        self.clip_step = clip_step
        self.video_clip_num = 0
        self.images_name_list = []
        
        self.transform = standard.Compose(
            [
                standard.ToTensor(),
                standard.Normalize(*mean_std)
            ]
        )
                
        if self.split == 'train':
            self.image_dir = osp.join(self.root_dir, 'images')
            self.split_image_dir = osp.join(self.image_dir, 'train')
            for i in range(27):
                capture_name = 'Capture' + str(i)
                capture_image_dir = osp.join(self.split_image_dir, capture_name)
                motion_name_list = os.listdir(capture_image_dir)
                motion_name_list.sort()
                for j in range(len(motion_name_list)):
                    motion_dir = osp.join(capture_image_dir, motion_name_list[j])
                    camera_list = os.listdir(motion_dir)
                    camera_list.sort()
                    camera_dir_temp = osp.join(motion_dir, camera_list[0])
                    image_list = os.listdir(camera_dir_temp)
                    image_list.sort()
                    motion_clip_num = len(image_list)
                    self.video_clip_num += motion_clip_num
                    camera_choose_list = [0, 3, 8, 10]
                    for item in image_list:
                        temp_name = []
                        for k in camera_choose_list:#取前十个视角
                            camera_dir = osp.join(motion_dir, camera_list[k])
                            image_name = osp.join(camera_dir, item)
                            temp_name.append(image_name)
                        self.images_name_list.append(temp_name)
              
        elif self.split == 'test':
            self.image_dir = osp.join(self.root_dir, 'images')
            self.split_image_dir = osp.join(self.image_dir, 'test')
            for i in range(8):
                capture_name = 'Capture' + str(i)
                capture_image_dir = osp.join(self.split_image_dir, capture_name)
                motion_name_list = os.listdir(capture_image_dir)
                motion_name_list.sort()
                for j in range(len(motion_name_list)):
                    motion_dir = osp.join(capture_image_dir, motion_name_list[j])
                    camera_list = os.listdir(motion_dir)
                    camera_list.sort()
                    camera_dir_temp = osp.join(motion_dir, camera_list[0])
                    image_list = os.listdir(camera_dir_temp)
                    image_list.sort()
                    motion_clip_num = len(image_list)
                    self.video_clip_num += motion_clip_num
                    camera_choose_list = [0, 3, 8, 10]
                    for item in image_list:
                        temp_name = []
                        for k in camera_choose_list:#取前十个视角
                            camera_dir = osp.join(motion_dir, camera_list[k])
                            image_name = osp.join(camera_dir, item)
                            temp_name.append(image_name)
                        self.images_name_list.append(temp_name)
                    
        elif self.split == 'val':
            self.image_dir = osp.join(self.root_dir, 'images')
            self.split_image_dir = osp.join(self.image_dir, 'val')
            capture_name = 'Capture' + '0'
            capture_image_dir = osp.join(self.split_image_dir, capture_name)
            motion_name_list = os.listdir(capture_image_dir)
            motion_name_list.sort()
            for j in range(len(motion_name_list)):
                motion_dir = osp.join(capture_image_dir, motion_name_list[j])
                camera_list = os.listdir(motion_dir)
                camera_list.sort()
                camera_dir_temp = osp.join(motion_dir, camera_list[0])
                image_list = os.listdir(camera_dir_temp)
                image_list.sort()
                motion_clip_num = len(image_list)
                self.video_clip_num += motion_clip_num
                camera_choose_list = [0, 3, 8, 10]
                for item in image_list:
                    temp_name = []
                    for k in camera_choose_list:#4 different views
                        camera_dir = osp.join(motion_dir, camera_list[k])
                        image_name = osp.join(camera_dir, item)
                        temp_name.append(image_name)
                    self.images_name_list.append(temp_name)
        
        assert split in ['val', 'test', 'train'], 'Split "{}" not supported'.format(split)
        
        print('Loading', split, 'data......')
        
        if self.aug:
            print('Using blur augmentation!')
            
    def __len__(self):
        """
        Returns:
            int: Number of samples in the dataset.
        """
        return self.video_clip_num
        # return 4000
        
    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
            
        """
        
        total_idx = index + 1
        if self.split == 'train':
            for i in range(27):
                capture_name = 'Capture' + str(i)
                capture_image_dir = osp.join(self.split_image_dir, capture_name)
                motion_name_list = os.listdir(capture_image_dir)
                motion_name_list.sort()
                for j in range(len(motion_name_list)):
                    motion_dir = osp.join(capture_image_dir, motion_name_list[j])
                    camera_list = os.listdir(motion_dir)
                    camera_list.sort()
                    camera_dir = osp.join(motion_dir, camera_list[4])
                    images_list = os.listdir(camera_dir)
            image_group_dir_path = osp.join(motion_dir, camera_list[4])
            annotation_dir = osp.join(self.root_dir, 'annotations')
            split_label_dir = osp.join(annotation_dir,self.split)
            # label_list = os.listdir(split_label_dir)
            # label_list.sort()
            split_label_camera_name = 'InterHand2.6M_' + self.split + '_camera.json'
            split_label_camera_dir = osp.join(split_label_dir, split_label_camera_name)
            split_label_data_name = 'InterHand2.6M_' + self.split + '_data.json'
            split_label_data_dir = osp.join(split_label_dir, split_label_data_name)
            split_label_join_3d_name = 'InterHand2.6M_' + self.split + '_joint_3d.json'
            split_label_join_3d_dir = osp.join(split_label_dir, split_label_join_3d_name)
            split_label_laban_name = 'InterHand2.6M_' + self.split + '_laban.json'
            split_label_laban_dir = osp.join(split_label_dir,split_label_laban_name)
            split_laber_img2frame_name = 'InterHand2.6M_' + self.split + '_img2frame.json'
            split_label_img2frame_dir = osp.join(split_label_dir, split_laber_img2frame_name)
            img_0 = load_img_from_id(self.images_name_list[index][0])
            img_1 = load_img_from_id(self.images_name_list[index][1])
            img_2 = load_img_from_id(self.images_name_list[index][2])
            img_3 = load_img_from_id(self.images_name_list[index][3])
            # img_4 = load_img_from_id(self.images_name_list[index][4])
            # img_5 = load_img_from_id(self.images_name_list[index][5])
            # img_6 = load_img_from_id(self.images_name_list[index][6])
            # img_7 = load_img_from_id(self.images_name_list[index][7])
            # img_8 = load_img_from_id(self.images_name_list[index][8])
            # img_9 = load_img_from_id(self.images_name_list[index][9])
            img_0 = blurAugmentation(img_0, 0.5)
            img_1 = blurAugmentation(img_1, 0.5)
            img_2 = blurAugmentation(img_2, 0.5)
            img_3 = blurAugmentation(img_3, 0.5)
            # img_4 = blurAugmentation(img_4, 0.5)
            # img_5 = blurAugmentation(img_5, 0.5)
            # img_6 = blurAugmentation(img_6, 0.5)
            # img_7 = blurAugmentation(img_7, 0.5)
            # img_8 = blurAugmentation(img_8, 0.5)
            # img_9 = blurAugmentation(img_9, 0.5)
            img_0 = self.transform(img_0.astype(np.uint8))
            img_1 = self.transform(img_1.astype(np.uint8))
            img_2 = self.transform(img_2.astype(np.uint8))
            img_3 = self.transform(img_3.astype(np.uint8))
            # img_4 = self.transform(img_4.astype(np.uint8))
            # img_5 = self.transform(img_5.astype(np.uint8))
            # img_6 = self.transform(img_6.astype(np.uint8))
            # img_7 = self.transform(img_7.astype(np.uint8))
            # img_8 = self.transform(img_8.astype(np.uint8))
            # img_9 = self.transform(img_9.astype(np.uint8))
            
            with open(split_label_laban_dir,'r', encoding='utf8') as fp:
                laban_data = json.load(fp)
                f_img2frame = open(split_label_img2frame_dir,'r')
                data_img2frame = json.load(f_img2frame)
                
                capture_id_0 = data_img2frame[self.images_name_list[index][0]]['capture_id']
                frame_idx_0 = data_img2frame[self.images_name_list[index][0]]['frame_idx']
                hand_type_0 = data_img2frame[self.images_name_list[index][0]]['hand_type']
                
                
                annotation_laban_0 = laban_data[capture_id_0][frame_idx_0]
                
                
                # print('test')
        elif self.split == 'test':
            for i in range(8):
                capture_name = 'Capture' + str(i)
                capture_image_dir = osp.join(self.split_image_dir, capture_name)
                motion_name_list = os.listdir(capture_image_dir)
                motion_name_list.sort()
                for j in range(len(motion_name_list)):
                    motion_dir = osp.join(capture_image_dir, motion_name_list[j])
                    camera_list = os.listdir(motion_dir)
                    camera_list.sort()
                    camera_dir = osp.join(motion_dir, camera_list[4])
                    images_list = os.listdir(camera_dir)
            image_group_dir_path = osp.join(motion_dir, camera_list[4])
            annotation_dir = osp.join(self.root_dir, 'annotations')
            split_label_dir = osp.join(annotation_dir,self.split)
            # label_list = os.listdir(split_label_dir)
            # label_list.sort()
            split_label_camera_name = 'InterHand2.6M_' + self.split + '_camera.json'
            split_label_camera_dir = osp.join(split_label_dir, split_label_camera_name)
            split_label_data_name = 'InterHand2.6M_' + self.split + '_data.json'
            split_label_data_dir = osp.join(split_label_dir, split_label_data_name)
            split_label_join_3d_name = 'InterHand2.6M_' + self.split + '_joint_3d.json'
            split_label_join_3d_dir = osp.join(split_label_dir, split_label_join_3d_name)
            split_label_laban_name = 'InterHand2.6M_' + self.split + '_laban.json'
            split_label_laban_dir = osp.join(split_label_dir,split_label_laban_name)
            split_laber_img2frame_name = 'InterHand2.6M_' + self.split + '_img2frame.json'
            split_label_img2frame_dir = osp.join(split_label_dir, split_laber_img2frame_name)
            img_0 = load_img_from_id(self.images_name_list[index][0])
            img_1 = load_img_from_id(self.images_name_list[index][1])
            img_2 = load_img_from_id(self.images_name_list[index][2])
            img_3 = load_img_from_id(self.images_name_list[index][3])
            # img_4 = load_img_from_id(self.images_name_list[index][4])
            # img_5 = load_img_from_id(self.images_name_list[index][5])
            # img_6 = load_img_from_id(self.images_name_list[index][6])
            # img_7 = load_img_from_id(self.images_name_list[index][7])
            # img_8 = load_img_from_id(self.images_name_list[index][8])
            # img_9 = load_img_from_id(self.images_name_list[index][9])
            img_0 = blurAugmentation(img_0, 0.5)
            img_1 = blurAugmentation(img_1, 0.5)
            img_2 = blurAugmentation(img_2, 0.5)
            img_3 = blurAugmentation(img_3, 0.5)
            # img_4 = blurAugmentation(img_4, 0.5)
            # img_5 = blurAugmentation(img_5, 0.5)
            # img_6 = blurAugmentation(img_6, 0.5)
            # img_7 = blurAugmentation(img_7, 0.5)
            # img_8 = blurAugmentation(img_8, 0.5)
            # img_9 = blurAugmentation(img_9, 0.5)
            img_0 = self.transform(img_0.astype(np.uint8))
            img_1 = self.transform(img_1.astype(np.uint8))
            img_2 = self.transform(img_2.astype(np.uint8))
            img_3 = self.transform(img_3.astype(np.uint8))
            # img_4 = self.transform(img_4.astype(np.uint8))
            # img_5 = self.transform(img_5.astype(np.uint8))
            # img_6 = self.transform(img_6.astype(np.uint8))
            # img_7 = self.transform(img_7.astype(np.uint8))
            # img_8 = self.transform(img_8.astype(np.uint8))
            # img_9 = self.transform(img_9.astype(np.uint8))
            
            with open(split_label_laban_dir,'r', encoding='utf8') as fp:
                laban_data = json.load(fp)
                f_img2frame = open(split_label_img2frame_dir,'r')
                data_img2frame = json.load(f_img2frame)
                
                capture_id_0 = data_img2frame[self.images_name_list[index][0]]['capture_id']
                frame_idx_0 = data_img2frame[self.images_name_list[index][0]]['frame_idx']
                hand_type_0 = data_img2frame[self.images_name_list[index][0]]['hand_type']
                
                annotation_laban_0 = laban_data[capture_id_0][frame_idx_0]
                
                
                # print('test')
        elif self.split == 'val':
            capture_name = 'Capture' + '0'
            capture_image_dir = osp.join(self.split_image_dir, capture_name)
            motion_name_list = os.listdir(capture_image_dir)
            motion_name_list.sort()
            for j in range(len(motion_name_list)):
                motion_dir = osp.join(capture_image_dir, motion_name_list[j])
                camera_list = os.listdir(motion_dir)
                camera_list.sort()
                camera_dir = osp.join(motion_dir, camera_list[4])
                images_list = os.listdir(camera_dir)
            image_group_dir_path = osp.join(motion_dir, camera_list[4])
            annotation_dir = osp.join(self.root_dir, 'annotations')
            split_label_dir = osp.join(annotation_dir,self.split)
            # label_list = os.listdir(split_label_dir)
            # label_list.sort()
            split_label_camera_name = 'InterHand2.6M_' + self.split + '_camera.json'
            split_label_camera_dir = osp.join(split_label_dir, split_label_camera_name)
            split_label_data_name = 'InterHand2.6M_' + self.split + '_data.json'
            split_label_data_dir = osp.join(split_label_dir, split_label_data_name)
            split_label_join_3d_name = 'InterHand2.6M_' + self.split + '_joint_3d.json'
            split_label_join_3d_dir = osp.join(split_label_dir, split_label_join_3d_name)
            split_label_laban_name = 'InterHand2.6M_' + self.split + '_laban.json'
            split_label_laban_dir = osp.join(split_label_dir,split_label_laban_name)
            split_laber_img2frame_name = 'InterHand2.6M_' + self.split + '_img2frame.json'
            split_label_img2frame_dir = osp.join(split_label_dir, split_laber_img2frame_name)
            img_0 = load_img_from_id(self.images_name_list[index][0])
            img_1 = load_img_from_id(self.images_name_list[index][1])
            img_2 = load_img_from_id(self.images_name_list[index][2])
            img_3 = load_img_from_id(self.images_name_list[index][3])
            # img_4 = load_img_from_id(self.images_name_list[index][4])
            # img_5 = load_img_from_id(self.images_name_list[index][5])
            # img_6 = load_img_from_id(self.images_name_list[index][6])
            # img_7 = load_img_from_id(self.images_name_list[index][7])
            # img_8 = load_img_from_id(self.images_name_list[index][8])
            # img_9 = load_img_from_id(self.images_name_list[index][9])
            img_0 = blurAugmentation(img_0, 0.5)
            img_1 = blurAugmentation(img_1, 0.5)
            img_2 = blurAugmentation(img_2, 0.5)
            img_3 = blurAugmentation(img_3, 0.5)
            # img_4 = blurAugmentation(img_4, 0.5)
            # img_5 = blurAugmentation(img_5, 0.5)
            # img_6 = blurAugmentation(img_6, 0.5)
            # img_7 = blurAugmentation(img_7, 0.5)
            # img_8 = blurAugmentation(img_8, 0.5)
            # img_9 = blurAugmentation(img_9, 0.5)
            img_0 = self.transform(img_0.astype(np.uint8))
            img_1 = self.transform(img_1.astype(np.uint8))
            img_2 = self.transform(img_2.astype(np.uint8))
            img_3 = self.transform(img_3.astype(np.uint8))
            # img_4 = self.transform(img_4.astype(np.uint8))
            # img_5 = self.transform(img_5.astype(np.uint8))
            # img_6 = self.transform(img_6.astype(np.uint8))
            # img_7 = self.transform(img_7.astype(np.uint8))
            # img_8 = self.transform(img_8.astype(np.uint8))
            # img_9 = self.transform(img_9.astype(np.uint8))
            
            with open(split_label_laban_dir,'r', encoding='utf8') as fp:
                laban_data = json.load(fp)
                f_img2frame = open(split_label_img2frame_dir,'r')
                data_img2frame = json.load(f_img2frame)
                
                capture_id_0 = data_img2frame[self.images_name_list[index][0]]['capture_id']
                frame_idx_0 = data_img2frame[self.images_name_list[index][0]]['frame_idx']
                hand_type_0 = data_img2frame[self.images_name_list[index][0]]['hand_type']
                
                
                #acquire laban data
                annotation_laban_0 = laban_data[capture_id_0][frame_idx_0]
                
                
                # print('test')
        inputs = {'img_0':img_0, 'img_1':img_1, 'img_2':img_2, 'img_3':img_3}
        
        annotation_laban_0_encode = laban_encoder(annotation_laban_0)
      
        
        # generate laban encode mask
        one_mask = [1 for _ in range(26)]
        zero_mask = [0 for _ in range(26)]
        laban_0_mask = annotation_laban_0_encode.copy()
        
        if hand_type_0 == 'right':
            laban_0_mask[:21] = one_mask
            laban_0_mask[21:] = zero_mask
        elif hand_type_0 == 'left':
            laban_0_mask[:21] = zero_mask
            laban_0_mask[21:] = one_mask
        elif hand_type_0 == 'interacting':
            laban_0_mask[:21] = one_mask
            laban_0_mask[21:] = one_mask
        else:
            print("Unexpected value of 'hand_type'. It should be 'right', 'left', or 'interacting'.")
        
        
        mask = {'laban_0_mask':laban_0_mask, 'laban_1_mask':laban_0_mask, 'laban_2_mask':laban_0_mask, 
                'laban_3_mask':laban_0_mask, 'laban_4_mask':laban_0_mask, 'laban_5_mask':laban_0_mask, 
                'laban_6_mask':laban_0_mask, 'laban_7_mask':laban_0_mask, 'laban_8_mask':laban_0_mask, 
                'laban_9_mask':laban_0_mask}
            
        
        
        targets = {'annotation_laban_0':annotation_laban_0_encode, 
                   'laban_0_mask':laban_0_mask}
        
        return inputs, targets
         
            
if __name__ == '__main__':
    root_dir = os.getcwd()
    data_dir = osp.join(root_dir,'data')
    # root_dir = 'f:\\HandLabanNet\\data'
    split = 'test'
    hand_data = HandLanbanDataset(root_dir=data_dir, split=split)
    import torch
    train_loader = torch.utils.data.DataLoader(hand_data, batch_size=5)
    for iteration, (inputs, targets) in enumerate(train_loader):
        print("test")
    
    
            
                
        
        
        
        
        
    

            
            
            
                