from models.gmc import AffectGMC as GMC
from models.vilt import ViLT
from data_modules.classification_dataset import ClassificationDataModule, DCADataModule
from models.trainers.model_trainer import ModelTrainer
from models.trainers.dca_evaluation_trainer import DCAEvaluator
from models.trainers.model_evaluation import ModelEvaluation

def setup_model(opt):
    if opt.dataset in ['mosi', 'mosei']:
        model = GMC(name='gmc',
                    common_dim=opt.common_dim,
                    latent_dim=opt.latent_dim,
                    dataset=opt.dataset,
                    n_classes=opt.n_classes,
                    transfer=opt.transfer_experiment
                    )
    elif opt.dataset in ['mmimdb', 'food101', 'hatememes']:
        model = ViLT(
            name='vilt',
            num_classes=opt.n_classes,
            common_dim=opt.common_dim,
            latent_dim=opt.latent_dim,
            dataset=opt.dataset,
        )

    return model.to(opt.device)

def setup_data_module(opt):
    if opt.dataset in ['mosi', 'mosei', 'mmimdb', 'food101', 'hatememes']:
        if opt.stage == "eval_dca":
            return DCADataModule(dataset=opt.dataset,
                                data_dir=opt.root_path,
                                device=opt.device,
                                data_config=opt,)
        return ClassificationDataModule(opt.dataset,
                                        data_dir=opt.root_path,
                                        device=opt.device,
                                        data_config=opt)
    else:
        raise ValueError(
            "[Data Module Setup] Selected Module not yet implemented: " + str(opt.dataset)
        )
    
def setup_trainer(model, data_module, opt):
    if opt.dataset in ['mosi', 'mosei', 'mmimdb', 'food101', 'hatememes']:
        if opt.stage == "eval_dca":
            return DCAEvaluator(
                model=model,
                dataset=opt.dataset,
                data_module=data_module,
                opt=opt
            )
        return ModelTrainer(
        model=model,
        dataset=opt.dataset,
        data_module=data_module,
        opt=opt
    )
    else:
        raise ValueError(
            "[Trainer Setup] Selected Trainer not yet implemented: " + str(opt.dataset)
        )
    
def setup_evaluator(model, opt, test_dataloader=None, class_names=None, train_dataloader=None, last_cpkt=False):
    if opt.dataset in ['mosi', 'mosei', 'mmimdb', 'food101', 'hatememes']:
        return ModelEvaluation(model=model,
                            dataset=opt.dataset,
                            test_loader=test_dataloader,
                            opt=opt,
                            modalities=opt.evaluation_modals,
                            last_cpkt=last_cpkt)
    else:
        raise ValueError(
            "[Evaluator Setup] Selected Evaluator not yet implemented: " + str(opt.dataset)
        )