from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as T
import matplotlib.pyplot as plt
import os, torch, time

class CustomDataset(Dataset):
    def __init__(self, txt_file, root_dir, transform=None, training=False):
        self.image_list = []
        self.id_list = []
        self.root_dir = root_dir
        self.transform = transform
        self.num_classes = 0
        self.training = training
        with open(txt_file, 'r') as f:
            line = f.readline()
            # self.datas = f.readlines()
            while line:
                img_name = line.split()[0]
                label = int(line.split()[1])
                # label = int(label)
                self.image_list.append(img_name)
                self.id_list.append(label)
                line = f.readline()
        self.num_classes = max(self.id_list)+1
        
    def __len__(self):
        return len(self.id_list)

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        label = self.id_list[idx]
        img_name = os.path.join(self.root_dir,img_name)
        image = Image.open(img_name).convert('RGB')

        if self.transform:
            image = self.transform(image)
        return image,label


def iNat2018(scale_size=256, target_size=224):
    path = "./DataSets/"
    train_txt = "./Dataset/iNat2018_train.txt"
    test_txt = "./Dataset/iNat2018_val.txt"

    train_transforms = T.Compose([
        T.RandomResizedCrop(target_size, interpolation=3),
        T.RandomHorizontalFlip(0.5),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    test_transforms = T.Compose([
        T.Resize(scale_size, interpolation=3),
        T.CenterCrop(target_size),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_data = CustomDataset(txt_file=train_txt, root_dir=path, transform=train_transforms)
    test_data = CustomDataset(txt_file=test_txt, root_dir=path, transform=test_transforms)
    num_class = 8142

    return train_data, test_data, num_class

# train_data, test_data, num_class = iNat2018()

# figure = plt.figure(figsize=(8, 8))
# cols, rows = 8, 8
# for i in range(1, cols * rows +1):
#     sample_idx = torch.randint(len(test_data), size=(1, )).item()
#     img, label = test_data[sample_idx]
#     plt.subplot(cols, rows, i)
#     plt.imshow(img.permute(1,2,0))
#     plt.title(str(label))
#     plt.axis('off')
# plt.savefig("./iNat2018.png", dpi=800, bbox_inches='tight')

# time1 = time.time()
# train_data, test_data, num_class = iNat2018()

# train_dataloader = DataLoader(
#     train_data,
#     batch_size=512,
#     num_workers=0,
#     shuffle=True,
#     drop_last=True
# )

# test_dataloader = DataLoader(
#     test_data,
#     batch_size=512,
#     num_workers=0,
#     shuffle=True
# )
# time2 = time.time()
# print(time2-time1)
# print(len(train_dataloader))
# print(len(test_dataloader))