from random import randint

import yaml
from pytorch_lightning.utilities.seed import seed_everything

from mind_the_pad.train_mnist import parse_args
from mind_the_pad.paths import experiments_folder
from mind_the_pad.train_mnist.model import MnistPadding
from mind_the_pad.train_mnist.trainer_pl import run_training


def new_experiment_folder():
    exprmnt_category_folder = experiments_folder / 'EMNIST'
    if not exprmnt_category_folder.exists(): exprmnt_category_folder.mkdir()
    i = 0
    while (exprmnt_category_folder / str(i)).exists():
        i += 1
    result = exprmnt_category_folder / str(i)
    result.mkdir()
    return result


def num_runs_with_same_hparams(hparams: dict):
    counter = 0
    emnist_runs = experiments_folder / 'EMNIST'
    if not emnist_runs.exists(): emnist_runs.mkdir()
    for run in emnist_runs.dirs():
        bn = hparams['batch_norm']
        with open(run / 'lightning_logs' / 'version_0' / 'hparams.yaml') as f:
            has_bn = yaml.safe_load(f)['batch_norm']
        with open(run / 'config.yaml') as f:
            actual_hparams = yaml.safe_load(f)
        del actual_hparams['seed']  # seed doesn't matter
        are_same = all(hparams[k] == actual_hparams[k] for k in hparams if k != 'batch_norm') and bn == has_bn
        if are_same:
            counter += 1
    return counter


def main():
    ntrials = 3
    args = parse_args()
    for padding_mode in ['zeros', 'reflect', 'replicate', 'circular']:
        for batch_norm_enabled in [False, True]:  # todo set at the end [True, False]
            for pad in [0, 1]:
                hparams = dict(padding_mode=padding_mode, padding='same',
                               input_size=[28 + pad, 28 + pad], batch_norm=batch_norm_enabled)
                num_trained_models = num_runs_with_same_hparams(hparams.copy())
                assert num_trained_models <= ntrials
                if num_trained_models == ntrials:
                    print('skip', hparams, f'because {ntrials} models are already trained')
                    continue
                for _ in range(ntrials - num_trained_models):
                    seed = randint(0, 2 ** 32)
                    hparams['seed'] = seed
                    seed_everything(seed)
                    model = MnistPadding('same', padding_mode, random_pad_input=pad,
                                         batch_norm=batch_norm_enabled, bn_affine=True)
                    expr_folder = new_experiment_folder()
                    with open(expr_folder / 'config.yaml', 'w') as f:
                        yaml.safe_dump(hparams, f)
                    run_training(expr_folder, model, args.device)
                    del expr_folder, model


if __name__ == '__main__':
    main()
