import numpy as np
import torch
from scipy.stats import norm
from tqdm import tqdm

from subspace_inference.posteriors.importance_sampler import ImportanceSampler
from subspace_inference import models, losses, utils


def eval_predict(space_model, train_loader, test_loader, device, args):
    output_dim = len(test_loader.dataset)
    result_ess = {}
    result_qmc = {}
    result_nuts = {}
    result_vi = {}
    if args.task_type in [20, 21, 22, 23, 24,
                          30, 31, 32, 33, 34, 35
                          ]:
        # for UCI datasets
        uci_eval = True
        ess_sample_size = 30
        ess_burn_in = 270
        qmc_sample_size = 1024
        nuts_sample_size = 30
        nuts_thinning = 5
        nuts_burn_in = 300
        vi_num_samples = 2000
        vi_epochs = 2000
        regression = True
        y_test = args.full_dataset.Y_test.squeeze()
        y_std = args.full_dataset.Y_std.item()
    else:
        # for CIFAR datasets
        uci_eval = False
        ess_sample_size = 30
        ess_burn_in = 30
        qmc_sample_size = 128
        nuts_sample_size = 30
        nuts_thinning = 3
        nuts_burn_in = 30
        vi_num_samples = 30
        vi_epochs = 100
        regression = False
    sampler = ImportanceSampler(base=args.model_cfg.base, criterion=args.inference_criterion, proposal_var=args.proposal_var,
                                temperature=args.temperature,
                                loader=train_loader, subspace=space_model, data=None, proposal_type="gaussian", deg_f=None,
                                device=device, prior_scale=args.prior_scale, *args.model_cfg.args, **args.model_cfg.kwargs)
    if train_loader is not None:
        ipt = torch.tensor([0., 0.], device=device)
        w = space_model(ipt)
        offset = 0
        for param in sampler.base_model.parameters():
            param.data.copy_(w[offset:offset + param.numel()].view(param.size()).to(device))
            offset += param.numel()
        utils.bn_update(train_loader, sampler.base_model, subset=1.0, device=device)
    # generate ESS samples
    if args.calc_ess:
        from subspace_inference.posteriors.ess import EllipticalSliceSampling
        print("draw posterior predictive with ESS.")
        ess_model = EllipticalSliceSampling(
            base=args.model_cfg.base,
            subspace=space_model,
            var=None,
            loader=train_loader,
            criterion=args.inference_criterion,
            num_samples=ess_burn_in + ess_sample_size,
            use_cuda=True,
            device=device,
            *args.model_cfg.args,
            **args.model_cfg.kwargs
        )
        ess_model.fit(temperature=args.temperature, scale=args.proposal_var)
        ess_samples = torch.from_numpy(ess_model.all_samples).t()[-ess_sample_size:].type(torch.FloatTensor).to(device)
        ess_cost = ess_model.density_calc_times
        result_ess['cost'] = ess_cost
        result_ess['samples'] = ess_samples.cpu().numpy()
        if uci_eval:
            outputs_ess = torch.zeros((ess_sample_size, output_dim, 2))
            for i in tqdm(range(ess_sample_size)):
                outputs_ess[i, :] = sampler.gen_data_with_loader(test_loader, ess_samples[i])

            # compute mean
            m = torch.mean(outputs_ess[:, :, 0], dim=0).numpy()
            # compute var (using mean of second order moments)
            v = torch.mean(outputs_ess[:, :, 0] ** 2 + outputs_ess[:, :, 1], dim=0).numpy() - m ** 2

            l = norm.logpdf(y_test, loc=m, scale=v ** 0.5)
            result_ess['test_loglik'] = np.average(l)
            lu = norm.logpdf(y_test * y_std, loc=m * y_std, scale=(v ** 0.5) * y_std)
            result_ess['test_loglik_unnormalized'] = np.average(lu)
            d = y_test - m
            std = v ** 0.5
            cal = (d < 1.96 * std) * (d > -1.96 * std)
            du = d * y_std
            result_ess['test_mae'] = np.average(np.abs(d))
            result_ess['test_mae_unnormalized'] = np.average(np.abs(du))
            result_ess['test_rmse'] = np.average(d ** 2) ** 0.5
            result_ess['test_rmse_unnormalized'] = np.average(du ** 2) ** 0.5
            result_ess['test_calibration'] = np.average(cal)

    # QMC-IS
    if args.calc_qmc:
        print("draw posterior predictive with QMC-IS.")
        proposal_sample, weights = sampler.sampling_with_weights(qmc_sample_size, enable_qmc=True, enable_tqdm=True)
        weights = weights.cpu()
        result_qmc['cost'] = qmc_sample_size
        result_qmc['samples'] = proposal_sample.cpu().numpy()
        result_qmc['weights'] = weights.numpy()
        if uci_eval:
            outputs_qmc = torch.zeros((qmc_sample_size, output_dim, 2))
            for i in tqdm(range(qmc_sample_size)):
                outputs_qmc[i, :] = sampler.gen_data_with_loader(test_loader, proposal_sample[i])

            # compute mean
            # weight shape: (sample_size)
            # outputs_qmc shape: (sample_size, output_dim, 2)
            m = torch.sum(outputs_qmc[:, :, 0] * weights[:, None], dim=0).numpy()
            # compute var (using mean of second order moments)
            v = torch.sum((outputs_qmc[:, :, 0] ** 2 + outputs_qmc[:, :, 1]) * weights[:, None], dim=0).numpy() - m ** 2

            l = norm.logpdf(y_test, loc=m, scale=v ** 0.5)
            result_qmc['test_loglik'] = np.average(l)
            lu = norm.logpdf(y_test * y_std, loc=m * y_std, scale=(v ** 0.5) * y_std)
            result_qmc['test_loglik_unnormalized'] = np.average(lu)
            d = y_test - m
            std = v ** 0.5
            cal = (d < 1.96 * std) * (d > -1.96 * std)
            du = d * y_std
            result_qmc['test_mae'] = np.average(np.abs(d))
            result_qmc['test_mae_unnormalized'] = np.average(np.abs(du))
            result_qmc['test_rmse'] = np.average(d ** 2) ** 0.5
            result_qmc['test_rmse_unnormalized'] = np.average(du ** 2) ** 0.5
            result_qmc['test_calibration'] = np.average(cal)

    # NUTS
    if args.calc_nuts:
        print("draw posterior predictive with NUTS.")
        nuts_samples = sampler.sampling_with_nuts(nuts_sample_size, test_loader, nuts_thinning, nuts_burn_in)
        nuts_cost = sampler.post_eval_times
        result_nuts['cost'] = nuts_cost
        result_nuts['samples'] = nuts_samples.cpu().numpy()
        if uci_eval:
            outputs_nuts = torch.zeros((nuts_sample_size, output_dim, 2))
            for i in tqdm(range(nuts_sample_size)):
                outputs_nuts[i, :] = sampler.gen_data_with_loader(test_loader, nuts_samples[i])

            # compute mean
            m = torch.mean(outputs_nuts[:, :, 0], dim=0).numpy()
            # compute var (using mean of second order moments)
            v = torch.mean(outputs_nuts[:, :, 0] ** 2 + outputs_nuts[:, :, 1], dim=0).numpy() - m ** 2

            l = norm.logpdf(y_test, loc=m, scale=v ** 0.5)
            result_nuts['test_loglik'] = np.average(l)
            lu = norm.logpdf(y_test * y_std, loc=m * y_std, scale=(v ** 0.5) * y_std)
            result_nuts['test_loglik_unnormalized'] = np.average(lu)
            d = y_test - m
            std = v ** 0.5
            cal = (d < 1.96 * std) * (d > -1.96 * std)
            du = d * y_std
            result_nuts['test_mae'] = np.average(np.abs(d))
            result_nuts['test_mae_unnormalized'] = np.average(np.abs(du))
            result_nuts['test_rmse'] = np.average(d ** 2) ** 0.5
            result_nuts['test_rmse_unnormalized'] = np.average(du ** 2) ** 0.5
            result_nuts['test_calibration'] = np.average(cal)

    if args.calc_vi:
        print("draw posterior predictive with VI.")
        from subspace_inference.posteriors.vi_model import VIModel, ELBO
        import math
        init_sigma = 1.
        prior_sigma = 5.
        vi_model = VIModel(
            subspace=space_model,
            init_inv_softplus_sigma=math.log(math.exp(init_sigma) - 1.0),
            prior_log_sigma=math.log(prior_sigma),
            base=args.model_cfg.base,
            device=device,
            *args.model_cfg.args,
            **args.model_cfg.kwargs
        )
        vi_model = vi_model.to(device)
        elbo = ELBO(args.inference_criterion, len(train_loader.dataset), temperature=args.temperature)
        optimizer = torch.optim.Adam([param for param in vi_model.parameters()], lr=.1)
        for epoch in tqdm(range(vi_epochs)):
            train_res = utils.train_epoch(train_loader, vi_model, elbo, optimizer, device=device, regression=regression)
            sigma = torch.nn.functional.softplus(vi_model.inv_softplus_sigma.detach().cpu())
            if epoch == int(vi_epochs / 2):
                utils.adjust_learning_rate(optimizer, 0.01)
        with torch.no_grad():
            vi_samples = vi_model.sample_from_post(vi_num_samples)
        result_vi['cost'] = vi_num_samples
        result_vi['samples'] = vi_samples.cpu().numpy()
        if uci_eval:
            outputs_vi = torch.zeros((vi_num_samples, output_dim, 2))
            for i in tqdm(range(vi_num_samples)):
                outputs_vi[i, :] = sampler.gen_data_with_loader(test_loader, vi_samples[i])

            # compute mean
            m = torch.mean(outputs_vi[:, :, 0], dim=0).numpy()
            # compute var (using mean of second order moments)
            v = torch.mean(outputs_vi[:, :, 0] ** 2 + outputs_vi[:, :, 1], dim=0).numpy() - m ** 2

            l = norm.logpdf(y_test, loc=m, scale=v ** 0.5)
            result_vi['test_loglik'] = np.average(l)
            lu = norm.logpdf(y_test * y_std, loc=m * y_std, scale=(v ** 0.5) * y_std)
            result_vi['test_loglik_unnormalized'] = np.average(lu)
            d = y_test - m
            std = v ** 0.5
            cal = (d < 1.96 * std) * (d > -1.96 * std)
            du = d * y_std
            result_vi['test_mae'] = np.average(np.abs(d))
            result_vi['test_mae_unnormalized'] = np.average(np.abs(du))
            result_vi['test_rmse'] = np.average(d ** 2) ** 0.5
            result_vi['test_rmse_unnormalized'] = np.average(du ** 2) ** 0.5
            result_vi['test_calibration'] = np.average(cal)

    return result_ess, result_qmc, result_nuts, result_vi
