import copy
import torch
import os
import random
import logging
import numpy as np

from torch.utils.data import Dataset, ConcatDataset, DataLoader

from trainer.trainers import BaseTrainer
from trainer.metrics import MetricsSummary
from trainer.utils import DatasetShell
from trainer.callbacks import EarlyStopping


class CrossValidation(object):

    def __init__(self, trainer: BaseTrainer, dataset: Dataset, save_states=None):
        pass