import torch
import torchvision
import yaml

from opts import parse_opts
from util_scripts.general_utils import setup_model, setup_data_module, setup_trainer, setup_evaluator

def get_opt():
    opt, _ = parse_opts()
    dct = opt.__dict__
    with open(f'configs/{opt.dataset}.yml', 'r') as f:
        additional_opt = yaml.load(f, Loader=yaml.FullLoader)
        for k, v in additional_opt.items():
            dct[k] = v
    return opt

def train_model(opt):
    # init model
    model = setup_model(opt)
    # init data
    data_module = setup_data_module(opt)
    # trainer
    trainer = setup_trainer(model, data_module, opt)

    trainer.fit()

def eval_classifier(opt):
    # init model
    model = setup_model(opt)
    # init data
    data_module = setup_data_module(opt)
    test_dataloader = data_module.test_dataloader()
    
    affect_evaluator = setup_evaluator(model, opt, test_dataloader=test_dataloader, last_cpkt=opt.last_cpkt)
    affect_evaluator.evaluate()

def eval_dca(opt):
    # init model
    model = setup_model(opt)
    # init data
    data_module = setup_data_module(opt)
    # trainer
    trainer = setup_trainer(model, data_module, opt)

    trainer.evaluate_dca()

def main(opt):
    if opt.stage == 'train_model':
        train_model(opt)
    elif opt.stage == 'eval_classifier':
        eval_classifier(opt)
    elif opt.stage == 'eval_dca':
        eval_dca(opt)
    else:
        raise ValueError(
            "Incorrect stage of pipeline selected: " + str(opt.stage)
        )

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

if __name__ == '__main__':
    opt = get_opt()
    opt.device = torch.device('cuda:0' if torch.cuda.is_available() and opt.device == 'cuda' else 'cpu')
    opt.ngpus_per_node = torch.cuda.device_count()
    
    seed_everything(opt.manual_seed)
    # main logic
    main(opt)
    torch.cuda.empty_cache()