import torch

from dataclasses import dataclass
from typing import Callable, Optional

@dataclass
class MCMCConfig:
    n_iter: int
    kernel_fn: Callable
    kernel_params: dict
    grad_logpdf_fn: Optional[Callable] = None

def ula_kernel(state, logpdf_fn, params, grad_logpdf_fn):
    """
    One step of the ULA algorithm.

    Args:
        state: Tensor of shape, e.g. [batch_size, r, mvn_dim] for GM models
        logpdf_fn: ignored for ULA kernel
        params: dictionnary {"step_size": ..., "sqrt_2eps": ...} (Langevin step size epsilon)
        grad_logpdf_fn: function taking `x` and returning gradient of log_prob(x)

    Returns:
        new_state: Tensor with same shape as state
    """
    z = torch.randn_like(state)  
    grad = grad_logpdf_fn(state)
    grad = torch.clamp(grad, min=-100.0, max=100.0)
    return state + params["step_size"] * grad + params["sqrt_2eps"] * z


def mh_accept_step(state, proposal, logpdf_fn, params, extra_log_ratio=0.0, grad=None):
    """
    Metropolis-Hastings acceptance step.

    Args:
        state: current state before proposal
        proposal: proposed new state
        logpdf_fn: function taking `x` and returning log_prob(x)
        params: dictionary containing the log probability of the current state 
        extra_log_ratio: additional log probability ratio (optional)
        grad: gradient of log probability (optional)

    Returns:
        next_state: accepted or rejected state
    """
    # Compute log ratio
    log_prob_prev = params["log_prob"] #log pi (x)
    log_prob_prop = logpdf_fn(proposal) #log pi (x')
    log_ratio = log_prob_prop - log_prob_prev + extra_log_ratio # extra_log_ratio = log q(x/x') - log q(x'/x)
    
    accept = torch.rand_like(log_ratio).log() < log_ratio # U (uniform(0,1)) < alpha (x,x'))
    log_prob_updated = torch.where(accept, log_prob_prop, log_prob_prev)

    # Broadcast accept to state shape
    accept = accept.view(*accept.shape, *[1] * (state.ndim - accept.ndim))
    next_state = torch.where(accept, proposal, state)

    # Update log probability and gradient to avoid recomputation
    params["log_prob"] = log_prob_updated
    if grad is not None:
        params["grad_log_prob"] = torch.where(accept, grad["prop"], grad["prev"])
    return next_state
    
    
def rw_kernel(state, logpdf_fn, params, grad_logpdf_fn=None):
    """
    One step of the Random Walk Metropolis-Hastings algorithm.
    Random-Walk Metropolis kernel supporting batch states of shape [N, d].

    Args:
        state: Tensor of shape [N, d]
        logpdf_fn: function taking `x` and returning log_prob(x)
        params: dictionnary {"noise_dist":....} (noise distribution)
        grad_logpdf_fn: ignored for RW kernel

    Returns:
        new_state: Tensor with same shape as state
        log_ratio: log probability ratio for acceptance
    """

    #proposal = state + params["noise_dist"].sample(state.shape)
    noise = params["noise_dist"].sample(state.shape).to(state.device)
    proposal = state + noise

    return mh_accept_step(state, proposal, logpdf_fn, params)


def mala_kernel(state, logpdf_fn, params, grad_logpdf_fn):
    """
    One step of the Metropolis Adjusted Langevin Algorithm (MALA).

    Args:
        state: Tensor of shape, e.g. [batch_size, r, mvn_dim] for GM models
        grad_logpdf_fn: function taking `x` and returning gradient of log_prob(x)
        params: dictionnary {"step_size": ..., "sqrt_2eps": ...} (Langevin step size epsilon)

    Returns:
        new_state: Tensor with same shape as state
        log_ratio: log probability ratio for acceptance
    """
    # Unpack parameters
    step_size = params["step_size"]
    sqrt_2eps = params["sqrt_2eps"]
    # Generate noise and compute gradient
    z = torch.randn_like(state)  
    grad_log_prob_prev = params["grad_log_prob"] # contains grad_logpdf_fn(state), computed in mcmc() // grad_logpdf_fn is the grad log pi function
    proposal = state + step_size * grad_log_prob_prev + sqrt_2eps * z
    grad_log_prob_prop = grad_logpdf_fn(proposal) #grad log pi (x')
    # Compute log probability ratio correction
    op_q_forward = state - proposal - step_size * grad_log_prob_prop
    log_q_backward = -0.25/step_size * torch.sum(op_q_forward ** 2, dim=1) # General version of the operation : torch.sum(x ** 2, dim=tuple(range(1, x.ndim)))
    log_q_forward = -0.5 * torch.sum(z ** 2, dim=1)
    extra_log_ratio = log_q_backward - log_q_forward
    return mh_accept_step(
        state,
        proposal,
        logpdf_fn,
        params, 
        extra_log_ratio=extra_log_ratio,
        grad={"prev": grad_log_prob_prev, "prop": grad_log_prob_prop}
    )


def mcmc(init_state, logpdf_fn, config: MCMCConfig):
    """
    Wrapper function for MCMC sampling.

    Args:
        init_state (tensor): initial state of the chain
        logpdf_fn (function): function taking `x` and returning logpdf_fn(x)
        mcmc_params (MCMCConfig): configuration object containing MCMC parameters
            n_iter (int): number of iterations
            kernel_fn (function): kernel function to use (e.g. rw_kernel, ula_kernel, mala_kernel)
            kernel_params (dict): dictionary containing the parameters for the kernel
            grad_logpdf_fn (function): function taking `x` and returning grad_logpdf_fn(x). Defaults to None.

    Raises:
        ValueError: cast an error if kernel is ULA or MALA and grad_logpdf_fn is None

    Returns:
        state: final accepted state
    """
    # Unpack parameters
    n_iter = config.n_iter
    kernel = config.kernel_fn
    params = config.kernel_params
    grad_logpdf_fn = config.grad_logpdf_fn
    
    # Initialize state
    state = init_state
    trace = [state.clone()]  # liste des positions successives
    # créer un tableau de taille n_iter
    
    # Safety checks
    if kernel in {ula_kernel, mala_kernel} and grad_logpdf_fn is None:
        raise ValueError(f"{kernel.__name__} requires `grad_logpdf_fn`, but none was provided.")
    # Only for MH-based kernels
    if kernel != ula_kernel:
        params.setdefault("log_prob", logpdf_fn(state))
    
    if grad_logpdf_fn is not None:
        params.setdefault("grad_log_prob", grad_logpdf_fn(state))
        params.setdefault("sqrt_2eps", torch.sqrt(torch.tensor(2.0 * params["step_size"], device=state.device)))

    for _ in range(n_iter):
        state = kernel(state, logpdf_fn, params, grad_logpdf_fn)
        trace.append(state.clone())

    return torch.stack(trace)  # shape [n_iter+1, d]