import os
import os.path as osp
from os import PathLike
import pickle
import numpy as np
import torch.distributed as dist
from torchvision.datasets import ImageNet as ImageNet_pytorch
from PIL import Image

from torchvision.transforms import Compose
from .base_dataset import BaseDataset
from utils.parallel import get_dist_info
from .utils import download_and_extract_archive, check_integrity

def expanduser(path):
    if isinstance(path, (str, PathLike)):
        return osp.expanduser(path)
    else:
        return path


def construct_data_list(root, classes, augs, tag, test_mode):
    datas = []
    datas_dict = {}
    for data_aug in augs:
        prefix = os.path.join(root, f'rendered_256x256/256x256/{tag}/{data_aug}')
        for cls_tag in classes:
            cls_dir = os.path.join(prefix, cls_tag)
            fig_names = [fig_name for fig_name in os.listdir(cls_dir) if fig_name.endswith('png') or fig_name.endswith('jpg')]
            num_figs = len(fig_names)
            split_num = int(num_figs * 0.8)
            if test_mode:
                fig_names = fig_names[split_num:]
            else:
                fig_names = fig_names[:split_num]
            for fig_name in fig_names:
                datas.append([data_aug, cls_tag, fig_name])
                if tag == 'photo':
                    fig_idx = fig_name[:-4]
                elif tag == 'sketch':
                    idx_pos = fig_name.find('-')
                    fig_idx = fig_name[:idx_pos]
                if fig_idx in datas_dict:
                    datas_dict[fig_idx].append([data_aug, cls_tag, fig_name])
                else:
                    datas_dict[fig_idx] = [[data_aug, cls_tag, fig_name]]
    return datas, datas_dict

class SketchyPair(BaseDataset):
    """Since there is prepared dataset class in pytorch, we just wrap it here.
    """
    # CLASSES = [
    #     'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
    #     'horse', 'ship', 'truck'
    # ]
    def __init__(self,
                 root,
                 photo_transforms,
                 sketch_transforms,
                 target_transforms=None,
                 used_ratio=1,
                 classes=None,
                 photo_augs=None,
                 sketch_augs=None,
                 ann_file=None,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if photo_transforms is None:
            photo_transforms = self.DEFAULT_TRNASFORMS
        self.photo_transforms = photo_transforms
        if sketch_transforms is None:
            sketch_transforms = self.DEFAULT_TRNASFORMS
        self.sketch_transforms = sketch_transforms
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.photo_augs = photo_augs
        self.sketch_augs = sketch_augs
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.used_ratio = used_ratio
        self.data_infos = self.load_annotations()
    
    def load_annotations(self):
        # rank, world_size = get_dist_info()
        # self.root: /home/lhy/datasets/sketchy/rendered_256x256
        
        # just use the dataset to get imgs and set
        photo_augs = self.photo_augs
        sketch_augs = self.sketch_augs
        if self.photo_augs is None:
            photo_augs = ['tx_000000000000']  #, 'tx_000100000000'
            
        if self.sketch_augs is None:
            sketch_augs = ['tx_000000000000', 'tx_000000000010', 'tx_000000000110', 
                       'tx_000000001010', 'tx_000000001110']  #, 'tx_000100000000'
            
        self.CLASSES = [cls_tag 
                        for cls_tag in os.listdir(os.path.join(self.root, 'rendered_256x256/256x256/photo/tx_000000000000'))
                        if not cls_tag.startswith('.')]
        photos_list, photos_dict = construct_data_list(self.root, self.CLASSES, photo_augs, 'photo', self.test_mode)
        sketchs_list, sketchs_dict = construct_data_list(self.root, self.CLASSES, sketch_augs, 'sketch', self.test_mode)
        print(f'Photo List Num:{len(photos_list)}, Photo Dict Num:{len(photos_dict)}')
        print(f'Sketch List Num:{len(sketchs_list)}, Sketch Dict Num:{len(sketchs_dict)}')

        photo_sketch_pairs = []
        for fig_idx in photos_dict:
            photos = photos_dict[fig_idx]
            if fig_idx not in sketchs_dict:
                continue
            sketches = sketchs_dict[fig_idx]
            for photo in photos:
                for sketch in sketches:
                    photo_sketch_pairs.append({'photo':photo, 'sketch':sketch})
        # add photos 
        data_infos = photo_sketch_pairs
        if self.used_ratio < 1:
            selected_idxs = np.random.choice(len(data_infos), int(self.used_ratio * len(data_infos)), replace=False)
            data_infos = np.array(data_infos)[selected_idxs]
        return data_infos

    def __getitem__(self, idx):
        photo, sketch = self.data_infos[idx]['photo'], self.data_infos[idx]['sketch']
        photo_path = os.path.join(self.root, 'rendered_256x256/256x256/photo/{}/{}/{}'.format(*photo))
        sketch_path = os.path.join(self.root, 'rendered_256x256/256x256/sketch/{}/{}/{}'.format(*sketch))
        if self.photo_transforms is not None:
            photo_img = Image.open(photo_path)
            photo_img = self.photo_transforms(photo_img)
        if self.sketch_transforms is not None:
            sketch_img = Image.open(sketch_path)
            sketch_img = self.sketch_transforms(sketch_img)
        # if self.photo_transforms is not None:
        #     photo_img = Image.open(photo_path)
        #     sketch_img = Image.open(sketch_path)
        #     photo_img, sketch_img = self.photo_transforms([photo_img, sketch_img])
        
        return photo_img, sketch_img


class SketchyPhoto(BaseDataset):
    """Since there is prepared dataset class in pytorch, we just wrap it here.
    """
    # CLASSES = [
    #     'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
    #     'horse', 'ship', 'truck'
    # ]
    def __init__(self,
                 root,
                 photo_transforms,
                 target_transforms=None,
                 used_ratio=1,
                 classes=None,
                 photo_augs=None,
                 ann_file=None,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if photo_transforms is None:
            photo_transforms = self.DEFAULT_TRNASFORMS
        self.photo_transforms = photo_transforms
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.photo_augs = photo_augs
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.used_ratio = used_ratio
        self.data_infos = self.load_annotations()
        
    def load_annotations(self):
        # rank, world_size = get_dist_info()
        # self.root: /home/lhy/datasets/sketchy/rendered_256x256
        # just use the dataset to get imgs and set
        photo_augs = self.photo_augs
        if self.photo_augs is None:
            photo_augs = ['tx_000000000000']  #, 'tx_000100000000'
            
        
        self.CLASSES = [cls_tag 
                        for cls_tag in os.listdir(os.path.join(self.root, 'rendered_256x256/256x256/photo/tx_000000000000'))
                        if not cls_tag.startswith('.')]
        self.class_dict = {cls_tag:i for i, cls_tag in enumerate(self.CLASSES)}
        photos_list, photos_dict = construct_data_list(self.root, self.CLASSES, photo_augs, 'photo', self.test_mode)
        print(f'Photo List Num:{len(photos_list)}, Photo Dict Num:{len(photos_dict)}')

        # add photos 
        data_infos = []
        for photo in photos_list:
            data_infos.append({'photo':photo, 'gt_label':np.array(self.class_dict[photo[1]]).astype(np.int32)})
        if self.used_ratio < 1:
            selected_idxs = np.random.choice(len(data_infos), int(self.used_ratio * len(data_infos)), replace=False)
            data_infos = np.array(data_infos)[selected_idxs]
        return data_infos

    def __getitem__(self, idx):
        photo, gt_label = self.data_infos[idx]['photo'], self.data_infos[idx]['gt_label']
        photo_path = os.path.join(self.root, 'rendered_256x256/256x256/photo/{}/{}/{}'.format(*photo))
        if self.photo_transforms is not None:
            photo_img = Image.open(photo_path)
            photo_img = self.photo_transforms(photo_img)
        if self.target_transforms is not None:
            gt_label = self.target_transforms(gt_label)
        return photo_img, gt_label


class SketchySketch(BaseDataset):
    """Since there is prepared dataset class in pytorch, we just wrap it here.
    """
    # CLASSES = [
    #     'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
    #     'horse', 'ship', 'truck'
    # ]
    def __init__(self,
                 root,
                 sketch_transforms,
                 target_transforms=None,
                 used_ratio=1,
                 classes=None,
                 sketch_augs=None,
                 ann_file=None,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if sketch_transforms is None:
            sketch_transforms = self.DEFAULT_TRNASFORMS
        self.sketch_transforms = sketch_transforms
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.sketch_augs = sketch_augs
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.used_ratio = used_ratio
        self.data_infos = self.load_annotations()
        
    def load_annotations(self):
        # rank, world_size = get_dist_info()
        # self.root: /home/lhy/datasets/sketchy/rendered_256x256
        # just use the dataset to get imgs and set
        sketch_augs = self.sketch_augs
        if self.sketch_augs is None:
            sketch_augs = ['tx_000000000000']  #, 'tx_000100000000'
        
        self.CLASSES = [cls_tag 
                        for cls_tag in os.listdir(os.path.join(self.root, 'rendered_256x256/256x256/sketch/tx_000000000000'))
                        if not cls_tag.startswith('.')]
        self.class_dict = {cls_tag:i for i, cls_tag in enumerate(self.CLASSES)}
        sketchs_list, sketchs_dict = construct_data_list(self.root, self.CLASSES, sketch_augs, 'sketch', self.test_mode)
        print(f'sketch List Num:{len(sketchs_list)}, sketch Dict Num:{len(sketchs_dict)}')

        # add sketchs 
        data_infos = []
        for sketch in sketchs_list:
            data_infos.append({'sketch':sketch, 'gt_label':np.array(self.class_dict[sketch[1]]).astype(np.int32)})
        if self.used_ratio < 1:
            selected_idxs = np.random.choice(len(data_infos), int(self.used_ratio * len(data_infos)), replace=False)
            data_infos = np.array(data_infos)[selected_idxs]
        return data_infos

    def __getitem__(self, idx):
        sketch, gt_label = self.data_infos[idx]['sketch'], self.data_infos[idx]['gt_label']
        sketch_path = os.path.join(self.root, 'rendered_256x256/256x256/sketch/{}/{}/{}'.format(*sketch))
        if self.sketch_transforms is not None:
            sketch_img = Image.open(sketch_path)
            sketch_img = self.sketch_transforms(sketch_img)
        if self.target_transforms is not None:
            gt_label = self.target_transforms(gt_label)
        return sketch_img, gt_label
