import torch
import time

def kinetic(momentum, mass):
    """
    Compute kinetic energy for each particle.
    
    Args:
        momentum: (N, d) tensor
        mass: (d,) tensor or scalar
    Returns:
        (N,) tensor of kinetic energies
    """
    return torch.sum(momentum ** 2 / (2 * mass), dim=1)


class HMC:
    def __init__(
        self,
        particles,
        model,
        test_dataset=None,
        mass=1.0,
        leapfrog_steps=10,
    ):
        """
        Hamiltonian Monte Carlo sampler for a NeuralNetworkEnsemble.
        
        Args:
            particles (torch.Tensor): initial particles (N × d)
            model (NeuralNetworkEnsemble): ensemble of neural networks
            test_dataset (TensorDataset, optional): dataset for evaluation
            mass (float): mass parameter for HMC (controls momentum variance)
            leapfrog_steps (int): number of leapfrog steps per iteration
        """
        self.particles = particles.clone().detach()
        self.model = model
        self.N, self.d = self.particles.shape
        self.mass = mass
        self.leapfrog_steps = leapfrog_steps
        self.test_dataset = test_dataset

        # Metrics
        self.time_seconds = 0
        self.acceptance_rate = 0
        self.ll = None
        self.rmse = None
        self.damv = None

    def _potential_and_grad(self, dataloader, temperature=1.0):
        """Compute negative log-posterior (potential) and its gradient for each particle."""
        logP = self.model.log_prob(dataloader)
        grad_logP = self.model.score(dataloader=dataloader, temperature=temperature)
        U = -logP
        grad_U = -grad_logP
        return U, grad_U

    def _estimate_mass(self, data_loader, num_samples=10):
        """
        Estimate per-parameter mass using multiple gradient samples.
        
        Returns:
            mass: (d,) tensor of per-parameter masses
        """
        print("Estimating mass matrix...")
        grad_samples = []
        
        for _ in range(num_samples):
            # Use current particles
            self.model.set_params(self.particles)
            _, grad_U = self._potential_and_grad(data_loader)
            grad_samples.append(grad_U)
        
        # Stack: (num_samples, N, d)
        grad_stack = torch.stack(grad_samples, dim=0)
        
        # Compute variance across samples and particles: (d,)
        # Flatten to (num_samples * N, d) then compute variance
        grad_flat = grad_stack.view(-1, self.d)
        mass = grad_flat.var(dim=0) + 1e-6
        
        # Clip extreme values
        mass = torch.clamp(mass, min=1e-4, max=1e4)
        
        print(f"Mass range: [{mass.min():.2e}, {mass.max():.2e}], median: {mass.median():.2e}")
        return mass

    def update(self, num_samples, step_size, data_loader, burn_in=10, 
               adaptive_step_size_interval=5, estimate_mass=True, verbose=True):
        """
        Run HMC sampling.
        
        Args:
            num_samples: Number of samples after burn-in
            step_size: Initial step size (will be adapted)
            data_loader: DataLoader for computing gradients
            burn_in: Number of burn-in iterations
            adaptive_step_size_interval: Adapt step size every N iterations
            estimate_mass: If True, estimate mass matrix from gradients
            verbose: Print progress
        """
        start_time = time.time()
        N, d = self.N, self.d
        particles = self.particles.clone().detach()
        accepted = 0
        total_iters = num_samples + burn_in

        # Estimate mass matrix
        if estimate_mass:
            self.mass = self._estimate_mass(data_loader, num_samples=10)
        else:
            if isinstance(self.mass, float):
                self.mass = torch.ones(d) * self.mass
            self.mass = self.mass.to(particles.device)

        for it in range(total_iters):
            # Sample momentum
            momentum = torch.randn_like(particles) * torch.sqrt(self.mass)

            # Set temperature with annealing schedule
            temperature = 1.0 if it >= burn_in else 10 - (9 * it / burn_in)
            
            # Compute initial potential and gradient
            self.model.set_params(particles)
            U, grad_U = self._potential_and_grad(data_loader, temperature=temperature)

            # Initialize leapfrog
            particles_new = particles.clone()
            momentum_new = momentum.clone()

            # Half step for momentum
            momentum_new -= 0.5 * step_size * grad_U

            # Full leapfrog steps
            for step in range(self.leapfrog_steps):
                # Full step for position
                particles_new += step_size * momentum_new / self.mass
                
                # Update gradient
                self.model.set_params(particles_new)
                U_new, grad_U_new = self._potential_and_grad(data_loader)
                
                # Full step for momentum (except last step)
                if step != self.leapfrog_steps - 1:
                    momentum_new -= step_size * grad_U_new

            # Final half step for momentum
            momentum_new -= 0.5 * step_size * grad_U_new
            
            # Recompute final potential (already have grad_U_new)
            U_new, _ = self._potential_and_grad(data_loader)
            
            # Compute Hamiltonian
            K = kinetic(momentum, self.mass)
            K_new = kinetic(momentum_new, self.mass)
            
            H = U + K
            H_new = U_new + K_new
            
            # Metropolis acceptance
            delta_H = H - H_new
            accept_prob = torch.exp(torch.minimum(delta_H, torch.zeros_like(delta_H)))
            accept_mask = (torch.rand_like(accept_prob) < accept_prob).float().unsqueeze(1)

            # Update particles
            particles = accept_mask * particles_new + (1 - accept_mask) * particles
            accepted += accept_mask.sum().item()
            
            # Set params for next iteration
            self.model.set_params(particles)

            # Adaptive step size (only during burn-in or early sampling)
            if adaptive_step_size_interval and (it + 1) % adaptive_step_size_interval == 0:
                acc_rate = accepted / ((it + 1) * N)
                if acc_rate < 0.5:
                    step_size *= 0.95
                elif acc_rate > 0.7:
                    step_size *= 1.05
                # Clamp step size to reasonable range
                step_size = torch.clamp(torch.tensor(step_size), min=1e-5, max=0.1).item()

            if verbose and (it + 1) % 10 == 0:
                acc_rate = accepted / ((it + 1) * N)
                print(f"[{it+1}/{total_iters}] Acceptance rate: {acc_rate:.3f}, Step size: {step_size:.2e}")

        # Store results
        self.particles = particles
        self.acceptance_rate = accepted / (total_iters * N)
        self.time_seconds = time.time() - start_time

        # Only retaint samples after burn-in
        self.model.ensemble = self.model.ensemble[burn_in:]

        # Print final hyperparameters
        if verbose:
            print("\nFinal particle hyperparameters (last 3):")
            for i, net in enumerate(self.model.ensemble[-3:]):
                print(f"  Particle {i}: log_gamma={net.log_gamma.item():.3f}, "
                      f"exp(log_gamma)={torch.exp(net.log_gamma).item():.3e}")

        # Evaluate on test set
        if self.test_dataset is not None:
            self.model.evaluate(self.test_dataset)
            self.ll = self.model.ll.item()
            self.rmse = self.model.rmse.item()

        if verbose:
            print(f"\n✅ HMC finished in {self.time_seconds:.2f}s | "
                  f"Acceptance: {self.acceptance_rate:.3f} | "
                  f"RMSE: {self.rmse:.4f} | LL: {self.ll:.4f}")

        return step_size  # Return final step size for reference

class HMCToy:
    def __init__(self, particles, model, mass=1.0, leapfrog_steps=10):
        """
        Minimal HMC sampler.
        
        Args:
            particles: (N, d) tensor, initial particles
            model: object with model.score(particles) -> (N, d) gradient
            mass: scalar or (d,) tensor
            leapfrog_steps: number of leapfrog steps
        """
        self.particles = particles.clone().detach()
        self.model = model
        self.N, self.d = self.particles.shape
        self.mass = torch.ones(self.d) * mass if isinstance(mass, float) else mass
        self.leapfrog_steps = leapfrog_steps
        self.acceptance_rate = 0
        self.time_seconds = 0

    def _potential_and_grad(self, particles):
        """Negative log-prob and gradient."""
        grad_U = -self.model.score(particles)  # score = grad logP
        # For Hamiltonian, we need U, but we can set U = 0 since we only use delta H for accept
        # Or user can implement log_prob(particles) if desired
        U = torch.zeros(self.N)
        return U, grad_U

    def update(self, num_samples, step_size=0.01, burn_in=100, adaptive_step_size_interval=5, verbose=True):
        start_time = time.time()
        N, d = self.N, self.d
        particles = self.particles.clone()
        accepted = 0

        total_iters = num_samples + burn_in

        all_particles = []

        for it in range(total_iters):
            # Sample momentum
            momentum = torch.randn_like(particles) * torch.sqrt(self.mass)

            # Leapfrog integration
            particles_new = particles.clone()
            momentum_new = momentum.clone()
            _, grad_U = self._potential_and_grad(particles_new)
            momentum_new -= 0.5 * step_size * grad_U

            for _ in range(self.leapfrog_steps):
                particles_new += step_size * momentum_new / self.mass
                _, grad_U_new = self._potential_and_grad(particles_new)
                momentum_new -= step_size * grad_U_new

            momentum_new += 0.5 * step_size * grad_U_new

            # Compute Hamiltonian (using only kinetic term if no log_prob)
            K = kinetic(momentum, self.mass)
            K_new = kinetic(momentum_new, self.mass)
            U, U_new = torch.zeros(N), torch.zeros(N)
            H, H_new = U + K, U_new + K_new

            # Metropolis acceptance
            delta_H = H - H_new
            accept_prob = torch.exp(torch.minimum(delta_H, torch.zeros_like(delta_H)))
            accept_mask = (torch.rand_like(accept_prob) < accept_prob).float().unsqueeze(1)

            # Update particles
            particles = accept_mask * particles_new + (1 - accept_mask) * particles
            accepted += accept_mask.sum().item()

            all_particles.append(particles.clone())

            # Adaptive step size (only during burn-in or early sampling)
            if adaptive_step_size_interval and (it + 1) % adaptive_step_size_interval == 0:
                acc_rate = accepted / ((it + 1) * N)
                if acc_rate < 0.5:
                    step_size *= 0.95
                elif acc_rate > 0.7:
                    step_size *= 1.05
                # Clamp step size to reasonable range
                step_size = torch.clamp(torch.tensor(step_size), min=1e-5, max=0.1).item()

            if verbose and (it + 1) % 20 == 0:
                print(f"[{it+1}/{num_samples}] Acceptance: {accepted / ((it+1)*N):.3f}, Step size: {step_size:.2e}")

        # Only retain samples after burn-in
        all_particles = torch.stack(all_particles)
        self.particles = all_particles[burn_in:]
        self.acceptance_rate = accepted / (num_samples * N)
        self.time_seconds = time.time() - start_time

        if verbose:
            print(f"\n✅ HMC finished in {self.time_seconds:.2f}s | Acceptance rate: {self.acceptance_rate:.3f}")
