from glob import glob
import os
from itertools import islice
from omegaconf import OmegaConf
import torch
import numpy as np

from bayes_dip.dip.deep_image_prior import DeepImagePriorReconstructor
from bayes_dip.inference.sample_based_predictive_posterior import SampleBasedPredictivePosterior
from bayes_dip.marginal_likelihood_optim.preconditioner import get_preconditioner
from bayes_dip.marginal_likelihood_optim.sample_based_mll_optim import sample_based_marginal_likelihood_optim
from bayes_dip.marginal_likelihood_optim.utils import get_ordered_nn_params_vec
from bayes_dip.probabilistic_models.linearized_dip.default_unet_priors import get_default_unet_gprior_dicts
from bayes_dip.probabilistic_models.linearized_dip.image_cov import ImageCov
from bayes_dip.probabilistic_models.linearized_dip.neural_basis_expansion.getter import get_neural_basis_expansion
from bayes_dip.probabilistic_models.linearized_dip.parameter_cov import ParameterCov
from bayes_dip.probabilistic_models.observation_cov import ObservationCov
from bayes_dip.marginal_likelihood_optim.sample_based_mll_optim_utils import sample_then_optimise
from bayes_dip.data.trafo import BaseRayTrafo


def get_dip_data(ground_truth, ray_trafo, device):
    """
    Compute the observation data for the bayes_dip model
    """
    hr = ground_truth.unsqueeze(0).unsqueeze(0).to(device)
    obs = ray_trafo(hr).to(device)
    fbp = ray_trafo.fbp(obs).to(device)
    return obs, hr, fbp

def convert_to_dip_angle(our_angle):
    """
    Convert rotation angle to dip angle in degree
    """
    return (- our_angle + 90) % 180

def get_dip_mask(our_angles, ray_trafo):
    """
    Generate a mask for selecting the angles in the ray transform
    """
    # mask = np.isin(np.rad2deg(ray_trafo.angles), test_elements = convert_to_dip_angle(our_angles).cpu().numpy())
    mask = np.zeros(len(ray_trafo.angles), dtype=bool)
    dip_angles_deg = np.rad2deg(ray_trafo.angles)
    for angle in convert_to_dip_angle(our_angles).cpu().numpy():
        mask = mask | (np.abs(dip_angles_deg - angle) < 1e-4)
    assert len(our_angles) == mask.sum()
    return mask

class SubsetRayTrafo(BaseRayTrafo):
    """
    Helper class to generate a RayTrafo object for a select set of angles.
    Note the implementation is inefficient in that it relies on generating the full transform first.
    """
    def __init__(self, full_ray_trafo):
        super().__init__(full_ray_trafo.im_shape, full_ray_trafo.obs_shape)
        self.full_ray_trafo = full_ray_trafo
        self.mask = None

    def set_angular_mask(self, mask):
        self.mask = np.where(mask)[0]
        self.inverse_mask = np.where(~mask)[0]
        self.obs_shape = (len(self.mask), self.obs_shape[1])

    @property
    def angles(self):
        return self.full_ray_trafo.angles[self.mask]
    
    def trafo(self, x):
        return self.full_ray_trafo.trafo(x)[:,:,self.mask].contiguous()

    def forward(self, x):
        return self.full_ray_trafo.forward(x)[:,:,self.mask].contiguous()

    def fbp(self, observation):
        return self.full_ray_trafo.fbp(observation[:,:, self.mask])

    def trafo_adjoint(self, observation):
        shape = (observation.shape[0], observation.shape[1], self.full_ray_trafo.obs_shape[0], observation.shape[3])
        full_observation = torch.zeros(shape, device=observation.device, dtype=observation.dtype)
        full_observation[:, :, self.mask] = observation
        return self.full_ray_trafo.trafo_adjoint(full_observation)
    
    trafo_flat = BaseRayTrafo._trafo_flat_via_trafo
    trafo_adjoint_flat = BaseRayTrafo._trafo_adjoint_flat_via_trafo_adjoint

def get_unet_reconstructor(im_shape, cfg, device):
    """
    Create unet model
    """
    net_kwargs = OmegaConf.to_object(cfg.dip.net)
    reconstructor = DeepImagePriorReconstructor(
        im_shape, torch_manual_seed=cfg.dip.torch_manual_seed,
        device=device, net_kwargs=net_kwargs,
        load_params_path=cfg.load_pretrained_dip_params)
    return reconstructor
    

def train_unet(reconstructor : DeepImagePriorReconstructor, observation, cfg, unet_input, ray_trafo=None, lr=0.003, iterations=1000, ground_truth=None, log_path=None, verbose=False):
    """
    Train unet model
    """
    optim_kwargs = {
        'lr': lr,
        'iterations': iterations,
        'loss_function': cfg.dip.optim.loss_function,
        'gamma': cfg.dip.optim.gamma
        }
    
    recon = reconstructor.reconstruct(observation,
        filtbackproj=unet_input,
        ground_truth=ground_truth,
        recon_from_randn=False,
        log_path=log_path,
        optim_kwargs=optim_kwargs,
        ray_trafo=ray_trafo,
        show_pbar=verbose)

    return recon


def get_posterior(reconstructor, ray_trafo, cfg, unet_input, device, verbose=False):
    """
    Create Laplace posterior
    """
    prior_assignment_dict, hyperparams_init_dict = get_default_unet_gprior_dicts(
        nn_model=reconstructor.nn_model, 
        gprior_hyperparams_init={'variance': cfg.priors.gprior.init_prior_variance_value}
    )
    
    parameter_cov = ParameterCov(
        reconstructor.nn_model,
        prior_assignment_dict,
        hyperparams_init_dict,
        device=device
    )

    # if cfg.load_gprior_scale_from_path is not None:
    #     # 3D requires pre-computing and loading g-prior scale vec
    #     load_scale_from_path = os.path.join(
    #             cfg.load_gprior_scale_from_path,
    #                 f'gprior_scale_vector_{i}.pt')
    # else:
    #     load_scale_from_path = None

    load_scale_from_path = None
    scale_kwargs = OmegaConf.to_object(cfg.priors.gprior.scale)
    scale_kwargs['verbose'] = verbose
    neural_basis_expansion = get_neural_basis_expansion(
        nn_model=reconstructor.nn_model,
        nn_input=unet_input,
        ordered_nn_params=parameter_cov.ordered_nn_params,
        nn_out_shape=unet_input.shape,
        use_gprior=True, # requires the g-prior assumption
        trafo=ray_trafo,
        load_scale_from_path=load_scale_from_path,
        scale_kwargs=scale_kwargs,
    )

    image_cov = ImageCov(parameter_cov=parameter_cov,
            neural_basis_expansion=neural_basis_expansion
    )

    # sample-based MLL based methods do not optimise noise variance, i.e. fixed to 1.
    observation_cov = ObservationCov(
            trafo=ray_trafo,
            image_cov=image_cov, 
            device=device
    )

    # # if `m_step==0` setting g-prior to init value
    # if em_step > 0:
    #     assert load_previous_observation_cov_from_path is not None
    #     # if `m_step>0` overwrite g_prior variance with the `em_step-1` optimised one
    #     observation_cov.load_state_dict(torch.load(
    #         os.path.join(load_previous_observation_cov_from_path, f'observation_cov_iter_{em_step - 1}.pt')))           

    optim_kwargs = {
        'iterations': cfg.mll_optim.iterations,
        'activate_debugging_mode': cfg.mll_optim.activate_debugging_mode,
        'num_samples': cfg.mll_optim.num_samples,
        'use_sample_then_optimise': cfg.mll_optim.use_sample_then_optimise
        }
    optim_kwargs['sample_kwargs'] = OmegaConf.to_object(cfg.mll_optim.sampling)
    precon_kwargs = OmegaConf.to_object(cfg.mll_optim.preconditioner)

    if cfg.load_sample_based_precon_state_from_path is not None:
        precon_kwargs['load_approx_basis'] = os.path.join(
            cfg.load_sample_based_precon_state_from_path, f'preconditioner_{i}.pt')
        precon_kwargs['load_state_dict'] = os.path.join(
            cfg.load_sample_based_precon_state_from_path, f'observation_cov_{i}.pt')

    cg_preconditioner = None
    if cfg.mll_optim.use_preconditioner:
        cg_preconditioner = get_preconditioner(observation_cov=observation_cov, kwargs=precon_kwargs, verbose=verbose)
        optim_kwargs['sample_kwargs']['cg_kwargs']['precon_closure'] = cg_preconditioner.get_closure()
    optim_kwargs['cg_preconditioner'] = cg_preconditioner
    if cfg.mll_optim.activate_debugging_mode: optim_kwargs['debugging_mode_kwargs'] = OmegaConf.to_object(
            cfg.mll_optim.debugging_mode_kwargs)

    predictive_posterior = SampleBasedPredictivePosterior(observation_cov)
    # posterior_obs_samples_sq_sum = {} # to compute eff. dims in 3D 
    # prev_linear_weights = None
    # if load_previous_em_step_from_path is not None:
    #     post_sample_sq_sum_paths = glob(
    #             os.path.join(load_previous_em_step_from_path, f'posterior_obs_samples_sq_sum_{i}_em={em_step}_seed=*.pt'))
    #     for k, path in enumerate(post_sample_sq_sum_paths):
    #         print(f'Loading sample from : ', path)
    #         posterior_obs_samples_sq_sum_i = torch.load(path)
    #         if k == 0:
    #             posterior_obs_samples_sq_sum['value'] = posterior_obs_samples_sq_sum_i['value']
    #             posterior_obs_samples_sq_sum['num_samples'] = posterior_obs_samples_sq_sum_i['num_samples']
    #         else:
    #             posterior_obs_samples_sq_sum['value'] += posterior_obs_samples_sq_sum_i['value']
    #             posterior_obs_samples_sq_sum['num_samples'] += posterior_obs_samples_sq_sum_i['num_samples']
        
    #     prev_linear_weights = torch.load(f'linearized_weights_em={em_step - 1}_{i}.pt')

    return predictive_posterior, optim_kwargs


def optimize_unet(recon, predictive_posterior, prev_linear_weights, observation,  optim_kwargs, cfg, ground_truth=None, iterations=None, verbose=False):
    parameter_cov = predictive_posterior.observation_cov.image_cov.inner_cov
    posterior_obs_samples_sq_sum = {} 
    em_step = 0

    if iterations:
        optim_kwargs['iterations'] = iterations

    linearized_weights, linearized_recon = sample_based_marginal_likelihood_optim(
        predictive_posterior=predictive_posterior,
        map_weights=get_ordered_nn_params_vec(parameter_cov).clone(),
        observation=observation,
        nn_recon=recon,
        ground_truth=ground_truth,
        optim_kwargs=optim_kwargs,
        log_path=None,
        posterior_obs_samples_sq_sum=posterior_obs_samples_sq_sum,
        em_start_step=em_step,
        prev_linear_weights=prev_linear_weights,
        verbose=verbose
    )

    return linearized_weights, linearized_recon


def sample_unet(predictive_posterior, optim_kwargs, num_samples=10, sample_then_optimize=False, batch_size=None, verbose=False):
    if sample_then_optimize:
        observation_cov = predictive_posterior.observation_cov

        unscaled_weights_sample_from_prior = torch.randn(num_samples, observation_cov.image_cov.neural_basis_expansion.num_params, device=observation_cov.device)
        unscaled_eps = torch.randn(
            num_samples, 1, *observation_cov.trafo.obs_shape, device=observation_cov.device)
        
        use_warm_start = optim_kwargs['sample_kwargs']['hyperparams_update']['optim_kwargs']['use_warm_start']
        weight_sample = sample_then_optimise(
            observation_cov=observation_cov,
            neural_basis_expansion=observation_cov.image_cov.neural_basis_expansion, 
            noise_variance=observation_cov.log_noise_variance.exp().detach(), 
            variance_coeff=observation_cov.image_cov.inner_cov.priors.gprior.log_variance.exp().detach(), 
            num_samples=num_samples,
            optim_kwargs=optim_kwargs['sample_kwargs']['hyperparams_update']['optim_kwargs'],
            unscaled_weights_sample_from_prior=unscaled_weights_sample_from_prior,
            unscaled_eps=unscaled_eps,
            init_at_previous_samples=weight_sample if use_warm_start else None,
        )
        
        # torch.save(weight_sample, f'weight_sample_iter_{i}.pt')
        # Zero mean samples.
        image_samples = observation_cov.image_cov.neural_basis_expansion.jvp(weight_sample).squeeze(dim=1)
        return image_samples
    else:
        if 'batch_size' in optim_kwargs['sample_kwargs']:
            optim_kwargs['sample_kwargs'].pop('batch_size')
        image_samples = predictive_posterior.sample_zero_mean(
            num_samples=num_samples,
            batch_size=batch_size if batch_size else num_samples,
            verbose=verbose,
            **optim_kwargs['sample_kwargs']
        )

    return image_samples
