import torch
import wandb
import pytorch_lightning as pl
from lightning.pytorch.loggers import WandbLogger

from data import build_data
from model import build_model, load_model
from baseline_trainer import build_baseline_trainer
from utils import LOGS_DIR, ENTITY, PROJECT, build_fairret


def execute(config=None):
    if 'num_threads' in config:
        print(f"Setting number of threads to {config['num_threads']}")
        torch.set_num_threads(config['num_threads'])

    logger = WandbLogger(entity=ENTITY, project=PROJECT, save_dir=LOGS_DIR, config=config, log_model=True)
    config = logger.experiment.config

    if len(dict(config)) == 0:
        raise ValueError("No config provided! Please provide a .yaml config file path.")
    print(dict(config))

    pl.seed_everything(config['seed'])

    data = build_data(**config['data'])

    if 'fairret' in config:
        fairret = build_fairret(**config['fairret'])
    else:
        fairret = None

    if 'load_weights' in config:
        model = load_model(config['load_weights'], **config['model'], fairret=fairret, **data.info())
    else:
        model = build_model(**config['model'], fairret=fairret, **data.info())

    if 'baseline' in config:
        trainer = build_baseline_trainer(**config['baseline'], trainer=(config['trainer'] |
                                                                        {'logger': logger,
                                                                         'accelerator': 'cpu'}))
    else:
        trainer = pl.Trainer(**config['trainer'], logger=logger, accelerator='cpu')

    trainer.fit(model, data)
    trainer.test(model, data.test_dataloader())

    wandb.finish()
