import random
import torch
import torch.nn as nn
from utils.file import load

from torch.utils.data import Dataset
from pathlib import Path

from utils.datasets.surnames_utils import NamesDataset


class _IndexedSurnames(Dataset):
    def __init__(self, train=True, indexed=True):
        train_set, test_set = load("Datasets/surnames.pkl")

        if not train:
            indices = test_set.indices
            self.data = [test_set.dataset.samples[i][0] for i in indices]
            self.targets = torch.Tensor([test_set.dataset.samples[i][1] for i in indices]).int()
        else:
            indices = train_set.indices
            self.data = [train_set.dataset.samples[i][0] for i in indices]
            if not indexed:
                self.targets = torch.Tensor([train_set.dataset.samples[i][1] for i in indices]).int()
            else:
                self.targets = torch.Tensor([[train_set.dataset.samples[i][1], j] for j, i in enumerate(indices)]).int()
        

    def __len__(self):
        return len(self.targets)
    

    def __getitem__(self, index):
        datum, target = self.data[index], self.targets[index].int()
        return datum, target


def load_surnames(which: str):
    if which == 'train':
        return _IndexedSurnames(train=True, indexed=True)
    elif which == 'test':
        return _IndexedSurnames(train=False, indexed=False)
    else:
        raise ValueError("Can only choose between 'train' and 'test'.")
    