"""Check a variety of DARTS schemes against a given search space."""
import numpy as np

import os
import json
import torch

from collections import defaultdict

import datetime

import time
import logging


from regression import RegressionSearchSpace, BroadRegressionSearchSpace, BroadRegressionSearchSpaceSignalRecovery
from regression_ops import OP_LIST, PRIMITIVES, GOOD_PRIMITIVES, LEARNABLE_PRIMITIVES, BAD_PRIMITIVES, GOOD_LEARNABLE_PRIMITIVES
from regression_ops import LEARNABLE_IMPROVED_PRIMITIVES, GOOD_IMPROVED_LEARNABLE_PRIMITIVES

import inverse_problems
torch.backends.cudnn.benchmark = inverse_problems.consts.BENCHMARK
torch.multiprocessing.set_sharing_strategy(inverse_problems.consts.SHARING_STRATEGY)
# we ue this seed to control data generation  # a different seed will initialize the search runs later
inverse_problems.utils.set_random_seed(233)

# Parse input arguments
args = inverse_problems.options().parse_args()
# 100% reproducibility?
if args.deterministic:
    inverse_problems.utils.set_deterministic()
if args.seed is not None:
    inverse_problems.utils.set_random_seed(args.seed)


if args.operations in OP_LIST:
    ops = [args.operations]
elif args.operations == 'all':
    ops = PRIMITIVES
elif args.operations == 'all-switched-grad':
    ops = LEARNABLE_IMPROVED_PRIMITIVES
elif args.operations == 'all-good-switched-grad':
    ops = GOOD_IMPROVED_LEARNABLE_PRIMITIVES
elif args.operations == 'all-good':
    ops = GOOD_PRIMITIVES
elif args.operations == 'all-learnable':
    ops = LEARNABLE_PRIMITIVES
elif args.operations == 'all-bad':
    ops = BAD_PRIMITIVES
elif args.operations == 'all-good-learnable':
    ops = GOOD_LEARNABLE_PRIMITIVES
elif args.operations == 'all-random':
    ops = PRIMITIVES
elif args.operations == 'all-learnable-random':
    ops = LEARNABLE_PRIMITIVES
elif args.operations == 'all-good-learnable-random':
    ops = GOOD_LEARNABLE_PRIMITIVES
else:
    raise ValueError(f'Invalid operator argument {args.operations}.')


if __name__ == "__main__":
    log_dir = os.path.join('experiments', args.dataset, args.variant, args.operations,
                           '{}'.format(time.strftime("%Y%m%d-%H%M%S")))
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    print('Experiment dir : {}'.format(log_dir))

    inverse_problems.utils.start_logging(args.exp_name, args.dryrun, save_runs=False)
    setup, _ = inverse_problems.utils.system_startup(args)
    # setup['dtype'] = torch.double
    launch_time = time.time()

    # Get Model configs
    config = inverse_problems.utils.get_config(args)

    if args.grid_search:
        config['alpha_lr'] = args.alpha_lr
        config['alpha_weight_decay'] = args.alpha_weight_decay
        config['param_lr'] = args.param_lr
        config['param_weight_decay'] = args.param_weight_decay
    if args.md:
        # config['param_lr'] = config['param_lr'] * 0.1
        config['alpha_lr'] = config['alpha_lr'] * 0.1
    # Define hyperparameter variants
    variants = inverse_problems.utils.load_hyperparameters(args, config)

    # Prepare data
    trainloader, metaloader, validloader = inverse_problems.construct_dataloaders(args.dataset, batch_size=128,
                                                                                  data_path=args.data, noise_level=args.noise_level)
    operator = trainloader.dataset.operator

    # For BOHB:
    if args.bohb:
        with open(os.path.join(args.save_dir, 'config.json'), 'w') as fp:
            json.dump(args.__dict__, fp)

    for defs in variants:
        logging.info(f'Searching parameters: {repr(defs)}')
        start_time = time.time()

        search_options = dict(operators=ops, layers=args.layers, sublayers=args.sublayers, channels=args.channels,
                              softmax_normalization=not defs.mirror_descent, deep_supervised=args.deep_supervision,
                              norm=defs.norm, update=defs.update, dataset=args.dataset,
                              randomized_init=args.randomize_init or 'random' in args.operations,
                              noise_level=args.noise_level)

        np.random.seed(None)  # reset randgen
        initializer_seed = np.random.randint(0, 2**32 - 1)  # draw new seed
        inverse_problems.utils.set_random_seed(initializer_seed)    # propagate new seed

        if args.cell_steps <= 1:
            model = RegressionSearchSpace(operator=operator, setup=setup, **search_options)
        else:
            if args.recover_signal:
                model = BroadRegressionSearchSpaceSignalRecovery(operator=operator, setup=setup,
                                                                 steps=args.cell_steps, multiplier=2,
                                                                 **search_options)
            else:
                model = BroadRegressionSearchSpace(operator=operator, setup=setup,
                                                   steps=args.cell_steps, multiplier=2,
                                                   **search_options)

        model.to(**setup)

        # Dispatch training
        if 'random' not in args.operations:
            logging.info('Learning an optimal architecture ..........')
            genotype_final, stats = inverse_problems.search(
                model, trainloader, metaloader, validloader, defs, callback=args.callback)
        else:
            logging.info('Choosing a random architecture ..........')
            stats = defaultdict(lambda: [None])
            genotype_final = model.genotype()

        # Reset model to random parameters, only the architecture parameters are kept:
        if not args.bohb:
            final_alphas = model.arch_parameters()
            model = RegressionSearchSpace(operator=operator, setup=setup, **search_options)
            model.to(**setup)
            model._arch_parameters = final_alphas

            # Evaluate found architecture
            logging.info(f'Evaluating {genotype_final} ..........')
            eval_stats = inverse_problems.evaluate(model, trainloader, validloader, defs)

            # Save Model state dict to log dir
            filepath = os.path.join(log_dir, 'model_parameter.obj')
            torch.save([p.detach().clone() for p in model.parameters()], filepath)
            filepath = os.path.join(log_dir, 'model_buffers.obj')
            torch.save([p.detach().clone() for p in model.buffers()], filepath)
            config_dict = {
                'Train_Loss': stats["train_loss"][-1],
                'Validation_Loss': stats['valid_loss'][-1],
                'Arch_Train': eval_stats["train_loss"][-1],
                'Arch_Validation': eval_stats["valid_loss"][-1],
                'Trunc_Loss': stats["test_loss"][-1],
                'Train_Metric': stats["train_psnr"][-1],
                'Validation_Metric': stats["valid_psnr"][-1],
                'Trunc_Metric': stats["test_psnr"][-1],
            }

            with open(os.path.join(log_dir, 'log.txt'), 'w') as file:
                json.dump('json_stats: ' + str(args.__dict__), file)

            with open(os.path.join(log_dir, 'results.txt'), 'w') as file:
                json.dump(str(config_dict), file)
                file.write('\n')

            # Save results to table:
            to_table = dict(Method=defs.name,
                            Train_Loss=stats["train_loss"][-1],
                            Validation_Loss=stats['valid_loss'][-1],
                            Trunc_Loss=stats["test_loss"][-1],
                            Arch_Train=eval_stats["train_loss"][-1],
                            Arch_Validation=eval_stats["valid_loss"][-1],
                            Train_Metric=stats["train_psnr"][-1],
                            Validation_Metric=stats['valid_psnr'][-1],
                            Trunc_Metric=stats["test_psnr"][-1],
                            Arch_Train_Metric=eval_stats["train_psnr"][-1],
                            Arch_Validation_Metric=eval_stats["valid_psnr"][-1],
                            Entropy=model.mean_entropy(),
                            Normalized_Entropy=model.normalized_entropy(),
                            Genotype=str(genotype_final),
                            Possible_Operations=ops,
                            timestamp=str(datetime.timedelta(
                                seconds=time.time() - start_time)).replace(',', ''),
                            **defs.asdict(),
                            initializer_seed=initializer_seed,
                            ops=args.operations,
                            layers=args.layers,
                            sublayers=args.sublayers,
                            )
            inverse_problems.utils.save_to_table(
                'tables/', f'{args.dataset}_{args.exp_name}', **to_table)

        else:
            #            config_dict = {
            #                'loss': stats["valid_loss"][-1],
            #                'Arch_Train': None,
            #                'Train_Loss': stats["train_loss"][-1],
            #                'Validation_Loss': stats['valid_loss'][-1],
            #                'Trunc_Loss': stats["test_loss"][-1],
            #            }
            config_dict = {
                'loss': stats["valid_psnr"][-1],
                'Train_Loss': stats["train_loss"][-1],
                'Validation_Loss': stats["valid_loss"][-1],
                'Trunc_Loss': stats["test_loss"][-1],
                'Train_Metric': stats["train_psnr"][-1],
                'Trunc_Metric': stats['test_psnr'][-1],
            }

            with open(os.path.join(args.save_dir, 'log.txt'), 'w') as file:
                json.dump('json_stats: ' + str(args.__dict__), file)

            with open(os.path.join(args.save_dir, 'results.txt'), 'w') as file:
                json.dump(str(config_dict), file)
                file.write('\n')

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('---------------------------------------------------')
    print(f'Finished computations with total train time: '
          f'{str(datetime.timedelta(seconds=time.time() - launch_time))}')
    print('-------------Job finished.-------------------------')
