"""Train a model on a Slurm HPC cluster.

To run the model outside the cluster, simply change the main commands.
"""

import argparse

import pytorch_lightning as pl

import sys
sys.path.append('')

import functions.datasets as datasets
import functions.sgc_resnets as sgc

def float_interval(lower=None, upper=None):
    def fun(x):
        x = float(x)
        if lower is not None:
            assert x >= lower
        if upper is not None:
            assert x <= upper
        return x
    return fun

def int_or_none(x):
    if x is None:
        return x
    return int(x)

def main(args):
    n_settings = len(args.coupling_arr)
    assert (n_settings == len(args.init_arr))
    id_setting = args.array_id % n_settings
    it = args.array_id // n_settings
    pl.seed_everything(it)
    train_loader, _, val_loader = datasets.get_data(
        args.data, batch_size=args.batch_size, trainset=args.train_size
    )
    args.stages = args.stages or datasets.stages[args.data]
    if args.coupling is None:
        args.coupling = args.coupling_arr[id_setting]
    if args.init is None:
        args.init = args.init_arr[id_setting]
    model = sgc.SGCResNetModule(args)
    version = 'init={}_coupling={}_it={}'\
              .format(args.init, args.coupling, it)
    checkpoint = pl.callbacks.ModelCheckpoint(
        args.folder+version+'_{epoch}',
        monitor='val_acc', mode='max'
    )
    logger = pl.loggers.CSVLogger(args.folder, version=version)
    trainer = pl.Trainer.from_argparse_args(args)
    trainer.logger = logger
    trainer.checkpoint_callback = checkpoint
    trainer.fit(model, train_loader, val_loader)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('data', type=str)
    parser.add_argument('folder', type=str)
    parser.add_argument('--array_id', type=int, default=0)
    parser.add_argument('--coupling_arr', type=float_interval(0, 1), nargs='+',
                        default = [0, 0, 0.5, 0.9, 1])
    parser.add_argument('--init_arr', choices=['r', 'nr'], nargs='+',
                        default = ['nr', 'r', 'r', 'r', 'r'])
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--train_size', type=int_or_none, default=None)
    parser = sgc.SGCResNetModule.add_model_specific_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()
    main(args)
