import torch
import numpy as np
import time
from tqdm import tqdm

from src.models.bnn import NeuralNetworkEnsemble
from src.models.mvnmixture import MultivariateNormalMixture

class NUTS:
    def __init__(
            self,
            particles,
            model,
            store_particles_history=False,
            store_logp_history=False,
            step_size=0.01,
            target_accept=0.8,
            max_tree_depth=10,
            test_dataset=None,
            device='cpu'
        ):
        if isinstance(particles, np.ndarray):
            particles = torch.from_numpy(particles)
        assert isinstance(particles, torch.Tensor) and particles.dim() == 2
        self.N, self.d = particles.shape
        self.model = model
        self.device = device
        self.step_size = step_size
        self.target_accept = target_accept
        self.max_tree_depth = max_tree_depth
        self.store_particles_history = store_particles_history
        self.store_logp_history = store_logp_history

        self.particles_current = particles.clone().detach().to(device)
        if test_dataset:
            self.test_dataset = test_dataset
        if store_particles_history:
            self.particles_history = self.particles_current.unsqueeze(0)
        if store_logp_history:
            self.logp_history = []

    # ----------------------------------------------------------------------
    # Use model's score() for ∇ log p
    # ----------------------------------------------------------------------
    def compute_grad_logp(self, particles, dataloader=None, temperature=1.0):
        """
        Uses model.score() to compute gradients (score function) for each particle.
        """
        if hasattr(self.model, "score"):
            grad = self.model.score(dataloader=dataloader, temperature=temperature)
            return grad
        else:
            raise ValueError("Model must implement a score(dataloader=...) method.")

    # ----------------------------------------------------------------------
    def leapfrog(self, position, momentum, grad_logp, step_size, dataloader=None, temperature=1.0):
        momentum_half = momentum + 0.5 * step_size * grad_logp
        position_new = position + step_size * momentum_half
        grad_logp_new = self.compute_grad_logp(position_new, dataloader=dataloader, temperature=temperature)
        momentum_new = momentum_half + 0.5 * step_size * grad_logp_new
        return position_new, momentum_new, grad_logp_new

    def hamiltonian(self, logp, momentum):
        return -logp + 0.5 * (momentum ** 2).sum(dim=1)

    # ----------------------------------------------------------------------
    def build_tree(self, position, momentum, grad_logp, logu, v, j, step_size, position0, momentum0, logp0, dataloader=None, temperature=1.0):
        if j == 0:
            new_pos, new_mom, new_grad = self.leapfrog(position, momentum, grad_logp, v * step_size, dataloader, temperature)
            self.model.set_params(new_pos)
            logp_new = self.model.log_prob(dataloader)
            H_new = -logp_new + 0.5 * (new_mom ** 2).sum(dim=1)
            n = (logu < (logp_new - 0.5 * (new_mom ** 2).sum(dim=1))).float()
            s = (logu - 1000 < (logp_new - 0.5 * (new_mom ** 2).sum(dim=1))).float()
            # return new_pos, new_mom, new_grad, new_pos, new_mom, new_pos, n, s, torch.exp(logp_new - logp0).clamp(max=1.0)
            return new_pos, new_mom, new_grad, new_pos, new_mom, new_grad, new_pos, n, s, torch.exp(logp_new - logp0).clamp(max=1.0)


        # Recursive case
        pos_minus, mom_minus, grad_minus, pos_plus, mom_plus, grad_plus, pos_prime, n_prime, s_prime, alpha_prime = \
            self.build_tree(position, momentum, grad_logp, logu, v, j - 1, step_size, position0, momentum0, logp0, dataloader, temperature)

        if s_prime.sum() == 0:
            return pos_minus, mom_minus, grad_minus, pos_plus, mom_plus, grad_plus, pos_prime, n_prime, s_prime, alpha_prime

        if v == -1:
            pos_minus, mom_minus, grad_minus, _, _, _, pos_prime2, n_prime2, s_prime2, alpha_prime2 = \
                self.build_tree(pos_minus, mom_minus, grad_minus, logu, v, j - 1, step_size, position0, momentum0, logp0, dataloader, temperature)
        else:
            _, _, _, pos_plus, mom_plus, grad_plus, pos_prime2, n_prime2, s_prime2, alpha_prime2 = \
                self.build_tree(pos_plus, mom_plus, grad_plus, logu, v, j - 1, step_size, position0, momentum0, logp0, dataloader, temperature)

        accept_mask = torch.rand(self.N, device=self.device) < (n_prime2 / torch.clamp(n_prime + n_prime2, min=1e-10))
        pos_prime[accept_mask] = pos_prime2[accept_mask]
        n_prime += n_prime2
        s_prime = s_prime * s_prime2
        alpha_prime = (alpha_prime + alpha_prime2) / 2
        return pos_minus, mom_minus, grad_minus, pos_plus, mom_plus, grad_plus, pos_prime, n_prime, s_prime, alpha_prime

    # ----------------------------------------------------------------------
    def update(self, burn_in=1000, num_samples=20, thinning_factor=10, progress=True, dataloader=None, temperature=1.0, **kwargs):
        """
        Run NUTS iterations using model.score() for gradients.
        """
        if dataloader is None:
            raise ValueError("Must pass a dataloader for computing model scores.")

        time_start = time.time()
        self.total_iterations = burn_in + int(num_samples * thinning_factor)

        samples_all = []
        for iteration in tqdm(range(self.total_iterations), disable=not progress):

            if iteration <= burn_in:
                momentum0 = torch.randn_like(self.particles_current)
            else:
                momentum0 = torch.randn_like(self.particles_current) * torch.sqrt(mass_diag)

            self.model.set_params(self.particles_current)
            logp0 = self.model.log_prob(dataloader)
            grad_logp = self.compute_grad_logp(self.particles_current, dataloader=dataloader, temperature=temperature)
            logu = logp0 - 0.5 * (momentum0 ** 2).sum(dim=1) - torch.rand(self.N, device=self.device).log()

            pos_minus = self.particles_current.clone()
            pos_plus = self.particles_current.clone()
            mom_minus = momentum0.clone()
            mom_plus = momentum0.clone()
            grad_minus = grad_logp.clone()
            grad_plus = grad_logp.clone()
            pos_prime = self.particles_current.clone()

            j = 0
            n = torch.ones(self.N, device=self.device)
            s = torch.ones(self.N, device=self.device)
            alpha_sum = 0.0
            n_alpha = 0

            while (s.sum() > 0) and (j < self.max_tree_depth):
                v = torch.randint(0, 2, (1,)).item() * 2 - 1  # ±1 direction
                if v == -1:
                    pos_minus, mom_minus, grad_minus, _, _, _, pos_prime_new, n_new, s_new, alpha_new = \
                        self.build_tree(pos_minus, mom_minus, grad_minus, logu, v, j, self.step_size, self.particles_current, momentum0, logp0, dataloader, temperature)
                else:
                    _, _, _, pos_plus, mom_plus, grad_plus, pos_prime_new, n_new, s_new, alpha_new = \
                        self.build_tree(pos_plus, mom_plus, grad_plus, logu, v, j, self.step_size, self.particles_current, momentum0, logp0, dataloader, temperature)

                accept_mask = torch.rand(self.N, device=self.device) < (n_new / torch.clamp(n + n_new, min=1e-10))
                pos_prime[accept_mask] = pos_prime_new[accept_mask]
                n += n_new
                s = s * s_new
                alpha_sum += alpha_new.mean().item()
                n_alpha += 1
                j += 1

            self.particles_current = pos_prime.clone().detach()

            if self.store_particles_history:
                self.particles_history = torch.vstack([self.particles_history, self.particles_current.unsqueeze(0)])
            if self.store_logp_history:
                with torch.no_grad():
                    self.logp_history.append(self.model.log_prob_params(self.particles_current, dataloader).mean().item())

            samples_all.append(self.particles_current.clone().detach())

            if iteration > burn_in - 1:
                # After burn-in, compute diagonal mass
                burn_in_samples = torch.stack(samples_all[:burn_in])
                mass_diag = torch.var(burn_in_samples, dim=0) + 1e-6  # avoid zero


        samples_all = torch.stack(samples_all)
        post_burn_in = samples_all[burn_in:]
        # thinning_step = max(1, int(1 / thinning_factor))
        self.samples = post_burn_in[::thinning_factor][:num_samples]
        self.damv = torch.var(self.samples.reshape(-1, self.d), dim=0).mean().item()

        if isinstance(self.model, NeuralNetworkEnsemble):
            # Discard burn-in samples (first half is common heuristic)
            burn_in = len(self.samples) // 2 if hasattr(self, "samples") else 0

            # Average over post–burn-in samples to get mean parameters
            if burn_in < len(self.samples):
                mean_params = torch.mean(self.samples, dim=0)
                self.model.set_params(mean_params)
                self.model.evaluate(self.test_dataset)
                self.evaluate()

        self.time_seconds = time.time() - time_start
        self.accept_rate = alpha_sum / max(n_alpha, 1)
        return self.particles_current

    def evaluate(self):
        """Evaluate predictive metrics (RMSE, log-likelihood) on the test dataset."""
        self.rmse = self.model.rmse
        self.ll = self.model.ll
        self.ll_array = self.model.ll_array
