"""Example script for training a zero-th order model (no error correction)
"""
import torch
import hydra
import wandb
import random
import numpy as np

from error_correction.diffeqs import *
from error_correction.models import NNDESolver
from error_correction.trainers import OperatorTrainer, FullBatchTrainer


def set_seed(seed):
    """Fix the random seed for reproducibility
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


@hydra.main(config_path='../config', config_name='config')
def main(hparams):

    # set random seed everywhere for reproducibility
    set_seed(hparams.random_seed)

    # get the datasets and loaders
    dataset = {
        'nPBE': NonlinearPBE,
        'HenonHeiles': HenonHeiles,
        'NonlinearOscillator': NonlinearOscillator
    }[hparams.data.name]
    train_ds = dataset(hparams, mode='train')
    test_ds = dataset(hparams, mode='test')

    # initialize the model
    model = NNDESolver(hparams)
    trainer = (
        OperatorTrainer if hparams.optimizer.name != 'LBFGS' else FullBatchTrainer
    )(model, train_ds, test_ds, hparams)

    # for remote train logging via weights & biases
    if hparams.wandb.log:
        wandb.init(
            project=hparams.wandb.project_name, 
            entity=hparams.wandb.entity,
            group=trainer.enc    
        )

    trainer.train()

    
if __name__ == '__main__':
    main()
    
