from src import datasets, models, train


def main(device, model_name, **model_kwargs):
    kwargs_str = train.dict_to_str(model_kwargs)
    name = 'weights/' + model_name + kwargs_str

    pre_epochs = 10
    epochs = 100
    batch_size = 1024 if model_name == 'ctfp' else 4096
    save = True
    ratio = 5
    generator_lr = 4e-5
    discriminator_lr = 4e-5
    epoch_per_metric = 5
    averaging = 0

    classification_epochs = 50
    classification_lr = 1e-4
    classification_kwargs = dict(hidden_size=32, hidden_hidden_size=32, num_layers=2)
    classification_plateau_terminate = 20

    prediction_epochs = 50
    prediction_lr = 1e-4
    prediction_kwargs = dict(context_size=32, hidden_size=32, hidden_hidden_size=32, num_layers=2)
    prediction_plateau_terminate = 20
    prediction_split = 0.8

    t, dataloader, input_channels, label_channels = datasets.weights_data(batch_size)
    t = t.to(device)

    model = models.make_model(model_name, input_channels=input_channels, label_channels=label_channels, **model_kwargs)

    def print_callback(model):
        pass

    ret = train.main(name=name,
                     t=t,
                     dataloader=dataloader,
                     model=model,
                     device=device,
                     save=save,
                     pre_epochs=pre_epochs,
                     epochs=epochs,
                     ratio=ratio,
                     generator_lr=generator_lr,
                     discriminator_lr=discriminator_lr,
                     print_callback=print_callback,
                     epoch_per_metric=epoch_per_metric,
                     input_channels=input_channels,
                     classification_epochs=classification_epochs,
                     classification_lr=classification_lr,
                     classification_kwargs=classification_kwargs,
                     classification_plateau_terminate=classification_plateau_terminate,
                     prediction_epochs=prediction_epochs,
                     prediction_lr=prediction_lr,
                     prediction_kwargs=prediction_kwargs,
                     prediction_plateau_terminate=prediction_plateau_terminate,
                     prediction_split=prediction_split,
                     averaging=averaging)
    return ret


def default(device, model_name):
    if model_name == 'nsde':
        model_kwargs = dict(generator_hidden_channels=96,
                            generator_hidden_hidden_channels=64,
                            generator_num_layers=2,
                            discriminator_hidden_channels=96,
                            discriminator_hidden_hidden_channels=64,
                            discriminator_num_layers=2,
                            noise_channels=3,
                            initial_noise_channels=40,
                            lipschitz=dict(gp=10),
                            adjoint=False,
                            method='midpoint',
                            adaptive=False)
    elif model_name == 'latent_ode':
        model_kwargs = dict(encoder_hidden_channels=96,
                            encoder_hidden_hidden_channels=64,
                            encoder_num_layers=2,
                            context_channels=40,
                            decoder_hidden_channels=96,
                            decoder_hidden_hidden_channels=64,
                            decoder_num_layers=2)
    elif model_name == 'ctfp':
        model_kwargs = dict(encoder_hidden_channels=96,
                            encoder_hidden_hidden_channels=64,
                            encoder_num_layers=2,
                            context_channels=40,
                            decoder_hidden_channels=64,
                            decoder_num_layers=2)
    else:
        raise ValueError
    return main(device=device, model_name=model_name, **model_kwargs)
