from training.erm import erm
from training.ours import ours
from training.oracle import oracle
from training.eiil import eiil
from training.george import george
from training.lff import lff
from training.dg_mmld import dg_mmld
from training.dann import dann
from training.mmd import mmd
from training.metric_learn import metric_learning
from training.cdann import cdann


def train_val_test(train_data, test_data, model, opt, args,
                   train_partition_loaders=None,
                   val_partition_loaders=None,
                   partition_model=None,
                   train_ebd=None):
    if args.method == 'erm':
        return erm(train_data, test_data, model, opt, args,
                   train_ebd=train_ebd)

    if args.method == 'metric':
        return metric_learning(train_data, test_data, model, opt, args,
                   train_ebd=train_ebd)

    if args.method == 'ours':
        return ours(train_data, test_data, model, opt, args,
                    partition_model=partition_model,
                    train_partition_loaders=train_partition_loaders,
                    val_partition_loaders=val_partition_loaders,
                    train_ebd=train_ebd)

    if args.method == 'eiil':
        return eiil(train_data, test_data, model, opt, args)

    if args.method == 'george':
        return george(train_data, test_data, model, opt, args)

    if args.method == 'dg_mmld':
        return dg_mmld(train_data, test_data, model, opt, args)

    if args.method == 'lff':
        return lff(train_data, test_data, model, opt, args)

    if args.method == 'dann':
        return dann(train_data, test_data, model, opt, args)

    if args.method == 'mmd':
        return mmd(train_data, test_data, model, opt, args)

    if args.method == 'cdann':
        return cdann(train_data, test_data, model, opt, args)

    if args.method == 'oracle':
        return oracle(train_data, test_data, model, opt, args,
                      partition_model)

    raise ValueError('method {} is not impelmented'.format(args.method))

