"""
Linear Metropolis-Hastings Corrector for Diffusion Models
Global acceptance probability with soft rejection
"""

import torch
import numpy as np
from typing import Optional, Dict, Any, Tuple
from dataclasses import dataclass
import json
from pathlib import Path


@dataclass
class MHLinearConfig:
    """Configuration for Linear MH corrector"""
    temperature_start: float = 0.01
    temperature_end: float = 0.1
    enable_history: bool = True
    history_window: int = 100
    verbose: bool = False


class MHLinearCorrector:
    """
    Linear Metropolis-Hastings corrector.
    Computes a single global acceptance probability for the entire sample.
    """
    
    def __init__(
        self,
        config: Optional[MHLinearConfig] = None,
        device: str = "cuda",
        verbose: bool = False
    ):
        self.config = config or MHLinearConfig()
        self.device = device
        self.verbose = verbose or self.config.verbose
        
        self.acceptance_history = []
        self.step_count = 0
    
    def compute_acceptance_probability(
        self,
        x_t: torch.Tensor,
        x_t_minus_1: torch.Tensor,
        score_t: torch.Tensor,
        score_t_minus_1: torch.Tensor,
        timestep: int,
        total_timesteps: int
    ) -> torch.Tensor:
        """
        Compute global acceptance probability for each sample in batch.
        
        Args:
            x_t: Current state [batch_size, ...]
            x_t_minus_1: Proposed next state [batch_size, ...]
            score_t: Score at current state [batch_size, ...]
            score_t_minus_1: Score at proposed state [batch_size, ...]
            timestep: Current timestep
            total_timesteps: Total number of timesteps
            
        Returns:
            Acceptance probabilities [batch_size]
        """
        delta = x_t_minus_1 - x_t
        
        t_ratio = timestep / total_timesteps
        temperature = self.config.temperature_end + (self.config.temperature_start - self.config.temperature_end) * t_ratio
        
        dims_to_sum = tuple(range(1, len(delta.shape)))
        log_ratio = 0.5 * torch.sum((score_t + score_t_minus_1) * delta, dim=dims_to_sum)
        log_ratio = log_ratio * temperature
        
        log_ratio = torch.clamp(log_ratio, min=-10, max=10)
        
        alpha = torch.minimum(torch.ones_like(log_ratio), torch.exp(log_ratio))
        
        return alpha
    
    def apply_linear_soft_acceptance(
        self,
        x_current: torch.Tensor,
        x_proposed: torch.Tensor,
        alpha: torch.Tensor
    ) -> torch.Tensor:
        """
        Apply linear soft acceptance.
        
        Args:
            x_current: Current sample [batch_size, ...]
            x_proposed: Proposed sample [batch_size, ...]
            alpha: Acceptance probabilities [batch_size]
            
        Returns:
            Mixed sample [batch_size, ...]
        """
        while len(alpha.shape) < len(x_current.shape):
            alpha = alpha.unsqueeze(-1)
        
        x_new = alpha * x_proposed + (1 - alpha) * x_current
        
        return x_new
    
    def correct_step(
        self,
        x_t: torch.Tensor,
        x_t_minus_1_proposed: torch.Tensor,
        timestep: int,
        total_timesteps: int,
        score_t: torch.Tensor,
        score_t_minus_1: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, Any]]:
        """
        Apply linear MH correction to a single denoising step.
        
        Args:
            x_t: Current sample at time t
            x_t_minus_1_proposed: Proposed sample at time t-1
            timestep: Current timestep
            total_timesteps: Total timesteps
            score_t: Score at current state
            score_t_minus_1: Score at proposed state
            
        Returns:
            Corrected sample and statistics dictionary
        """
        alpha = self.compute_acceptance_probability(
            x_t, x_t_minus_1_proposed,
            score_t, score_t_minus_1,
            timestep, total_timesteps
        )
        
        x_corrected = self.apply_linear_soft_acceptance(
            x_t, x_t_minus_1_proposed, alpha
        )
        
        alpha_mean = alpha.mean().item()
        alpha_std = alpha.std().item() if alpha.numel() > 1 else 0.0
        
        stats = {
            'acceptance_prob_mean': alpha_mean,
            'acceptance_prob_std': alpha_std,
            'acceptance_prob_min': alpha.min().item(),
            'acceptance_prob_max': alpha.max().item(),
            'temperature': self.config.temperature_end + 
                          (self.config.temperature_start - self.config.temperature_end) * 
                          (timestep / total_timesteps)
        }
        
        if self.config.enable_history:
            self.acceptance_history.append({
                'step': self.step_count,
                'timestep': timestep,
                'alpha_mean': alpha_mean,
                'alpha_std': alpha_std
            })
            
            if len(self.acceptance_history) > self.config.history_window:
                self.acceptance_history = self.acceptance_history[-self.config.history_window:]
        
        self.step_count += 1
        
        if self.verbose and self.step_count % 10 == 0:
            print(f"[Linear MH] Step {self.step_count}: α_mean={alpha_mean:.3f} ± {alpha_std:.3f}")
        
        return x_corrected, stats
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get current statistics and history."""
        stats = {
            'total_steps': self.step_count,
            'history': self.acceptance_history
        }
        
        if self.acceptance_history:
            recent = self.acceptance_history[-20:] if len(self.acceptance_history) >= 20 else self.acceptance_history
            stats['recent_mean'] = np.mean([h['alpha_mean'] for h in recent])
            stats['recent_std'] = np.mean([h['alpha_std'] for h in recent])
        
        return stats
    
    def save_history(self, path: str):
        """Save history to file for analysis."""
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        
        data = {
            'config': {
                'temperature_start': self.config.temperature_start,
                'temperature_end': self.config.temperature_end,
                'method': 'linear'
            },
            'statistics': self.get_statistics()
        }
        
        with open(path, 'w') as f:
            json.dump(data, f, indent=2)
    
    def reset(self):
        """Reset history and statistics."""
        self.acceptance_history = []
        self.step_count = 0