import hydra
import torch
import time
import os

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import get_vocab_size
from multiguide.training.helpers import get_data_loaders, setup_wandb, \
                                      construct_model, \
                                      train_property_predictor_epoch, \
                                      save_checkpoint, evaluate_property_predictor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@hydra.main(config_path='../configs', config_name='config.yaml')
def train_property_predictor(config):
    if config.classifier_guidance.train.model_log_var:
        assert config.classifier_guidance.train.loss=='nll', \
            f'Should have nll loss if model_log_var is true, got {config.classifier_guidance.train.loss}'
    else:
        assert config.classifier_guidance.train.loss=='mse' or config.classifier_guidance.train.loss=='ce', \
            f'Should have mse loss if model_log_var is false, got {config.classifier_guidance.train.loss}'
    print('Starting training for classifier_guidance: ', config.classifier_guidance)
    print(f'Batch sizes: {config.classifier_guidance.dataset.batch_sizes}')
    print(f'device: {device}')
    print(f'num_epochs: {config.classifier_guidance.train.num_epochs}')
    run = setup_wandb(config) 
    print('Loading data...')
    start_time = time.time()
    train_loader, val_loader, target_mean, target_std = get_data_loaders(config)
    print(f'Time taken to construct datasets: {time.time() - start_time}')
    print('Constructing model...')
    start_time = time.time()
    vocab_size = get_vocab_size(config)
    print(f'Vocab size: {vocab_size}')
    model = construct_model(config, vocab_size)
    model = model.to(device)
    print(f'Time taken to construct model: {time.time() - start_time}')
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config.classifier_guidance.train.learning_rate,
                                 weight_decay=config.classifier_guidance.train.weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.2, patience=3, verbose=True
    ) if config.classifier_guidance.train.use_scheduler else None
    
    # if resume, load the model from the checkpoint
    if config.classifier_guidance.train.resume and config.classifier_guidance.train.resume_path is not None:
        checkpoint_path = os.path.join(PROJECT_ROOT,
                                        'experiments', 
                                        config.classifier_guidance.experiment_name,
                                        'checkpoints',
                                        config.classifier_guidance.train.resume_path)
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler: scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        # TODO: check what else is in the checkpoint file
        start_epoch = checkpoint['epoch']
        print(f'Resuming training from epoch {start_epoch}')
    else:
        start_epoch = 0
    # save the target mean and std
    print('Starting training....')
    for epoch in range(start_epoch, config.classifier_guidance.train.num_epochs):
        train_property_predictor_epoch(model, train_loader, optimizer,
                                       target_mean, target_std, config, device, epoch, run)
        if epoch % config.classifier_guidance.train.eval_interval == 0:
            print(f'======= Evaluating at epoch {epoch} =======')
            metrics = evaluate_property_predictor(model, val_loader, target_mean, target_std, config, run, epoch, device)
            # Save checkpoint with model, optimizer, and scheduler states
            save_checkpoint(config, run, model, optimizer, scheduler, epoch, target_mean, target_std, metrics)
    if config.general.wandb.mode == 'enabled':
        run.finish()

if __name__=='__main__':
    train_property_predictor()