from data_utils.cremad_dataset import CremaDDataset
import numpy as np
import torch


class DataManager:
    def __init__(self, config):
        self.config = config

    def get_train_eval_dataloaders(self):
        np.random.seed(707)

        dataset = CremaDDataset(self.config)
        dataset_size = len(dataset)

        ## SPLIT DATASET
        train_split = self.config['train_size']
        train_size = int(train_split * dataset_size)
        validation_size = dataset_size - train_size

        ########### CURRENTLY DOING THIS, WHICH WORKS ###########
        indices = list(range(dataset_size))
        np.random.shuffle(indices)
        train_indices = indices[:train_size]
        temp = int(train_size + validation_size)
        val_indices = indices[train_size:temp]

        ## DATA LOARDER ##
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
        valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_indices)

        train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                   batch_size=self.config['batch_size'],
                                                   sampler=train_sampler)

        validation_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                        batch_size=1,
                                                        sampler=valid_sampler)
        return train_loader, validation_loader

    def get_train_eval_test_dataloaders(self, dataset_name='cremad'):
        np.random.seed(707)

        dataset = CremaDDataset(self.config)
        dataset_size = len(dataset)

        ## SPLIT DATASET
        train_split = self.config[dataset_name]['train_size']
        valid_split = self.config[dataset_name]['valid_size']

        train_size = int(train_split * dataset_size)
        valid_size = int(valid_split * dataset_size)

        ########### ESTABLISHING INDICES FOR DATALOADERS ###########
        indices = list(range(dataset_size))
        np.random.shuffle(indices)
        train_indices = indices[:train_size]
        valid_indices = indices[train_size:(train_size + valid_size)]
        test_indices = indices[(train_size + valid_size):]

        ## DATA LOARDER ##
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
        valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indices)
        test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_indices)

        train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                   batch_size=self.config['batch_size'],
                                                   sampler=train_sampler)

        validation_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                        batch_size=1,
                                                        sampler=valid_sampler)

        test_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                  batch_size=1,
                                                  sampler=test_sampler)

        return train_loader, validation_loader, test_loader
