"""
Provides the kernelised sampling-based linearised NN inference routine for 
gprior hyperparameter, :func:``sample_based_marginal_likelihood_optim``.
"""
from typing import Dict, Optional
import os
import socket
import datetime
import torch
import numpy as np
import tensorboardX
from tqdm import tqdm
from torch import Tensor
from .sample_based_mll_optim_utils import (
        PCG_based_linear_map, sample_then_optim_linear_map,
        sample_then_optimise, estimate_effective_dimension, gprior_variance_mackay_update,
        debugging_loglikelihood_estimation, debugging_histogram_tensorboard,
        debugging_uqviz_tensorboard
    )
from bayes_dip.utils import get_mid_slice_if_3d
from bayes_dip.utils import PSNR, SSIM, normalize
from bayes_dip.inference import SampleBasedPredictivePosterior

def sample_based_marginal_likelihood_optim(
    predictive_posterior: SampleBasedPredictivePosterior,
    map_weights: Tensor, 
    observation: Tensor,
    nn_recon: Tensor,
    optim_kwargs: Dict,
    log_path: str = './',
    em_start_step: int = 0,
    ground_truth : Optional[Tensor] = None,
    posterior_obs_samples_sq_sum: Optional[Dict] = None,
    prev_linear_weights: Optional[Tensor] = None,
    verbose=True
    # return_samples=False
    ):

    '''
    Kernelised sampling-based linearised NN inference.
    ``sample_based_marginal_likelihood_optim`` implements Algo. 3 
    in https://arxiv.org/abs/2210.04994.
    '''

    writer = None
    if log_path:
        writer = tensorboardX.SummaryWriter(
            logdir=os.path.join(log_path, '_'.join((
                datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%fZ'),
                socket.gethostname(),
                'marginal_likelihood_sample_based_hyperparams_optim')))
            )

        writer.add_image('nn_recon.', normalize(get_mid_slice_if_3d(nn_recon)[0]), 0)
        writer.add_image('ground_truth', normalize(get_mid_slice_if_3d(ground_truth)[0]), 0)
    observation_cov = predictive_posterior.observation_cov

    with torch.no_grad():

        scale = observation_cov.image_cov.neural_basis_expansion.scale.pow(-1)
        scale_corrected_map_weights = scale*map_weights
        recon_offset = - nn_recon + observation_cov.image_cov.neural_basis_expansion.jvp(
            scale_corrected_map_weights[None, :])
        assert recon_offset.shape[2] == 1
        observation_offset = observation_cov.trafo(recon_offset.squeeze(2)).unsqueeze(2)
        observation_for_lin_optim = observation + observation_offset
        
        linearized_weights = None
        weight_sample = None

        if posterior_obs_samples_sq_sum:
            assert optim_kwargs['iterations'] == 1, 'Only one iteration is allowed when resuming from a checkpoint.'
            linearized_weights = prev_linear_weights

        if optim_kwargs['use_sample_then_optimise']:
            unscaled_weights_sample_from_prior = torch.randn(
                optim_kwargs['num_samples'], observation_cov.image_cov.neural_basis_expansion.num_params, device=observation_cov.device)
            unscaled_eps = torch.randn(
                optim_kwargs['num_samples'], 1, *observation_cov.trafo.obs_shape, device=observation_cov.device)

        with tqdm(range(em_start_step, em_start_step + optim_kwargs['iterations']), desc='sample_based_marginal_likelihood_optim', disable=not verbose) as pbar:
            for i in pbar:
                if not optim_kwargs['use_sample_then_optimise']:
                    linearized_weights, linearized_observation, linearized_recon = PCG_based_linear_map(
                        observation_cov=observation_cov, 
                        observation=observation_for_lin_optim, 
                        cg_kwargs=optim_kwargs['sample_kwargs']['cg_kwargs'],
                    )
                else:
                    # wd = A = variance_coeff^{-1}
                    wd = observation_cov.image_cov.inner_cov.priors.gprior.log_variance.exp().pow(-1)
                    optim_kwargs['sample_kwargs']['weights_linearisation']['optim_kwargs'].update({'wd': wd})
                    use_warm_start = optim_kwargs['sample_kwargs']['weights_linearisation']['optim_kwargs']['use_warm_start']
                    with torch.enable_grad():
                        linearized_weights, linearized_recon = sample_then_optim_linear_map(
                            trafo=observation_cov.trafo, 
                            neural_basis_expansion=observation_cov.image_cov.neural_basis_expansion, 
                            map_weights=scale_corrected_map_weights, 
                            observation=observation_for_lin_optim, 
                            optim_kwargs=optim_kwargs['sample_kwargs']['weights_linearisation']['optim_kwargs'],
                            aux={'ground_truth': ground_truth, 'recon_offset': recon_offset},
                            init_at_previous_weights=prev_linear_weights if use_warm_start else None,
                            name_prefix=f'weight_linearisation_em={i}'
                            )

                linearized_observation = observation_cov.trafo.trafo(linearized_recon)
                linearized_recon = linearized_recon - recon_offset.squeeze(dim=0)
                linearized_observation = linearized_observation - observation_offset
                
                if not posterior_obs_samples_sq_sum:
                    if not optim_kwargs['use_sample_then_optimise']:
                        image_samples = predictive_posterior.sample_zero_mean(
                            num_samples=optim_kwargs['num_samples'],
                            verbose=verbose,
                            **optim_kwargs['sample_kwargs']
                            )
                    else:
                        use_warm_start = optim_kwargs['sample_kwargs']['hyperparams_update']['optim_kwargs']['use_warm_start']
                        weight_sample = sample_then_optimise(
                            observation_cov=observation_cov,
                            neural_basis_expansion=observation_cov.image_cov.neural_basis_expansion, 
                            noise_variance=observation_cov.log_noise_variance.exp().detach(), 
                            variance_coeff=observation_cov.image_cov.inner_cov.priors.gprior.log_variance.exp().detach(), 
                            num_samples=optim_kwargs['num_samples'],
                            optim_kwargs=optim_kwargs['sample_kwargs']['hyperparams_update']['optim_kwargs'],
                            unscaled_weights_sample_from_prior=unscaled_weights_sample_from_prior,
                            unscaled_eps=unscaled_eps,
                            init_at_previous_samples=weight_sample if use_warm_start else None,
                            )
                        
                        torch.save(weight_sample, f'weight_sample_iter_{i}.pt')
                        # Zero mean samples.
                        image_samples = observation_cov.image_cov.neural_basis_expansion.jvp(weight_sample).squeeze(dim=1)

                    # import matplotlib.pyplot as plt

                    # _img_samples = image_samples.detach().cpu().numpy()
                    # _n_samples = _img_samples.shape[0]

                    # fig, axes = plt.subplots(nrows=_n_samples, ncols=4, figsize=(40, _n_samples * 10))

                    # for j in range(_n_samples):
                    #     axes[j, 0].imshow(_img_samples[j, 0], cmap='gray')
                    #     axes[j, 0].axis('off')

                    #     axes[j, 1].imshow(_img_samples[j, 0] + nn_recon.squeeze().detach().cpu().numpy(), cmap='gray')
                    #     axes[j, 1].axis('off')

                    #     axes[j, 2].imshow(_img_samples[j, 0] + recon_offset.squeeze().detach().cpu().numpy(), cmap='gray')
                    #     axes[j, 2].axis('off')

                    #     axes[j, 3].imshow(_img_samples[j, 0] + linearized_recon.squeeze().detach().cpu().numpy(), cmap='gray')
                    #     axes[j, 3].axis('off')

                    # fig.subplots_adjust()
                    # fig.tight_layout()

                    # fig.savefig(f'img_samples_iter_{i}.pdf')
                    # print(f'Saved image samples at iter {i} to {os.getcwd()}.')
                    # plt.close(fig)

                    # fig, axes = plt.subplots(ncols=4, nrows=1, figsize=(40, 10))

                    # # ground truth
                    # axes[0].imshow(ground_truth.squeeze().detach().cpu().numpy(), cmap='gray')
                    # axes[0].axis('off')
                    # axes[0].set_title('Ground truth')

                    # # nn_recon
                    # axes[1].imshow(nn_recon.squeeze().detach().cpu().numpy(), cmap='gray')
                    # axes[1].axis('off')
                    # axes[1].set_title('NN recon')

                    # # std deviation
                    # std_image = image_samples.std(dim=0).squeeze().detach().cpu().numpy()
                    # axes[2].imshow(std_image, cmap='gray')
                    # axes[2].axis('off')
                    # axes[2].set_title('Std deviation')


                    # # second way to calculate std deviation
                    # _, patch_diags = predictive_posterior.log_prob_patches(
                    #     mean=nn_recon,
                    #     ground_truth=ground_truth,
                    #     samples=image_samples,
                    #     reweight_off_diagonal_entries=False,
                    #     return_patch_diags=True,
                    #     verbose=False
                    # )
                    # std_image2 = torch.tensor(patch_diags).reshape(128,128).cpu().numpy()
                    # axes[3].imshow(std_image2, cmap='gray')
                    # axes[3].axis('off')
                    # axes[3].set_title('Patch std deviation')

                    # fig.subplots_adjust()
                    # fig.tight_layout()

                    # fig.savefig(f'std_image_iter_{i}.pdf')
                    # print(f'Saved std deviation image at iter {i} to {os.getcwd()}.')
                    # plt.close(fig)

                    obs_samples = observation_cov.trafo(image_samples)

                    # if return_samples:
                    #     return image_samples, obs_samples
                    posterior_obs_samples_sq_mean = obs_samples.pow(2).sum(dim=0) / obs_samples.shape[0]
                else:
                    posterior_obs_samples_sq_mean = posterior_obs_samples_sq_sum['value'] / posterior_obs_samples_sq_sum['num_samples']

                eff_dim = estimate_effective_dimension(posterior_obs_samples_sq_mean=posterior_obs_samples_sq_mean, 
                        noise_variance=observation_cov.log_noise_variance.exp().detach()
                        ).clamp(min=1, max=np.prod(observation_cov.trafo.obs_shape)-1)
                
                variance_coeff = gprior_variance_mackay_update(
                    eff_dim=eff_dim, map_linearized_weights=linearized_weights
                    )
                observation_cov.image_cov.inner_cov.priors.gprior.log_variance = variance_coeff.log()
                se_loss = (linearized_observation-observation).pow(2).sum()

                if not optim_kwargs['use_sample_then_optimise'] and optim_kwargs['cg_preconditioner'] is not None:
                    optim_kwargs['cg_preconditioner'].update(verbose=verbose)
                
                torch.save(
                    observation_cov.state_dict(), 
                    f'observation_cov_iter_{i}.pt'
                )
                torch.save(
                    variance_coeff, 
                    f'gprior_variance_iter_{i}.pt'
                )


                if writer:
                    writer.add_scalar('variance_coeff', variance_coeff.item(), i)
                    writer.add_scalar('noise_variance', observation_cov.log_noise_variance.data.exp().item(), i)
                    writer.add_image('linearized_model_recon', normalize(get_mid_slice_if_3d(linearized_recon)[0]), i)
                    writer.add_scalar('effective_dimension', eff_dim.item(), i)
                    writer.add_scalar('se_loss', se_loss.item(), i)

                if optim_kwargs['activate_debugging_mode'] and not posterior_obs_samples_sq_sum:
                    if optim_kwargs['use_sample_then_optimise']:
                        print('Log-likelihood is calculated using previous samples, and only mll_optim.num_samples are used.')
                    loglik_nn_model, image_samples_diagnostic = debugging_loglikelihood_estimation(
                        predictive_posterior=predictive_posterior,
                        mean=get_mid_slice_if_3d(nn_recon),
                        ground_truth=get_mid_slice_if_3d(ground_truth),
                        image_samples=None if not optim_kwargs['use_sample_then_optimise'] else image_samples,
                        sample_kwargs=optim_kwargs['sample_kwargs'],
                        loglikelihood_kwargs=optim_kwargs['debugging_mode_kwargs']['loglikelihood_kwargs']
                    )
                    loglik_lin_model, _ = debugging_loglikelihood_estimation(
                        predictive_posterior=predictive_posterior,
                        mean=get_mid_slice_if_3d(linearized_recon),
                        ground_truth=get_mid_slice_if_3d(ground_truth),
                        image_samples=get_mid_slice_if_3d(image_samples_diagnostic),
                        loglikelihood_kwargs=optim_kwargs['debugging_mode_kwargs']['loglikelihood_kwargs']
                    )
                    if writer:
                        writer.add_image('debugging_histogram_nn_model', debugging_histogram_tensorboard(
                            get_mid_slice_if_3d(ground_truth), get_mid_slice_if_3d(nn_recon), 
                            get_mid_slice_if_3d(image_samples_diagnostic))[0], i)
                        writer.add_image('debugging_histogram_lin_model', debugging_histogram_tensorboard(
                            get_mid_slice_if_3d(ground_truth), get_mid_slice_if_3d(linearized_recon), 
                            get_mid_slice_if_3d(image_samples_diagnostic))[0], i)
                        writer.add_image('debugging_histogram_uqviz_nn_model', debugging_uqviz_tensorboard(
                            get_mid_slice_if_3d(ground_truth), get_mid_slice_if_3d(nn_recon), 
                            get_mid_slice_if_3d(image_samples_diagnostic))[0], i)
                        writer.add_scalar('loglik_nn_model',  loglik_nn_model.item(), i)
                        writer.add_scalar('loglik_lin_model', loglik_lin_model.item(), i)

                    if optim_kwargs['debugging_mode_kwargs']['verbose']:
                    
                        print('\n\033[1m' + f'iter: {i}, variance_coeff: {variance_coeff.item():.2E}, ',\
                            f'noise_variance: {observation_cov.log_noise_variance.data.exp().item():.2E}, ',\
                            f'eff_dim: {eff_dim.item():.2E}, se_loss: {se_loss.item():.2E} ',\
                            f'l2: {linearized_weights.pow(2).sum().item():.2E}' + '\033[0m')
                        print('\033[1m' + f'iter: {i}, linearized_recon PSNR: {PSNR(linearized_recon.cpu().numpy(), ground_truth.cpu().numpy()):.2E}, '\
                            f'SSIM: {SSIM(linearized_recon.cpu().numpy()[0, 0], ground_truth.cpu().numpy()[0, 0]):.2E}' + '\033[0m')
                        print('\033[1m' + f'iter: {i}, loglik_nn_model: {loglik_nn_model:.2E}, loglik_lin_model: {loglik_lin_model:.2E}\n' + '\033[0m')
    
    # if return_samples:
    #     return image_samples, obs_samples
    return linearized_weights, linearized_recon