import numpy as np
import torch
from torch import nn
from tqdm import tqdm
import argparse


from datasets import Regression, Banana, Uci, Mnist
from vi_model import VIBNN
from la_models import LABNN
from metric_model import AEBNN

from models_utils import regression_metrics, classification_metrics
from plot_utils import plot_hypothesis_reg, plot_var_reg, plot_hypothesis_class, plot_var_class #, get_regression_fig, get_banana_fig


def setup(args, network_specs=None, mcmc_vars=None):


    assert (args.prob == 1 if args.model_type == 0 else True), "When using Variational Inference the model has to be probabilistic"

    device = torch.device('cuda:0' if torch.cuda.is_available() and args.device == 1 else 'cpu')
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    model_type = args.model_type
    model_names = {0: 'VI_BNN', 1: 'Laplace_BNN', 2: 'Laplace_BNN_our', 3: 'MetricBNN'}
    model_sizes = {0: 'small', 1: 'big', 2: 'real'}
    model_size = model_sizes[args.model_size]
    experiments_type = {0: 'regression',
                        1: 'banana',
                        2: 'UCI_australian',
                        3: 'UCI_breast',
                        4: 'UCI_glass',
                        5: 'UCI_ionosphere',
                        6: 'UCI_vehicle',
                        7: 'UCI_waveform',
                        8: 'MNIST',
                        9: 'FashionMNIST'}
    experiment = experiments_type[args.experiment]

    weight_decay = args.wd

    name_exp = model_names[model_type] + '_' + model_size + '_' + experiment + '_L2=' + str(weight_decay) + '_seed=' + str(seed)
    model_file_name = experiment + '_' + model_size + '_L2=' + str(weight_decay) + '_seed=' + str(seed)

    if network_specs is not None:
        name_exp += '_layers=' + str(len(network_specs['architecture']))
        model_file_name += '_layers=' + str(len(network_specs['architecture']))


    loss_category = 'classification'
    all_x = None

    ''' GENERAL HYPERPARAMETERS '''
    if experiment == 'regression':

        # small: L2 1e-3, epochs 7e5
        # big: L2 1e-2, epochs 3.5e4

        loss_category = 'regression'

        batch_size = 200
        n_test_samples = 100
        lr = 1e-3
        EPOCHS = 50000 if model_size == 'small' else 35000 #700000
        testing_epochs = int(EPOCHS/100)
        ood = args.ood_regr #False
        probabilistic = args.prob == 1
        loss_type = 'NLL' if probabilistic else 'mse'

        if network_specs is None:
            if model_size == 'small':
                network_specs = {'architecture': [[1, 15], [15, 1+args.prob]], 'activation': nn.Tanh()}
            elif model_size == 'big':
                network_specs = {'architecture': [[1, 10], [10, 10], [10, 1+args.prob]], 'activation': nn.Tanh()}
            else:
                network_specs = {'architecture': [[1, 32], [32, 32], [32, 32], [32, 1+args.prob]], 'activation': nn.ReLU()}

        # get_plot = get_regression_fig
        plot_hypothesis = plot_hypothesis_reg
        plot_var = plot_var_reg
        get_metrics = regression_metrics

        dataset = Regression(ood, device)

        name_exp += '_ood=' + str(ood)
        model_file_name += '_ood=' + str(ood)

        all_x = torch.linspace(-1, 7, 100).float().to(device).view(-1, 1)

    elif experiment == 'banana':

        # L2 1e-2
        network_specs = {'architecture': [[2, 16], [16, 16], [16, 2]], 'activation': nn.Tanh()} if network_specs is None else network_specs

        batch_size = 32
        n_test_samples = 100
        lr = 1e-3
        EPOCHS = 2500
        testing_epochs = 25
        probabilistic = False
        loss_type = 'CE'

        # get_plot = get_banana_fig
        plot_hypothesis = plot_hypothesis_class
        plot_var = plot_var_class
        get_metrics = classification_metrics

        dataset = Banana(device)

        name_exp += '_n_samples=' + str(n_test_samples)

        x_test, y_test = dataset.x_test.detach().cpu().numpy(), dataset.y_test.detach().cpu().numpy()

        N_grid = 100
        offset = 2
        x1min = x_test[:, 0].min() - offset
        x1max = x_test[:, 0].max() + offset
        x2min = x_test[:, 1].min() - offset
        x2max = x_test[:, 1].max() + offset

        x_grid = np.linspace(x1min, x1max, N_grid)
        y_grid = np.linspace(x2min, x2max, N_grid)
        XX1, XX2 = np.meshgrid(x_grid, y_grid)
        X_grid = np.column_stack((XX1.ravel(), XX2.ravel()))
        all_x = torch.from_numpy(X_grid).float().to(device)

    elif experiment[:3] == 'UCI':

        data_specs = {  # input, number of categories
            'UCI_australian': [14, 2],
            'UCI_breast': [9, 2],
            'UCI_glass': [9, 8],
            'UCI_ionosphere': [34, 2],
            'UCI_vehicle': [25, 2],
            'UCI_waveform': [21, 3]
        }

        # network_specs = {'architecture': [[data_specs[experiment][0], 50], [50, data_specs[experiment][1]]], 'activation': nn.Tanh()} if network_specs is None else network_specs
        network_specs = {'architecture': [[data_specs[experiment][0], 32], [32, 32], [32, data_specs[experiment][1]]], 'activation': nn.ReLU()} if network_specs is None else network_specs

        batch_size = 32
        n_test_samples = 30
        lr = 1e-3
        EPOCHS = 1000 #0
        testing_epochs = 100
        probabilistic = False
        loss_type = 'CE'

        # get_plot = lambda a, b, c: None
        plot_hypothesis = lambda a, b, c: None
        plot_var = lambda a, b, c: None
        get_metrics = classification_metrics

        dataset = Uci(experiment, device)

        name_exp += '_n_samples=' + str(n_test_samples)

    elif experiment[-5:] == 'MNIST':

        network_specs = {'architecture': [[1, 5, 4], [4, 5, 4], [4*(4**2), 16], [16, 10], [10, 10], [10, 10]], 'activation': nn.Tanh()} if network_specs is None else network_specs

        batch_size = 32
        n_test_samples = 25
        lr = 1e-3
        EPOCHS = 100
        testing_epochs = 10
        probabilistic = False
        loss_type = 'CE'

        # get_plot = lambda a, b, c: None
        plot_hypothesis = lambda a, b, c: None
        plot_var = lambda a, b, c: None
        get_metrics = classification_metrics

        dataset = Mnist(experiment, device)

        name_exp += '_n_samples=' + str(n_test_samples)


    ''' MODEL-SPECIFIC HYPERPARAMETERS '''
    if model_names[model_type] == 'VI_BNN':  # VI hyperparams

        beta_kl = args.kl
        init_vals = {'mu': 1e-2, 'std': args.std}
        diag = True

        model = VIBNN(network_specs, weight_decay, beta_kl, lr, loss_type, diag, n_test_samples, device, init_vals).to(device)

        name_exp += '_beta=' + str(beta_kl) + '_init_vals=' + str(init_vals)

    elif model_names[model_type][:11] == 'Laplace_BNN':  # LA hyperparams

        hessian_types = {0: 'full', 1: 'diag', 2: 'fisher', 3: 'kron', 4: 'lowrank', 5: 'gp', 6: 'gauss_newton'}
        hessian_type = hessian_types[args.hessian_type]

        implementation_type = 0 if model_names[model_type] == 'Laplace_BNN' else 1

        marginal_type = 'determinant'

        use_riemann = args.use_riemann == 1
        use_linear_network = args.use_linear_network == 1
        tune_alpha = args.tune_alpha == 1

        model = LABNN(implementation_type, loss_category, network_specs, weight_decay, lr, loss_type, n_test_samples, hessian_type, probabilistic, marginal_type, use_linear_network, use_riemann, tune_alpha, device).to(device)

        name_exp += '_hessian_type=' + hessian_type
        name_exp += '_riemannian=' + str(use_riemann) + '_lin=' + str(use_linear_network)
        name_exp += '_tune_alpha=' + str(tune_alpha)

    else:

        model = AEBNN(network_specs, 0, weight_decay, lr, loss_type, device).to(device)

        variables = {}
        variables['alpha'] = 0.001 if mcmc_vars is None else mcmc_vars['alpha']
        variables['T'] = 100 if mcmc_vars is None else mcmc_vars['T']
        variables['n_steps'] = 5 if mcmc_vars is None else mcmc_vars['n_steps']
        variables['n_traj'] = 100 if mcmc_vars is None else mcmc_vars['n_traj']
        variables['use_brownian'] = True
        variables['inner_lr'] = 0.001
        variables['k'] = 32
        variables['batch_size'] = 1024
        variables['epochs'] = 1000
        variables['pos_lambda'] = 1.0
        variables['neg_lambda'] = 1.0
        variables['dec_lambda'] = 1.0

        model.set_global_variables(variables)

    # writer = SummaryWriter("logs/" + name_exp + "")

    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=True)

    training_params = {}

    training_params['EPOCHS'] = EPOCHS


    return model, loader, name_exp, model_file_name, all_x, training_params, plot_hypothesis, plot_var, get_metrics, device











