import numpy as np
import torch
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from random import sample, random
from arguments import args, parse_args
import os
from os.path import join, dirname


def get_random_subset(names, labels, percent):
    """

    :param names: list of names
    :param labels:  list of labels
    :param percent: 0 < float < 1
    :return:
    """
    samples = len(names)
    amount = int(samples * percent)
    random_index = sample(range(samples), amount)
    name_val = [names[k] for k in random_index]
    name_train = [v for k, v in enumerate(names) if k not in random_index]
    labels_val = [labels[k] for k in random_index]
    labels_train = [v for k, v in enumerate(labels) if k not in random_index]
    return name_train, name_val, labels_train, labels_val


def _dataset_info(txt_labels):
    with open(txt_labels, 'r') as f:
        images_list = f.readlines()

    file_names = []
    labels = []
    for row in images_list:
        row = row.split(' ')
        file_names.append(row[0])
        labels.append(int(row[1]))

    return file_names, labels


def get_split_dataset_info(txt_list, val_percentage):
    names, labels = _dataset_info(txt_list)
    return get_random_subset(names, labels, val_percentage)


class CustomDataset(data.Dataset):
    def __init__(self, names, labels, jig_classes=100, img_transformer=None, patches=True):
        cwd = os.getcwd()
        cwd[:-8]
        if args.dataset == 'VLCS':
            self.data_path = join(dirname(__file__), "VLCS/Raw images")
        elif args.dataset == 'PACS':
            self.data_path = join(dirname(__file__), "pacs/images")
        else:
            self.data_path = join(dirname(__file__), "OfficeHome")
        
        self.names = names
        self.labels = labels

        self.N = len(self.names)

        self._image_transformer = img_transformer
        if patches:
            self.patch_size = 64
            self.returnFunc = lambda x: x

    def get_image(self, index):
        framename = self.data_path + '/' + self.names[index]
        img = Image.open(framename).convert('RGB')
        return self._image_transformer(img)

    def __getitem__(self, index):
        framename = self.data_path + '/' + self.names[index]
        img = Image.open(framename).convert('RGB')
        
       
        if args.dataset == 'PACS':
            onehot = torch.FloatTensor(7).fill_(0)
            onehot[int(self.labels[index] - 1)] = 1
            return self._image_transformer(img), onehot, int(self.labels[index] - 1)
        elif args.dataset == 'OfficeHome':
            onehot = torch.FloatTensor(65).fill_(0)
            onehot[int(self.labels[index] - 1)] = 1
            return self._image_transformer(img), onehot, int(self.labels[index])
        else:
            onehot = torch.FloatTensor(5).fill_(0)
            onehot[int(self.labels[index])] = 1
            return self._image_transformer(img), onehot, int(self.labels[index])

    def __len__(self):
        return len(self.names)

class TestCustomDataset(CustomDataset):
    def __init__(self, *args, **xargs):
        super().__init__(*args, **xargs)

    def __getitem__(self, index):
        framename = self.data_path + '/' + self.names[index]
        img = Image.open(framename).convert('RGB')
        if args.dataset == 'PACS':
            onehot = torch.FloatTensor(7).fill_(0)
            onehot[int(self.labels[index] - 1)] = 1
            return self._image_transformer(img), onehot, int(self.labels[index] - 1)
        elif args.dataset == 'OfficeHome':
            onehot = torch.FloatTensor(65).fill_(0)
            onehot[int(self.labels[index])] = 1
            return self._image_transformer(img), onehot, int(self.labels[index])
        else:
            onehot = torch.FloatTensor(5).fill_(0)
            onehot[int(self.labels[index])] = 1
            return self._image_transformer(img), onehot, int(self.labels[index])
        
