import os
import torch
import pickle
import numpy as np
from torch.utils.data import DataLoader, Dataset

from .extra.affect_dataset import AffectDataset
from .extra.mmimdb_dataset import MMIMDBDataset
from .extra.food101_dataset import Food101Dataset
from .extra.hatememes_dataset import HatememesDataset

class ClassificationDataModule(Dataset):
    def __init__(self, dataset, data_dir, device, data_config):
        super().__init__()

        # DataModule variables;
        self.dataset = dataset
        self.data_dir = data_dir
        self.device = device
        self.classification = data_config.classification
        self.data_config = data_config
        # Data-specific variables - fill with setup function;
        self.transform = None
        self.train_data, self.val_data, self.test_data = None, None, None
        self.setup(stage=data_config.stage)

    def prepare_data(self):
        # download
        if self.dataset == "mosei":
            train_data_file = os.path.join(self.data_dir, "mosei_train_a.dt")
            if not os.path.exists(train_data_file):
                raise RuntimeError('MOSEI Dataset not found.')

        elif self.dataset == "mosi":
            train_data_file = os.path.join(self.data_dir, "mosi_train_a.dt")
            print(train_data_file)
            if not os.path.exists(train_data_file):
                raise RuntimeError('MOSI Dataset not found.')
        else:
            raise ValueError(
                "[Classification Dataset] Selected dataset: " + str(self.dataset) + " not implemented.")

    def setup(self, stage=None):
        if self.dataset in ["mosei", "mosi"]:
            self.train_data = AffectDataset(self.data_dir, dataset=self.dataset, split_type='train', device=self.device, labeled_ratio=self.data_config.labeled_ratio, classification=self.classification, train_classifier_only=self.data_config.train_classifier_only, transfer=self.data_config.transfer_experiment)
            self.val_data = AffectDataset(self.data_dir, dataset=self.dataset, split_type='valid', device=self.device, classification=self.classification, transfer=self.data_config.transfer_experiment)
            if stage == "eval_classifier":
                self.test_data = AffectDataset(self.data_dir, dataset=self.dataset, split_type='test', device=self.device, classification=self.classification, transfer=self.data_config.transfer_experiment)
        elif self.dataset in ['mmimdb']:
            self.train_data = MMIMDBDataset(self.data_dir, data=self.dataset, split_type='train', device=self.device, labeled_ratio=self.data_config.labeled_ratio)
            self.val_data = MMIMDBDataset(self.data_dir, data=self.dataset, split_type='dev',device=self.device)
            if stage == "eval_classifier":
                self.test_data = MMIMDBDataset(self.data_dir, data=self.dataset, split_type='test', device=self.device)
        elif self.dataset in ['food101']:
            self.train_data = Food101Dataset(self.data_dir, data=self.dataset, split_type='train', device=self.device, labeled_ratio=self.data_config.labeled_ratio)
            self.val_data = Food101Dataset(self.data_dir, data=self.dataset, split_type='val',device=self.device)
            if stage == "eval_classifier":
                self.test_data = Food101Dataset(self.data_dir, data=self.dataset, split_type='test', device=self.device)
        elif self.dataset in ['hatememes']:
            self.train_data = HatememesDataset(self.data_dir, data=self.dataset, split_type='train', device=self.device, labeled_ratio=self.data_config.labeled_ratio, train_classifier_only=self.data_config.train_classifier_only)
            self.val_data = HatememesDataset(self.data_dir, data=self.dataset, split_type='dev',device=self.device)
            if stage == "eval_classifier":
                self.test_data = HatememesDataset(self.data_dir, data=self.dataset, split_type='test', device=self.device)
        else:
            raise ValueError(
                "[Classification Dataset] Selected dataset: " + str(self.dataset) + " not implemented.")


    def train_dataloader(self):
        if self.dataset in ["mosei", "mosi"]:
            return DataLoader(
                self.train_data,
                batch_size=self.data_config.batch_size,
                shuffle=True,
                num_workers=self.data_config.num_workers,
            )
        elif self.dataset in ['mmimdb', 'food101', 'hatememes']:
            return DataLoader(self.train_data,
                    batch_size=self.data_config.batch_size,
                    shuffle=True,
                    num_workers=self.data_config.num_workers,
                    pin_memory=True,
                    # collate_fn=self.train_data._collate
                    )
        else:
            raise ValueError(
                "[Classification Dataset] Selected dataset: " + str(self.dataset) + " not implemented.")


    def val_dataloader(self):
        if self.dataset in ["mosei", "mosi"]:
            return DataLoader(
                self.val_data,
                batch_size=self.data_config.batch_size,
                shuffle=False,
                num_workers=self.data_config.num_workers,
            )
        elif self.dataset in ['mmimdb', 'food101', 'hatememes']:
            return DataLoader(self.val_data,
                    batch_size=self.data_config.batch_size,
                    shuffle=False,
                    num_workers=self.data_config.num_workers,
                    pin_memory=True,
                    # collate_fn=self.train_data._collate
                    )
        else:
            raise ValueError(
                "[Classification Dataset] Selected dataset: " + str(self.dataset) + " not implemented.")

    def test_dataloader(self):
        if self.dataset in ["mosei", "mosi"]:
            return DataLoader(
                self.test_data,
                batch_size=self.data_config.batch_size,
                shuffle=False,
                num_workers=self.data_config.num_workers,
            )
        elif self.dataset in ['mmimdb', 'food101', 'hatememes']:
            return DataLoader(self.test_data,
                    batch_size=self.data_config.inference_batch_size,
                    shuffle=False,
                    num_workers=self.data_config.num_workers,
                    pin_memory=True,
                    # collate_fn=self.train_data._collate
                    )
        else:
            raise ValueError(
                "[Classification Dataset] Selected dataset: " + str(self.dataset) + " not implemented.")


class DCADataModule(Dataset):
    def __init__(self, dataset, data_dir, device, data_config):
        super().__init__()

        # DataModule variables;
        self.dataset = dataset
        self.data_dir = data_dir
        self.device = device
        self.classification = data_config.classification
        self.data_config = data_config
        # Data-specific variables - fill with setup function;
        self.transform = None
        self.train_data, self.val_data, self.test_data = None, None, None
        self.dca_partial_eval_indices = None
        self.setup(stage=data_config.stage)

    def set_dca_eval_sample_indices(self):
        print('Setting DCA eval sample indices')
        if self.dca_partial_eval_indices is None:
            self.dca_partial_eval_indices = np.random.choice(
                list(range(len(self.test_data))),
                self.data_config.n_dca_samples,
                replace=False,
            )

    def setup(self, stage=None):
        if self.dataset in ["mosei", "mosi"]:
            self.train_data = AffectDataset(self.data_dir, dataset=self.dataset, split_type='train', device=self.device, labeled_ratio=self.data_config.labeled_ratio, classification=self.classification, transfer=self.data_config.transfer_experiment)
            self.val_data = AffectDataset(self.data_dir, dataset=self.dataset, split_type='valid', device=self.device, classification=self.classification, transfer=self.data_config.transfer_experiment)
            if stage == "eval_dca":
                self.test_data = AffectDataset(self.data_dir, dataset=self.dataset, split_type='test', device=self.device, classification=self.classification, transfer=self.data_config.transfer_experiment)
        elif self.dataset in ['mmimdb']:
            self.train_data = MMIMDBDataset(self.data_dir, data=self.dataset, split_type='train', device=self.device, labeled_ratio=self.data_config.labeled_ratio)
            self.val_data = MMIMDBDataset(self.data_dir, data=self.dataset, split_type='dev',device=self.device)
            if stage == "eval_dca":
                self.test_data = MMIMDBDataset(self.data_dir, data=self.dataset, split_type='test', device=self.device)
        elif self.dataset in ['food101']:
            self.train_data = Food101Dataset(self.data_dir, data=self.dataset, split_type='train', device=self.device, labeled_ratio=self.data_config.labeled_ratio)
            self.val_data = Food101Dataset(self.data_dir, data=self.dataset, split_type='val',device=self.device)
            if stage == "eval_dca":
                self.test_data = Food101Dataset(self.data_dir, data=self.dataset, split_type='test', device=self.device)
        elif self.dataset in ['hatememes']:
            self.train_data = HatememesDataset(self.data_dir, data=self.dataset, split_type='train', device=self.device, labeled_ratio=self.data_config.labeled_ratio)
            self.val_data = HatememesDataset(self.data_dir, data=self.dataset, split_type='dev',device=self.device)
            if stage == "eval_dca":
                self.test_data = HatememesDataset(self.data_dir, data=self.dataset, split_type='test', device=self.device)
        else:
            raise ValueError(
                "[DCA Dataset] Selected dataset: " + str(self.dataset) + " not implemented.")

        if self.test_data is not None:
            self.set_dca_eval_sample_indices()
            self.partial_test_sampler = torch.utils.data.SubsetRandomSampler(
                self.dca_partial_eval_indices
            )

    def train_dataloader(self):
        if self.dataset in ["mosei", "mosi"]:
            return DataLoader(
                self.train_data,
                batch_size=self.data_config.batch_size,
                shuffle=True,
                num_workers=self.data_config.num_workers,
            )
        elif self.dataset in ['mmimdb', 'food101', 'hatememes']:
            return DataLoader(self.train_data,
                    batch_size=self.data_config.batch_size,
                    shuffle=True,
                    num_workers=self.data_config.num_workers,
                    pin_memory=True,
                    # collate_fn=self.train_data._collate
                    )
        else:
            raise ValueError(
                "[DCA Dataset] Selected dataset: " + str(self.dataset) + " not implemented.")

    def test_dataloader(self):
        if self.dataset in ["mosei", "mosi"]:
            return DataLoader(
                self.test_data,
                batch_size=self.data_config.batch_size,
                shuffle=False,
                num_workers=self.data_config.num_workers,
                sampler=self.partial_test_sampler,
                drop_last=False,
            )
        elif self.dataset in ['mmimdb', 'food101', 'hatememes']:
            return DataLoader(self.test_data,
                    batch_size=self.data_config.inference_batch_size,
                    shuffle=False,
                    num_workers=self.data_config.num_workers,
                    sampler=self.partial_test_sampler,
                    drop_last=False,
                    pin_memory=True,
                    # collate_fn=self.train_data._collate
                    )
        else:
            raise ValueError(
                "[DCA Dataset] Selected dataset: " + str(self.dataset) + " not implemented.")