from .base import *

import numpy as np, os, sys, pandas as pd, csv, copy
import torch
import torchvision
import PIL.Image


class Inshop_Dataset_simple_split(torch.utils.data.Dataset):
    def __init__(self, root, mode, transform=None):
        self.root = root + "/InShop"  # self.root = root + "/Inshop_Clothes"
        self.mode = mode
        self.transform = transform
        self.train_ys, self.train_im_paths = [], []
        self.test_ys, self.test_im_paths = [], []

        data_info = np.array(
            pd.read_table(
                self.root
                + "/list_eval_partition.txt",  # self.root + "/Eval/list_eval_partition.txt",
                header=1,
                delim_whitespace=True,
            )
        )[:, :]
        # Separate into training dataset and test=(query+gallery) dataset for testing.
        train, test = (
            data_info[data_info[:, 2] == "train"][:, :2],
            data_info[(data_info[:, 2] == "query") | (data_info[:, 2] == "gallery")][
                :, :2
            ],
        )

        # Generate conversions
        lab_conv = {
            x: i
            for i, x in enumerate(
                np.unique(np.array([int(x.split("_")[-1]) for x in train[:, 1]]))
            )
        }
        train[:, 1] = np.array([lab_conv[int(x.split("_")[-1])] for x in train[:, 1]])

        lab_conv = {
            x: i
            for i, x in enumerate(
                np.unique(np.array([int(x.split("_")[-1]) for x in test[:, 1]]))
            )
        }
        test[:, 1] = np.array([lab_conv[int(x.split("_")[-1])] for x in test[:, 1]])

        # Generate Image-Dicts for training and testing=(query+gallery) of shape {class_idx:[list of paths to images belong to this class] ...}
        for img_path, key in train:
            self.train_im_paths.append(
                os.path.join(self.root, img_path)
            )  # self.train_im_paths.append(os.path.join(self.root, "Img", img_path))
            self.train_ys += [int(key)]

        for img_path, key in test:
            self.test_im_paths.append(
                os.path.join(self.root, img_path)
            )  # self.test_im_paths.append(os.path.join(self.root, "Img", img_path))
            self.test_ys += [int(key)]

        if self.mode == "train":
            self.im_paths = self.train_im_paths
            self.ys = self.train_ys
        elif self.mode == "eval":
            self.im_paths = self.test_im_paths
            self.ys = self.test_ys

    def nb_classes(self):
        return len(set(self.ys))

    def __len__(self):
        return len(self.ys)

    def __getitem__(self, index):
        def img_load(index):
            im = PIL.Image.open(self.im_paths[index])
            # convert gray to rgb
            if len(list(im.split())) == 1:
                im = im.convert("RGB")
            if self.transform is not None:
                im = self.transform(im)
            return im

        im = img_load(index)
        target = self.ys[index]

        return im, target
