import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from tqdm import tqdm
import argparse
from torch.utils.tensorboard import SummaryWriter

from models_utils import regression_metrics, classification_metrics
from datasets import Regression, Banana, Uci, Mnist
from vi_model import VIBNN
from la_models import LABNN
from metric_model import AEBNN

import pickle
import timeit

from main_utils import setup

def train(loader, model):

    best_acc = 0

    writer = SummaryWriter("logs/" + name_exp + "")

    tensorboard_idx = 0
    for epoch in tqdm(range(training_params['EPOCHS'])):

        model.train()
        for (x, y) in loader:

            y_pred = model(x)
            losses, loss = model.loss(y_pred, y, logs=True)

            model.optimizer.zero_grad()
            loss.backward()
            model.optimizer.step()

            for loss_key in losses:
                writer.add_scalar('Loss/'+loss_key, losses[loss_key], tensorboard_idx)
            tensorboard_idx += 1

    torch.save(model.theta, "./saved_models/" + model_file_name + ".pt")

    writer.close()

parser = argparse.ArgumentParser()

parser.add_argument('--do_train', default=0, type=int)
parser.add_argument('--qual_plot', default=1, type=int)
parser.add_argument('--quant_table', default=0, type=int)
parser.add_argument('--comp_complx', default=0, type=int)

parser.add_argument('--experiment', default=0, type=int, help="0: regression, 1: banana, 2-7: UCI, 8: MNIST, 9: FashionMNIST")
parser.add_argument('--ood_regr', default=0, type=int, help="")

parser.add_argument('--model_size', default=2, type=int, help="0: small, 1: big, 2: real")
parser.add_argument('--wd', default=0.01, type=float, help="L2 regularization")

parser.add_argument('--model_type', default=1, type=int, help="0: VI_BNN, 1: Laplace_BNN, 2: Laplace_BNN_our, 3: MetricBNN")
parser.add_argument('--use_riemann', default=0, type=int, help="0: don't use, 1: use")
parser.add_argument('--use_linear_network', default=0, type=int, help="0: don't use, 1: use")
parser.add_argument('--tune_alpha', default=1, type=int, help="0: don't use, 1: use")

parser.add_argument('--hessian_type', default=0, type=int, help="0: full, 1: diag, 2: fisher, 3: kron, 4: lowrank, 5: gp, 6:gauss_newton")

parser.add_argument('--kl', default=0.01, type=float, help="KL weighting term")
parser.add_argument('--std', default=0, type=float, help="initial standard deviation")
parser.add_argument('--prob', default=0, type=int, help="0: deterministic_out, 1: probabilistic_out")

parser.add_argument('--seed', default=0, type=int, help="seed")
parser.add_argument('--device', default=1, type=int, help="device")

parser.add_argument('--n_layers', default=0, type=int, help="n layers")


parser.add_argument('--alpha', default=0.001, type=float)
parser.add_argument('--T', default=15, type=int)
parser.add_argument('--n_steps', default=10, type=int)
parser.add_argument('--n_traj', default=10, type=int)


args = parser.parse_args()


mcmc_vars = {}
mcmc_vars['alpha'] = args.alpha #0.01
mcmc_vars['T'] = args.T #5000
mcmc_vars['n_steps'] = args.n_steps #5 #10
mcmc_vars['n_traj'] = args.n_traj #10


model, loader, name_exp, model_file_name, all_x, training_params, plot_hypothesis, plot_var, get_metrics, device = setup(args, mcmc_vars=mcmc_vars)

if args.do_train == 1 or not os.path.exists("./saved_models/" + model_file_name + ".pt"):
    train(loader, model)
else:
    model.theta = torch.load("./saved_models/" + model_file_name + ".pt", map_location=device).to(device)


model.eval()



if args.qual_plot == 1:

    if args.model_type == 1:

        args.model_size = 2
        args.use_linear_network = 0
        model, loader, name_exp, model_file_name, all_x, training_params, plot_hypothesis, plot_var, get_metrics, device = setup(args, mcmc_vars=mcmc_vars)
        train(loader, model)
        y_map, y_mu, y_std, py = model.posterior(all_x, loader)
        name = 'LA' if args.use_riemann == 0 else 'LA_Riem'
        name += '_reg' if args.experiment == 0 else '_class'
        plot_hypothesis(all_x, py, loader, name=name + '_hyp', color='tab:orange', save=False)

        y_map, y_mu, y_std, py = model.posterior(all_x, loader)

        torch.save(y_map, './saved_posteriors/' + name_exp + '_map.pt')
        torch.save(py, './saved_posteriors/' + name_exp + '_py.pt')

        name = 'LA' if args.use_riemann == 0 else 'LA_Riem'
        name += '_reg' if args.experiment == 0 else '_class'
        # plot_hypothesis(all_x, py, loader, name=name + '_hyp', color='tab:orange', save=True)
        plot_var(all_x, py, loader, color='tab:orange', name=name + '_var', save=True)

    else:

        variables = {}

        variables['alpha'] = 0.1
        variables['T'] = 100
        variables['n_steps'] = 10
        variables['n_traj'] = 10
        variables['use_brownian'] = True
        variables['inner_lr'] = 0.001

        variables['k'] = 32
        variables['k2'] = 32
        variables['batch_size'] = 1024
        variables['epochs'] = 50000
        variables['pos_lambda'] = 1.0
        variables['neg_lambda'] = 0.1
        variables['dec_lambda'] = 1.0

        model.set_global_variables(variables)

        model.traj_theta = torch.load('./test_fiber_class.pt').to(device)

        model.train_metric(model.traj_theta)

        y_map, y_mu, y_std, py = model.posterior(all_x, loader)
        _, y_mu_naive, y_std_naive, py_naive = model.posterior_naive(all_x, loader)

        all_theta = model.traj_theta.reshape([-1, model.traj_theta.shape[-1]])
        py_mcmc = torch.stack([model(all_x, theta_sample) for theta_sample in all_theta], 0)

        torch.save(y_map, './saved_posteriors/' + name_exp + '_map.pt')
        torch.save(py, './saved_posteriors/' + name_exp + '_py.pt')
        torch.save(py_naive, './saved_posteriors/' + name_exp + '_py_naive.pt')
        torch.save(py_mcmc, './saved_posteriors/' + name_exp + '_py_mcmc.pt')

        name = 'MCMC'
        name += '_reg' if args.experiment == 0 else '_class'
        plot_hypothesis(all_x, py_mcmc, loader, color='tab:orange', name=name + '_hyp_mc', save=True)
        plot_hypothesis(all_x, py, loader, color='tab:orange', name=name + '_hyp_sampled', save=True)
        plot_hypothesis(all_x, py_naive, loader, color='tab:orange', name=name + '_hyp_naive', save=True)
        plot_var(all_x, py_mcmc, loader, color='tab:orange', name=name + '_var_mc', save=True)
        plot_var(all_x, py, loader, color='tab:orange', name=name + '_var_sampled', save=True)
        plot_var(all_x, py_naive, loader, color='tab:orange', name=name + '_var_naive', save=True)



if args.quant_table == 1:

    name_exp_2 = name_exp + "_alpha" + str(mcmc_vars['alpha']) + "_T" + str(mcmc_vars['T']) + "_n_steps" + str(mcmc_vars['n_steps']) + "_n_traj" + str(mcmc_vars['n_traj'])

    if os.path.exists('./results_metrics_mnist/' + name_exp_2 + '.p'):
        exit()

    metrics = get_metrics(model, loader)

    pickle.dump(metrics, open('./results_metrics_mnist/' + name_exp_2 + '.p', "wb"))



if args.comp_complx == 1:

    def get_time_and_nll(model, model_file_name, loader, args, iters, device):


        if args.do_train == 1 or not os.path.exists("./saved_models/" + model_file_name + ".pt"):
            train(loader, model)
        else:
            model.theta = torch.load("./saved_models/" + model_file_name + ".pt", map_location=device).to(device)

        execution_time = timeit.timeit(
            stmt="model.overall_posterior(all_x, loader)",
            setup="from __main__ import model, all_x, loader",
            number=iters
        )

        x_test, y_test = loader.dataset.x_test, loader.dataset.y_test
        py = model.overall_posterior(x_test, loader)
        mu_py = torch.mean(py, 0)[:, 0]
        var_py = torch.std(py, 0)[:, 0] ** 2
        nll = 0.5 * torch.mean((torch.log(2 * torch.pi * var_py) + (y_test[:, 0] - mu_py) ** 2 / var_py))
        nll_test = nll.detach().cpu().item()

        return execution_time, nll_test


    iters = 2

    all_network_specs = [{'architecture': [[1, 15], [15, 1 + args.prob]], 'activation': nn.Tanh()},
                         {'architecture': [[1, 15], [15, 15], [15, 1 + args.prob]], 'activation': nn.Tanh()},
                         {'architecture': [[1, 15], [15, 15], [15, 15], [15, 1 + args.prob]], 'activation': nn.Tanh()},
                         {'architecture': [[1, 15], [15, 15], [15, 15], [15, 15], [15, 1 + args.prob]], 'activation': nn.Tanh()},
                         {'architecture': [[1, 15], [15, 15], [15, 15], [15, 15], [15, 15], [15, 1 + args.prob]], 'activation': nn.Tanh()}]


    execution_time_la = []
    nll_test_la = []
    execution_time_riem = []
    nll_test_riem = []
    execution_time_mcmc = []
    nll_test_mcmc = []

    for i, network_specs in enumerate(all_network_specs):

        if i < 2:
            continue

        args.model_type = 1
        args.use_riemann = 1
        args.use_linear_network = 0
        setup_vars = setup(args, network_specs=network_specs)
        model, model_file_name, loader, device = setup_vars[0], setup_vars[3], setup_vars[1], setup_vars[9]

        time_and_nll = get_time_and_nll(model, model_file_name, loader, args, iters, device)
        execution_time_riem.append(time_and_nll[0])
        nll_test_riem.append(time_and_nll[1])

        comp_resuts = {}
        comp_resuts['time'] = time_and_nll[0] * 5
        comp_resuts['nll'] = time_and_nll[1]
        name = 'LA_riem=1' + '_layers=' + str(i)
        pickle.dump(comp_resuts, open('./results_computations/' + name + '.p', "wb"))

    print()










