import os
import time
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from data.base_dataset import NewsDataset
from models import load_backbone

from common import CKPT_PATH

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_base_dataset(data_name, tokenizer, data_ratio=1.0, seed=0):
    print('Initializing base dataset... (name: {})'.format(data_name))

    # Text Classifications
    if data_name == 'news':
        dataset = NewsDataset(tokenizer, data_ratio, seed)

    return dataset