#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This script loads a json file of model definitions and runs one of them selected
just by a number.

This is for training many models in parallel.

also in this file: A function to add a model to the json file
"""

import os
import json
import torch
from script_general import main


def add_model(file='models.json', model='pixel', version=999, neigh=4,
              learning_rate=0.001, epochs=10, crop_size=256,
              batch_size=32, num_workers=6, downsample=0, lr_fact=10,
              pars_train='all', loss_type='pos', momentum=0.9,
              weight_decay=10**-4, n_report=10000, n_acc=5, n_pos=5, noise=0,
              **kwargs):
    if os.path.exists(file):
        with open(file, 'r') as f:
            model_list = json.load(f)
    else:
        model_list = []
    new_model = {
        'model': model,
        'version': version,
        'neigh': neigh,
        'learning_rate': learning_rate,
        'lr_fact': lr_fact,
        'epochs': epochs,
        'crop_size': crop_size,
        'batch_size': batch_size,
        'num_workers': num_workers,
        'downsample': downsample,
        'loss_type': loss_type,
        'momentum': momentum,
        'weight_decay': weight_decay,
        'n_report': n_report,
        'n_acc': n_acc,
        'pars_train': pars_train,
        'noise': noise}
    model_list.append(new_model)
    with open(file, 'w') as f:
        json.dump(model_list, f, indent=2)


def create_std_list(file='models.json'):
    for i_model, batch_size, n_acc in zip(
            ['pixel', 'linearbig', 'predseg1'],
            [32, 6, 16],
            [5, 10, 10]):
        version = 0
        for loss in ['pos', 'shuffle']:
            if loss == 'shuffle' and i_model == 'predseg1':
                batch_size_c = 24
            else:
                batch_size_c = batch_size
            for noise in [0, 0.1, 0.2]:
                for neigh in [4, 8, 12, 20]:
                    add_model(file=file, model=i_model, version=version,
                              neigh=neigh, loss_type=loss, noise=noise,
                              batch_size=batch_size_c, n_acc=n_acc)
                    version += 1


def run_model_by_number(action='train', file='models.json', idx=0, **kwargs):
    with open(file, 'r') as f:
        model_list = json.load(f)
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'
    if action == 'save_BSD':
        model_list[idx]['num_workers'] = 0
        model_list[idx].pop('downsample', None)
    else:
        kwargs.pop('downsample', None)
    main(action=action, cont=True, device=device, **model_list[idx], **kwargs)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-a", "--action",
                        help="what to do? [train, train_readout, save_eval, eval"
                        + ", reset, save_BSD]",
                        choices=['train', 'train_readout', 'save_eval',
                                 'eval', 'reset', 'save_BSD'],
                        default='train')
    parser.add_argument("--verbose", action='store_true')
    parser.add_argument("--separate", action='store_true')
    parser.add_argument("--interpolate", action='store_true')
    parser.add_argument("--split",
                        help="which data split to work on. Applies only to evaluations",
                        choices=['train', 'val', 'test'],
                        default='train')
    parser.add_argument("-f", "--file",
                        help="file of model definitions",
                        type=str, default='models.json')
    parser.add_argument("-d", "--downsample",
                        help="factor for downsampling input (0 = don't)",
                        type=int, default=0)
    parser.add_argument("idx",
                        help="model index to run",
                        type=int, default=0)
    parser.add_argument("-t", "--transform",
                        help="which transform to edge_weights ['quant', 'expit']",
                        choices=['expit', 'quant'],
                        default='quant')
    args = parser.parse_args()
    run_model_by_number(**vars(args))
