"""
MH Local Corrector - Simplified version matching toy_analysis.py exactly
"""

import torch
import numpy as np
from typing import Optional, Dict, Any, Tuple, List
from dataclasses import dataclass
import json
from pathlib import Path


@dataclass
class MHLocalConfig:
    """Configuration matching toy implementation"""
    # Temperature schedule
    temperature_start: float = 1
    temperature_end: float = 1
    
    # Clipping
    clip_min: float = -10.0
    clip_max: float = 10.0
    
    # Method selection
    method: str = 'local_adaptive'  # 'vanilla', 'linear', 'local_adaptive'
    
    # History tracking
    enable_history: bool = False
    
    # No trigger mechanism - always correct (like toy)
    

class MHLocalCorrector:
    """MH Local Corrector exactly matching toy_analysis.py implementation"""
    
    def __init__(
        self,
        config: Optional[MHLocalConfig] = None,
        device: str = "cuda",
        verbose: bool = False
    ):
        self.config = config or MHLocalConfig()
        self.device = device
        self.verbose = verbose
        self.acceptance_history = []
        
    def compute_acceptance_probability(
        self,
        x: torch.Tensor,
        x_proposed: torch.Tensor,
        score_t: torch.Tensor,
        score_t_minus_1: torch.Tensor,
        timestep: float,
        total_timesteps: float
    ) -> torch.Tensor:
        """Global acceptance probability - matching toy exactly"""
        delta = x_proposed - x
        
        # Temperature scaling - exactly as in toy
        t_ratio = timestep / total_timesteps
        temperature = self.config.temperature_end + \
                     (self.config.temperature_start - self.config.temperature_end) * (1 - t_ratio)
        
        # Compute log ratio
        log_ratio = 0.5 * torch.sum((score_t + score_t_minus_1) * delta, dim=1)
        log_ratio = log_ratio * temperature
        
        # Clip to prevent numerical issues
        log_ratio = torch.clamp(log_ratio, min=self.config.clip_min, max=self.config.clip_max)
        
        alpha = torch.minimum(torch.ones_like(log_ratio), torch.exp(log_ratio))
        return alpha
    
    def compute_acceptance_probability_per_dim(
        self,
        x: torch.Tensor,
        x_proposed: torch.Tensor,
        score_t: torch.Tensor,
        score_t_minus_1: torch.Tensor,
        timestep: float,
        total_timesteps: float
    ) -> torch.Tensor:
        """Per-dimension acceptance probability - matching toy exactly"""
        delta = x_proposed - x
        
        # Temperature scaling - exactly as in toy
        t_ratio = timestep / total_timesteps
        temperature = self.config.temperature_end + \
                     (self.config.temperature_start - self.config.temperature_end) * (1 - t_ratio)
        
        # Compute log ratio per dimension
        log_ratio = 0.5 * (score_t + score_t_minus_1) * delta
        log_ratio = log_ratio * temperature
        
        # Clip to prevent numerical issues
        log_ratio = torch.clamp(log_ratio, min=self.config.clip_min, max=self.config.clip_max)
        
        alpha = torch.minimum(torch.ones_like(log_ratio), torch.exp(log_ratio))
        return alpha
    
    def apply_soft_acceptance(
        self,
        x: torch.Tensor,
        x_proposed: torch.Tensor,
        alpha: torch.Tensor,
        method: str
    ) -> torch.Tensor:
        """Apply soft acceptance - matching toy exactly"""
        
        if method == 'linear':
            # Linear soft acceptance: α * x_proposed + (1-α) * x
            alpha_expanded = alpha.unsqueeze(1)
            x_new = alpha_expanded * x_proposed + (1 - alpha_expanded) * x
            
        elif method == 'local_adaptive':
            # Per-dimension acceptance (alpha should be per-dimension here)
            x_new = alpha * x_proposed + (1 - alpha) * x
            
        else:
            raise ValueError(f"Unknown soft acceptance method: {method}")
        
        return x_new
    
    def correct_step(
        self,
        x_t: torch.Tensor,
        x_t_minus_1_proposed: torch.Tensor,
        timestep: int,
        total_timesteps: int,
        score_t: torch.Tensor,
        score_t_minus_1: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, Any]]:
        """
        Main correction function - matching toy implementation
        No trigger mechanism - always corrects (like toy)
        """
        
        # Initialize stats
        stats = {
            'triggered': True,  # Always triggered in toy
            'acceptance_prob_mean': 1.0,
            'acceptance_prob_std': 0.0,
            'accepted': True,
            'rejection_ratio': 0.0
        }
        
        # Convert timestep to float for ratio calculation
        timestep_float = float(timestep)
        
        if self.config.method == 'vanilla':
            # No correction
            x_result = x_t_minus_1_proposed
            
        elif self.config.method == 'local_adaptive':
            # Per-dimension acceptance - exactly as in toy
            alpha = self.compute_acceptance_probability_per_dim(
                x_t, x_t_minus_1_proposed,
                score_t, score_t_minus_1,
                timestep_float, total_timesteps
            )
            
            # Apply soft acceptance
            x_result = self.apply_soft_acceptance(
                x_t, x_t_minus_1_proposed, alpha, 'local_adaptive'
            )
            
            # Track statistics
            stats['acceptance_prob_mean'] = alpha.mean().item()
            stats['acceptance_prob_std'] = alpha.std().item()
            stats['rejection_ratio'] = (alpha < 0.5).float().mean().item()
            
        elif self.config.method == 'linear':
            # Global acceptance - exactly as in toy
            alpha = self.compute_acceptance_probability(
                x_t, x_t_minus_1_proposed,
                score_t, score_t_minus_1,
                timestep_float, total_timesteps
            )
            
            # Apply soft acceptance
            x_result = self.apply_soft_acceptance(
                x_t, x_t_minus_1_proposed, alpha, 'linear'
            )
            
            # Track statistics
            stats['acceptance_prob_mean'] = alpha.mean().item()
            stats['acceptance_prob_std'] = alpha.std().item() if alpha.numel() > 1 else 0.0
            
        else:
            raise ValueError(f"Unknown method: {self.config.method}")
        
        # Update history if enabled
        if self.config.enable_history:
            self.acceptance_history.append({
                'timestep': timestep,
                'acceptance_prob': stats['acceptance_prob_mean'],
                'rejection_ratio': stats['rejection_ratio']
            })
        
        return x_result, stats
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get accumulated statistics"""
        if not self.acceptance_history:
            return {
                'total_corrections': 0,
                'recent_acceptance_rate': 0.0,
                'avg_rejection_ratio': 0.0
            }
        
        recent = self.acceptance_history[-20:] if len(self.acceptance_history) > 20 else self.acceptance_history
        
        return {
            'total_corrections': len(self.acceptance_history),
            'recent_acceptance_rate': np.mean([h['acceptance_prob'] for h in recent]),
            'avg_rejection_ratio': np.mean([h['rejection_ratio'] for h in recent])
        }
    
    def save_history(self, path: str):
        """Save correction history"""
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        
        data = {
            'config': {
                'temperature_start': self.config.temperature_start,
                'temperature_end': self.config.temperature_end,
                'clip_min': self.config.clip_min,
                'clip_max': self.config.clip_max,
                'method': self.config.method
            },
            'history': self.acceptance_history,
            'statistics': self.get_statistics()
        }
        
        with open(path, 'w') as f:
            json.dump(data, f, indent=2, default=str)
    
    def reset(self):
        """Reset history"""
        self.acceptance_history = []