import os

import numpy as np
import torch
from torch.utils.data import ConcatDataset, Dataset, Subset
from torchvision.datasets import ImageFolder


class FaceScrubLatents(Dataset):

    def __init__(self,
                 group,
                 train,
                 latent_file='datasets/facescrub_latents.npy',
                 split_seed=42,
                 transform=None,
                 cropped=True,
                 root='data/facescrub'):

        self.latents_file = np.load(latent_file, allow_pickle=True).tolist()

        if group == 'actors':
            if cropped:
                root = os.path.join(root, 'actors/faces')
            else:
                root = os.path.join(root, 'actors/images')
            folder_paths = sorted(
                [os.path.join(root, d) for d in os.listdir(root)])
            self.name = 'facescrub_actors_latents'

        elif group == 'actresses':
            if cropped:
                root = os.path.join(root, 'actresses/faces')
            else:
                root = os.path.join(root, 'actresses/images')
            folder_paths = sorted(
                [os.path.join(root, d) for d in os.listdir(root)])
            self.name = 'facescrub_actresses_latents'

        elif group == 'all':
            if cropped:
                root_actors = os.path.join(root, 'actors/faces')
                root_actresses = os.path.join(root, 'actresses/faces')
            else:
                root_actors = os.path.join(root, 'actors/images')
                root_actresses = os.path.join(root, 'actresses/images')
            folder_paths = sorted([
                os.path.join(root_actors, d) for d in os.listdir(root_actors)
            ])
            folder_paths += sorted([
                os.path.join(root_actresses, d)
                for d in os.listdir(root_actresses)
            ])
            self.name = 'facescrub_all_latents'
        else:
            raise ValueError(
                f'Dataset group {group} not found. Valid arguments are \'all\', \'actors\' and \'actresses\'.'
            )

        # TODO: Load latents after train-test split

        targets = []
        latents = []
        for target, folder in enumerate(folder_paths):
            files = sorted(
                [os.path.join(folder, file) for file in os.listdir(folder)])
            for file in files:
                if file.split('.')[-1] == '.gif':
                    continue
                file = file.split('/')[-1]
                try:
                    latents.append(self.latents_file[file][-1])
                    targets.append(target)
                except:
                    file = file.replace(file.split('.')[-1], 'jpeg')
                    latents.append(self.latents_file[file][-1])
                    targets.append(target)

        self.transform = transform
        indices = list(range(len(latents)))

        np.random.seed(split_seed)
        np.random.shuffle(indices)
        training_set_size = int(0.9 * len(latents))
        train_idx = indices[:training_set_size]
        test_idx = indices[training_set_size:]

        if train:
            self.latents = torch.from_numpy(np.array(latents)[train_idx])
            self.targets = np.array(targets)[train_idx].tolist()
        else:
            self.latents = torch.from_numpy(np.array(latents)[test_idx])
            self.targets = np.array(targets)[test_idx].tolist()

        assert len(self.latents) == len(self.targets)

    def __len__(self):
        return len(self.latents)

    def __getitem__(self, idx):
        latent = self.latents[idx]
        if self.transform:
            return self.transform(latent).squeeze(0), self.targets[idx]
        else:
            return latent, self.targets[idx]
