import torch
import random
import numpy as np
from .dataset import *


def seed_torch(seed=2024):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_dataset(dataset_name, tokenizer):
    if dataset_name == 'sst2':
        train_datset = SST2Dataset(mode='train').get_dataset(tokenizer)
        dev_dataset = SST2Dataset(mode='dev').get_dataset(tokenizer)
        test_dataset = SST2Dataset(mode='test').get_dataset(tokenizer)
        label_num = SST2Dataset(mode='train').label_num
    elif dataset_name == 'cola':
        train_datset = CoLADataset(mode='train').get_dataset(tokenizer)
        dev_dataset = CoLADataset(mode='dev').get_dataset(tokenizer)
        test_dataset = CoLADataset(mode='test').get_dataset(tokenizer)
        label_num = CoLADataset(mode='train').label_num
    elif dataset_name == 'mrpc':
        train_datset = MRPCDataset(mode='train').get_dataset(tokenizer)
        dev_dataset = MRPCDataset(mode='dev').get_dataset(tokenizer)
        test_dataset = MRPCDataset(mode='test').get_dataset(tokenizer)
        label_num = MRPCDataset(mode='train').label_num
    elif dataset_name == 'qqp':
        train_datset = QQPDataset(mode='train').get_dataset(tokenizer)
        dev_dataset = QQPDataset(mode='dev').get_dataset(tokenizer)
        test_dataset = QQPDataset(mode='test').get_dataset(tokenizer)
        label_num = QQPDataset(mode='train').label_num
    elif dataset_name == 'stsb':
        train_datset = STSBDataset(mode='train').get_dataset(tokenizer)
        dev_dataset = STSBDataset(mode='dev').get_dataset(tokenizer)
        test_dataset = STSBDataset(mode='test').get_dataset(tokenizer)
        label_num = 1
    elif dataset_name == 'mnli':
        train_datset = MNLIDataset(mode='train').get_dataset(tokenizer)
        dev_dataset = MNLIDataset(mode='dev').get_dataset(tokenizer)
        test_dataset = MNLIDataset(mode='test').get_dataset(tokenizer)
        label_num = MNLIDataset(mode='train').label_num
    elif dataset_name == 'mnli_mismatched':
        train_datset = MnliMismatchedSDataset(mode='train').get_dataset(tokenizer)
        dev_dataset = MnliMismatchedSDataset(mode='dev').get_dataset(tokenizer)
        test_dataset = MnliMismatchedSDataset(mode='test').get_dataset(tokenizer)
        label_num = MnliMismatchedSDataset(mode='train').label_num
    elif dataset_name == 'qnli':
        train_datset = QNLIDataset(mode='train').get_dataset(tokenizer)
        dev_dataset = QNLIDataset(mode='dev').get_dataset(tokenizer)
        test_dataset = QNLIDataset(mode='test').get_dataset(tokenizer)
        label_num = QNLIDataset(mode='train').label_num
    elif dataset_name == 'rte':
        train_datset = RTEDataset(mode='train').get_dataset(tokenizer)
        dev_dataset = RTEDataset(mode='dev').get_dataset(tokenizer)
        test_dataset = RTEDataset(mode='test').get_dataset(tokenizer)
        label_num = RTEDataset(mode='train').label_num
    elif dataset_name == 'wnli':
        train_datset = WNLIDataset(mode='train').get_dataset(tokenizer)
        dev_dataset = WNLIDataset(mode='dev').get_dataset(tokenizer)
        test_dataset = WNLIDataset(mode='test').get_dataset(tokenizer)
        label_num = WNLIDataset(mode='train').label_num
    else:
        raise ValueError("{} is invalid!".format(dataset_name))

    return train_datset, dev_dataset, test_dataset, label_num
