import torch
from tqdm import tqdm
from zmq import log

from ..targets.target_distribution import TargetDistribution

from ..kernels import Kernel
from ..utils.kernel_state import KernelState

from .base import Sampler

from sampling.targets import TargetDistribution

class DenoisingTarget(TargetDistribution):
    r"""
    Defines the denoising posterior distribution p(x|\tilde{x}) used in DiGS.
    
    The log probability is given by: log p(x|\tilde{x}) \prop log p(x) - || \alpha x - \tilde{x}||² / (2\sigma^2)
    
    Arguments:
        target: The original target distribution p(x).
        x_tilde: The noisy sample \tilde{x}.
        alpha: The contraction factor for the noise schedule.
        sigma: The standard deviation of the Gaussian noise.
    """
    def __init__(self, target: TargetDistribution, x_tilde: torch.Tensor, alpha: float, sigma: float):
        super().__init__()
        self.target = target
        self.x_tilde = x_tilde
        self.alpha = alpha
        self.sigma = sigma
    
    @property
    def dim(self) -> int:
        return self.target.dim

    def log_prob(self, x: torch.Tensor) -> torch.Tensor:
        """
        Computes the log probability of the denoising posterior.
        
        Arguments:
            x: (batch_size, dim) tensor of clean samples.
            
        Returns:
            log_prob: (batch_size, 1) tensor of log probabilities.
        """
        log_p_x = self.target.log_prob(x)
        
        # Calculate the quadratic denoising term from Eq. 11
        norm_sq = torch.sum((self.alpha * x - self.x_tilde)**2, dim=-1, keepdim=True)
        denoising_term = -0.5 * norm_sq / (self.sigma**2)
        
        return log_p_x + denoising_term

class DiGS(Sampler):
    """
    Diffusive Gibbs Sampling (DiGS) algorithm.
    
    Implements the DiGS sampler as described in arXiv:2402.03008. This method is
    designed for effective sampling from multi-modal distributions by integrating
    diffusion models with Gibbs sampling.
    
    Arguments:
        target: The target distribution to sample from.
        alpha_schedule: A pre-defined schedule for the alpha parameter.
        sigma_schedule: A pre-defined schedule for the sigma parameter.
        n_noise_levels: Number of noise levels for the schedule (if not provided).
        alpha_min: The minimum value of alpha for the generated schedule.
        alpha_max: The maximum value of alpha for the generated schedule.
        n_gibbs_sweeps: Number of Gibbs sweeps to perform at each noise level.
        n_denoising_steps: Number of MCMC steps for the inner denoising sampler.
        denoising_sampler: The MCMC kernel to use for the denoising step ('la', 'mala', etc.).
        denoising_step_size: Step size for the denoising MCMC kernel.
        n_leapfrog_steps: Number of leapfrog steps for HMC-based samplers.
        verbose: If True, display a progress bar during sampling.
    """
    def __init__(
        self, 
        target: TargetDistribution,
        denoising_kernel: Kernel,
        alpha_schedule: torch.Tensor,
        sigma_schedule: torch.Tensor,
        n_gibbs_sweeps: int = 100,
        n_denoising_steps: int = 10,
        denoising_step_size: float = 0.1,
        verbose: bool = False,
        compile: bool = False
    ):
        super().__init__(target=target, verbose=verbose)
        
        self.denoising_kernel = denoising_kernel
        self.n_gibbs_sweeps = n_gibbs_sweeps
        self.n_denoising_steps = n_denoising_steps
        self.register_buffer(
            "denoising_step_size", 
            torch.tensor(denoising_step_size, dtype=torch.float32)
        )

        self._compile = compile

        if alpha_schedule is None or sigma_schedule is None:
            raise ValueError("Both alpha_schedule and sigma_schedule must be provided.")
        
        if len(alpha_schedule) != len(sigma_schedule):
            raise ValueError("alpha_schedule and sigma_schedule must have the same length.")

        self.alpha_schedule = alpha_schedule
        self.sigma_schedule = sigma_schedule
        self.n_noise_levels = len(alpha_schedule)

        self._denoising_step_fn = lambda state, x_tilde, alpha, sigma, step_size: self.denoising_kernel.step(
            DenoisingTarget(self.target, x_tilde, alpha, sigma),
            state,
            step_size=step_size
        )

        if self._compile:
            self._denoising_step_fn = torch.compile(self._denoising_step_fn)

    def build_initial_point(
        self,
        n_samples: int = 1,
        device: torch.device = torch.device("cpu"),
        dtype: torch.dtype = torch.float32
    ) -> torch.Tensor:
        """
        Build an initial point for the kernel. For DiGS, this is a tensor
        of shape (n_samples, dim).
        """
        return torch.randn(n_samples, self.dim, device=device, dtype=dtype)
        
    def _metropolis_within_gibbs_init(self, x_prev, x_tilde, alpha, sigma):
        """Performs the Metropolis-within-Gibbs step to initialize the denoising sampler."""
        # Proposal: q(x|~x) = N(x | ~x/α, (σ/α)²I) from Eq. 14
        proposal_mean = x_tilde / alpha
        proposal_std = sigma / alpha
        x_init_prime = proposal_mean + proposal_std * torch.randn_like(x_prev)
        
        denoising_target = DenoisingTarget(self.target, x_tilde, alpha, sigma)
        
        # Log-posteriors for MH acceptance probability (Eq. 15)
        log_p_prime = denoising_target.log_prob(x_init_prime)
        log_p_prev = denoising_target.log_prob(x_prev)
        
        # Log-proposal ratio: log(q(x_prev|~x)/q(x_init_prime|~x))
        log_q_ratio = -0.5 * (
            torch.sum(((x_prev - proposal_mean) / proposal_std)**2, dim=-1, keepdim=True) -
            torch.sum(((x_init_prime - proposal_mean) / proposal_std)**2, dim=-1, keepdim=True)
        )
        
        log_acc_ratio = (log_p_prime - log_p_prev) + log_q_ratio
        
        # Accept or reject the proposed initialization
        accepted = (torch.rand_like(log_acc_ratio).log() < log_acc_ratio)
        
        x_init = torch.where(accepted, x_init_prime, x_prev)
        
        return x_init, accepted.float().mean()

    def _denoising_mcmc_run(self, x_init, x_tilde, alpha, sigma):
        """Runs the inner MCMC sampler for the denoising step."""
        denoising_target = DenoisingTarget(self.target, x_tilde, alpha, sigma)
        grad_init, log_prob_init = denoising_target.grad_log_prob(x_init, return_log_prob=True)
        state = KernelState(x=x_init, log_prob=log_prob_init, grad=grad_init)       
        for _ in range(self.n_denoising_steps):
            # state = self.denoising_kernel.step(denoising_target, state, step_size=self.denoising_step_size)
            state, _ = self._denoising_step_fn(state, x_tilde, alpha, sigma, self.denoising_step_size)
        return state.x

    def step(self, x: torch.Tensor, step_id: int = 0):
        """
        Performs a single step of the DiGS algorithm, which corresponds to one
        Gibbs sweep (noising + denoising).
        
        Arguments:
            x: Current clean samples (batch_size, dim).
            step_id: The global step index, used to determine the noise level.
            
        Returns:
            x_new: The new clean samples after one Gibbs sweep.
            mwg_acc_rate: The acceptance rate of the Metropolis-within-Gibbs step.
        """
        noise_level_idx = min(step_id // self.n_gibbs_sweeps, self.n_noise_levels - 1)
        
        alpha = self.alpha_schedule[noise_level_idx].to(x.device)
        sigma = self.sigma_schedule[noise_level_idx].to(x.device)

        # 1. Noising step: x -> ~x (Eq. 10)
        x_tilde = alpha * x + sigma * torch.randn_like(x)
        
        # 2. Denoising step: ~x -> x'
        # 2a. Metropolis-within-Gibbs initialization
        x_init, mwg_acc_rate = self._metropolis_within_gibbs_init(x, x_tilde, alpha, sigma)
        
        # 2b. Denoising MCMC run
        x_new = self._denoising_mcmc_run(x_init, x_tilde, alpha, sigma)
        
        return x_new, mwg_acc_rate

    def forward(
        self, 
        x0: torch.Tensor,
        return_trajectory: bool = False,
        return_log_acceptance: bool = False,
    ):
        """
        Runs the full DiGS sampling chain.
        
        Arguments:
            x0: The starting point (batch_size, dim).
            return_trajectory: If True, returns the entire history of samples.
            return_log_acceptance: If True, returns the acceptance rates of the MwG step.
        """

        n_steps = self.n_gibbs_sweeps * self.n_noise_levels

        if self.verbose:
            pbar = tqdm(total=n_steps, desc="DiGS")

        x = x0.clone()
        
        xs = [x.clone().cpu()] if return_trajectory else None
        log_acceptances = [] if return_log_acceptance else None

        for i in range(n_steps):
            x, mwg_acc_rate = self.step(x, i)
            
            if self.verbose:
                pbar.set_postfix(mwg_acc=f"{mwg_acc_rate:.2f}")
                pbar.update(1)
            
            if return_trajectory:
                xs.append(x.clone().cpu())
            if return_log_acceptance:
                # Note: This is the acceptance rate, not log probability, for monitoring purposes.
                log_acceptances.append(mwg_acc_rate)

        if self.verbose:
            pbar.close()
        
        result = []
        if return_trajectory:
            result.append(torch.stack(xs))
        else:
            result.append(x)
        
        if return_log_acceptance:
            result.append(torch.tensor(log_acceptances))

        return tuple(result) if len(result) > 1 else result[0]

    def sample(
        self,
        n_samples: int,
        device: torch.device = torch.device("cpu"),
        dtype: torch.dtype = torch.float32,
    ) -> torch.Tensor:
        """
        Generate samples from the target distribution.

        Arguments:
            n_samples: The number of particles (samples) to generate.
            device: The device to place the tensors on.
            dtype: The data type of the tensors.

        Returns:
            (n_samples, dim) tensor of final samples.
        """
        # Initial particles are typically drawn from a simple distribution like N(0, I)
        x0 = torch.randn(n_samples, self.dim, device=device, dtype=dtype)
        return self.forward(x0, return_trajectory=False)

    def sample_trajectory(
        self,
        n_samples: int,
        device: torch.device = torch.device("cpu"),
        dtype: torch.dtype = torch.float32,
    ) -> torch.Tensor:
        """
        Generate a trajectory of samples at each annealing step.

        Arguments:
            n_samples: The number of particles (samples) to generate.
            device: The device to place the tensors on.
            dtype: The data type of the tensors.

        Returns:
            (n_steps + 1, n_samples, dim) tensor of sample trajectories.
        """
        x0 = torch.randn(n_samples, self.dim, device=device, dtype=dtype)
        return self.forward(x0, return_trajectory=True)