#
# Copyright 2022- IBM Inc. All rights reserved
# SPDX-License-Identifier: Apache2.0
#
import os
import os.path as osp

import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import pdb


class DomainNet(Dataset):

    def __init__(self, root='./data',subset='real',train=True,
                 transform=None,
                 index_path=None, index=None, base_sess=None,
                 do_augment=True):
        if train:

            setname = subset + '_'+'train'
        else:
            setname = subset + '_' + 'test'


        self.root = os.path.expanduser(root)
        self.transform = transform
        self.train = train  # training set or test set
        self.IMAGE_PATH = os.path.join(root, subset)
        self.SPLIT_PATH = os.path.join(root, subset)

        csv_path = osp.join(self.SPLIT_PATH, setname + '.txt')

        self.data = []
        self.targets = []
        self.data2label = {}

        with open(csv_path,'r') as f:
            lines = f.readlines()


        for l in lines:
            name,label = l.strip('\n').split()
            path = osp.join(self.IMAGE_PATH, name)
            self.data.append(path)
            self.targets.append(int(label))
            self.data2label[path] = int(label)



        # lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
        #
        # self.data = []
        # self.targets = []
        # self.data2label = {}
        # lb = -1
        #
        # self.wnids = []
        #
        # for l in lines:
        #     name, wnid = l.split(',')
        #
        #     path = osp.join(self.IMAGE_PATH, name)
        #     if wnid not in self.wnids:
        #         self.wnids.append(wnid)
        #         lb += 1
        #     self.data.append(path)
        #     self.targets.append(lb)
        #     self.data2label[path] = lb


        image_size = 84
        if train and do_augment:
            self.transform = transforms.Compose([
                transforms.Resize([image_size,image_size]),
                # transforms.RandomResizedCrop(image_size),
                # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])])
        else:
            self.transform = transforms.Compose([
                transforms.Resize([84, 84]),
                # transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])])

        if train:
            if base_sess:
                self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
            else:

                self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)


        else:
            self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)


    def SelectfromTxt(self, data2label, index_path):

        with open(index_path,'r',encoding='utf-8') as f:
            lines = f.readlines()

        data_tmp = []
        targets_tmp = []

        for line in lines:
            pth,idx = line.strip('\n').split()
            img_path = os.path.join(self.IMAGE_PATH, pth)
            data_tmp.append(img_path)
            targets_tmp.append(data2label[img_path])
        return data_tmp, targets_tmp

    def SelectfromClasses(self, data, targets, index):
        data_tmp = []
        targets_tmp = []
        for i in index:
            ind_cl = np.where(i == targets)[0]
            for j in ind_cl:
                data_tmp.append(data[j])
                targets_tmp.append(targets[j])

        return data_tmp, targets_tmp

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):

        path, targets = self.data[i], self.targets[i]
        image = self.transform(Image.open(path).convert('RGB'))
        return image, targets


if __name__ == '__main__':
    txt_path = "../../../../data/index_list/mini_imagenet/session_1.txt"
    # class_index = open(txt_path).read().splitlines()
    base_class = 100
    class_index = np.arange(base_class)
    dataroot = '/Data_PHD/phd22_yijie_hu/DomainNet'
    batch_size_base = 400
    trainset = DomainNet(root=dataroot, train=True, transform=None, index_path=txt_path, base_sess=True, index=np.arange(10))
    dt,lb = trainset.__getitem__(100)
    print(dt.shape)
    # cls = np.unique(trainset.targets)
    # trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size_base, shuffle=True, num_workers=8,
    #                                           pin_memory=True)
    # print(len(trainloader.dataset.data))
