import os
# Surpress tensorflow cuda errors
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import random
import warnings
import click
import colorama
import collections
import numpy as np

import torch
from convexrobust.data import datamodules
from convexrobust.model.base_certifiable import Norm
# from convexrobust.model.linf_certifiable import LInfCertifiable
from convexrobust.model.insts.convex import (
    ConvexCifar, ConvexMnist, ConvexKaggle, ConvexMalimg, ConvexSimple
)
from convexrobust.model.insts.randsmooth import (
    RandsmoothCifar, RandsmoothMnist, RandsmoothKaggle, RandsmoothMalimg, RandsmoothSimple
)
from convexrobust.model.insts.cayley import (
    CayleyCifar, CayleyMnist
)

from convexrobust.utils import file_utils, pretty, dirs

from convexrobust.main import core_manager, plot
from convexrobust.main.core_manager import ModelBlueprint, BlueprintDict, ModelDict, ResultDict


def randsmooth_blueprints(randsmooth_class, epochs, sigma_scale, nb=100, low_n=True):
    if low_n:
        n_params = {'n': 10000, 'cert_n_scale': 10, 'nb': nb}
    else:
        n_params = {'n': 100000, 'cert_n_scale': 1, 'nb': nb}
    return {
        'randsmoothsplitderandom_a': ModelBlueprint(randsmooth_class, epochs, False,
            {'noise': 'split_derandomized', 'sigma': sigma_scale, **n_params}, 'simple'
        ),
        'randsmoothlaplace_a': ModelBlueprint(randsmooth_class, epochs, False,
            {'noise': 'laplace', 'sigma': sigma_scale, **n_params}, 'simple'
        ),
        'randsmoothgauss_a': ModelBlueprint(randsmooth_class, epochs, False,
            {'noise': 'gaussian', 'sigma': sigma_scale, **n_params}, 'simple'
        ),
        'randsmoothuniform_a': ModelBlueprint(randsmooth_class, epochs, False,
            {'noise': 'uniform', 'sigma': sigma_scale, **n_params}, 'simple'
        ),
    }


def get_blueprints(global_params) -> BlueprintDict:
    if global_params.experiment in ['ablation', 'ablation_noaugment']:
        assert global_params.data in ['cifar10_catsdogs', 'cifar10_dogscats']

    if global_params.data == 'mnist_38':
        return {
            'convex_noreg': ModelBlueprint(ConvexMnist, 30, False, {'reg': 0.0}),
            'convex_reg': ModelBlueprint(ConvexMnist, 30, False, {'reg': 0.01}),
            'cayley': ModelBlueprint(CayleyMnist, 30, False, {}),
            # Commented out by default since install is tricky -- see lib/linf_dist
            # for install instructions and uncomment to run this baseline
            # 'linf': ModelBlueprint(LInfCertifiable, 0, False, {
                # 'model': 'MLPModel(depth=5,width=5120,identity_val=10.0,scalar=True)',
                # 'load_path': dirs.pretrain_path('mnist_38', 'model.pth'),
                # 'input_shape': [1, 28, 28]
            # }),
            **randsmooth_blueprints(RandsmoothMnist, 30, 0.75, low_n=False)
        }
    elif global_params.data == 'malimg':
        return {
            'convex_noreg': ModelBlueprint(ConvexMalimg, 50, False, {}),
            'convex_reg': ModelBlueprint(ConvexMalimg, 50, False, {'reg': 0.1}),
            **randsmooth_blueprints(RandsmoothMalimg, 50, 3.5, nb=32)
        }
    elif global_params.data in ['cifar10_catsdogs', 'cifar10_dogscats']:
        if global_params.experiment == 'standard':
            return {
                'convex_noreg': ModelBlueprint(ConvexCifar, 150, False, {}),
                'convex_reg': ModelBlueprint(ConvexCifar, 150, False, {'reg': 0.01}),
                'cayley': ModelBlueprint(CayleyCifar, 150, False, {}),
                **randsmooth_blueprints(RandsmoothCifar, 600, 0.75),
            }
        elif global_params.experiment == 'ablation':
            blueprints = {
                'convex_noaugment': ModelBlueprint(
                    ConvexCifar, 150, False, {'augment_input': False, 'reg': 0.1}
                )
            }

            for i, reg in enumerate([0.0, 0.0025, 0.005, 0.0075, 0.01]):
                blueprints[f'convex_reg_{i}'] = ModelBlueprint(
                    ConvexCifar, 150, False, {'reg': reg}
                )

            return blueprints
        elif global_params.experiment == 'ablation_noaugment':
            return {
                'convex_noaugment': ModelBlueprint(
                    ConvexCifar, 500, False, {'augment_input': False}
                )
            }
        else:
            raise RuntimeError('Bad experiment type')
    elif global_params.data == 'kaggle_catsdogs':
        return {
            'convex_noreg': ModelBlueprint(ConvexKaggle, 225, False, {}),
            'convex_reg': ModelBlueprint(ConvexKaggle, 225, False, {'reg': 0.01}),
            **randsmooth_blueprints(RandsmoothKaggle, 225, 3.5)
        }
    elif global_params.data == 'circles':
        return {
            'convex_noreg': ModelBlueprint(ConvexSimple, 30, False, {})
        }
    else:
        raise RuntimeError('Bad dataset')

@click.command()
@click.option('--data', type=click.Choice(datamodules.names), default='cifar10_catsdogs')
@click.option('--experiment', type=click.Choice(['standard', 'ablation', 'ablation_noaugment']),
              default='standard')
@click.option('--train/--no_train', default=True)
@click.option('--augment_data/--no_augment_data', default=True)
@click.option('--balance/--no_balance', default=True)
@click.option('--rebalance/--no_rebalance', default=False)

@click.option('--eval/--no_eval', default=True)
@click.option('--eval_n', default=1000)
@click.option('--verify_cert/--no_verify_cert', default=False)
@click.option('--empirical_cert/--no_empirical_cert', default=False)
@click.option('--tensorboard/--no_tensorboard', default=True)
@click.option('--clear_models/--no_clear_models', default=False)

# PLOTTING
@click.option('--clear_figs/--no_clear_figs', default=False)
@click.option('--figsize', type=click.Choice(plot.figsize_dict.keys()), default='large')
@click.option('--labels', type=click.Choice(plot.labels_dict.keys()), default='standard_a')
@click.option('--x_label/--no_x_label', default=True)
@click.option('--y_label/--no_y_label', default=True)
@click.option('--x_log/--no_x_log', default=False)
@click.option('--label_acc/--no_label_acc', default=True)
@click.option('--title', type=str, default=None)

@click.option('--seed', default=1)
def run(data, experiment, train, augment_data, balance, rebalance,
        eval, eval_n, verify_cert, empirical_cert, tensorboard, clear_models,
        clear_figs, figsize, labels, x_label, y_label, x_log, label_acc, title, seed):
    assert not (train and not eval)

    torch.manual_seed(seed)
    random.seed(seed)

    colorama.init()

    pretty.section_print('Loading datamodules')
    warnings.filterwarnings('ignore')
    datamodule = datamodules.get_datamodule(data, no_transforms=not augment_data)

    pretty.section_print('Assembling parameters')
    local_vars = locals()
    global_params = collections.namedtuple('Params', local_vars.keys())(*local_vars.values())

    blueprints: BlueprintDict = get_blueprints(global_params)

    pretty.section_print('Creating models')
    models: ModelDict = core_manager.create_models(blueprints, global_params)

    for model_name in models.keys():
        file_utils.ensure_created_directory(f'./figs/{data}-{experiment}', clear=clear_figs)

    for (_, model) in models.items():
        model.eval()

    pretty.section_print('Evaluating models')
    results: ResultDict = core_manager.eval_models(models, blueprints, global_params)

    pretty.section_print('Plotting results')
    plot.clean_confusion_plot(results, global_params)
    plot.certified_radius_plot(results, global_params, norm=Norm.L1)
    plot.certified_radius_plot(results, global_params, norm=Norm.L2)
    plot.certified_radius_plot(results, global_params, norm=Norm.LInf)

    if global_params.data in ['circles']:
        for name, model in models.items():
            plot.make_decision_plot(model, name, global_params, threshold=False, plot_cert=True)

    pretty.section_print('Done plotting (ctrl+c to exit)')


if __name__ == "__main__":
    run()
