# utils/visualization.py
"""
Visualization utilities for PI-ConvNP.
"""

import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Tuple, List, Dict
import os


def plot_predictions(
    x_context: torch.Tensor,
    y_context: torch.Tensor,
    x_target: torch.Tensor,
    y_target: torch.Tensor,
    y_pred_mean: torch.Tensor,
    y_pred_std: torch.Tensor,
    save_path: Optional[str] = None,
    title: str = "Predictions",
    figsize: Tuple[int, int] = (12, 6),
    show: bool = False,
    confidence_levels: List[float] = [0.90, 0.95]
) -> None:
    """
    Plot model predictions with uncertainty.
    
    Args:
        x_context: Context input locations [N, spatial_dim]
        y_context: Context observations [N, obs_dim]
        x_target: Target input locations [M, spatial_dim]
        y_target: Target ground truth [M, output_dim]
        y_pred_mean: Predicted mean [M, output_dim]
        y_pred_std: Predicted standard deviation [M, output_dim]
        save_path: Path to save the figure
        title: Plot title
        figsize: Figure size
        show: Whether to show the plot
        confidence_levels: List of confidence levels to plot
    """
    # Convert to numpy
    x_context = x_context.detach().cpu().numpy()
    y_context = y_context.detach().cpu().numpy()
    x_target = x_target.detach().cpu().numpy()
    y_target = y_target.detach().cpu().numpy()
    y_pred_mean = y_pred_mean.detach().cpu().numpy()
    y_pred_std = y_pred_std.detach().cpu().numpy()
    
    # Handle different dimensions
    if x_context.ndim == 1:
        x_context = x_context[:, None]
    if x_target.ndim == 1:
        x_target = x_target[:, None]
    if y_context.ndim == 1:
        y_context = y_context[:, None]
    if y_target.ndim == 1:
        y_target = y_target[:, None]
    if y_pred_mean.ndim == 1:
        y_pred_mean = y_pred_mean[:, None]
    if y_pred_std.ndim == 1:
        y_pred_std = y_pred_std[:, None]
    
    # For 1D spatial input, plot each output dimension
    if x_context.shape[1] == 1:
        num_outputs = y_pred_mean.shape[1]
        
        fig, axes = plt.subplots(1, num_outputs, figsize=figsize)
        if num_outputs == 1:
            axes = [axes]
        
        for i, ax in enumerate(axes):
            # Sort target points for smooth plotting
            sort_idx = np.argsort(x_target[:, 0])
            x_sorted = x_target[sort_idx, 0]
            y_mean_sorted = y_pred_mean[sort_idx, i]
            y_std_sorted = y_pred_std[sort_idx, i]
            y_true_sorted = y_target[sort_idx, i]
            
            # Plot ground truth
            ax.plot(x_sorted, y_true_sorted, 'k-', linewidth=2, 
                   label='Ground Truth', alpha=0.7)
            
            # Plot predicted mean
            ax.plot(x_sorted, y_mean_sorted, 'b-', linewidth=2, 
                   label='Prediction')
            
            # Plot confidence intervals
            colors = ['lightblue', 'lightskyblue']
            alphas = [0.3, 0.2]
            
            for j, (conf_level, color, alpha) in enumerate(
                zip(confidence_levels, colors, alphas)
            ):
                # Compute z-score for confidence level
                from scipy import stats
                z_score = stats.norm.ppf((1 + conf_level) / 2)
                
                lower = y_mean_sorted - z_score * y_std_sorted
                upper = y_mean_sorted + z_score * y_std_sorted
                
                ax.fill_between(
                    x_sorted, lower, upper,
                    alpha=alpha, color=color,
                    label=f'{int(conf_level*100)}% CI'
                )
            
            # Plot context points
            ax.scatter(x_context[:, 0], y_context[:, i], 
                      c='red', s=50, marker='o', 
                      label='Context', zorder=5, alpha=0.8)
            
            ax.set_xlabel('x', fontsize=12)
            ax.set_ylabel(f'y[{i}]', fontsize=12)
            ax.set_title(f'{title} - Output {i}' if num_outputs > 1 else title)
            ax.legend(loc='best', fontsize=10)
            ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
    else:
        # For higher dimensional inputs, create scatter plots
        raise NotImplementedError(
            "Visualization for spatial_dim > 1 not yet implemented"
        )
    
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        
    if show:
        plt.show()
    else:
        plt.close()


def plot_uncertainty(
    x_target: torch.Tensor,
    y_pred_std: torch.Tensor,
    save_path: Optional[str] = None,
    title: str = "Predictive Uncertainty",
    figsize: Tuple[int, int] = (10, 5),
    show: bool = False
) -> None:
    """
    Plot predictive uncertainty.
    
    Args:
        x_target: Target input locations [M, spatial_dim]
        y_pred_std: Predicted standard deviation [M, output_dim]
        save_path: Path to save the figure
        title: Plot title
        figsize: Figure size
        show: Whether to show the plot
    """
    # Convert to numpy
    x_target = x_target.detach().cpu().numpy()
    y_pred_std = y_pred_std.detach().cpu().numpy()
    
    # Handle different dimensions
    if x_target.ndim == 1:
        x_target = x_target[:, None]
    if y_pred_std.ndim == 1:
        y_pred_std = y_pred_std[:, None]
    
    if x_target.shape[1] == 1:
        num_outputs = y_pred_std.shape[1]
        
        fig, axes = plt.subplots(1, num_outputs, figsize=figsize)
        if num_outputs == 1:
            axes = [axes]
        
        for i, ax in enumerate(axes):
            # Sort by x for smooth plotting
            sort_idx = np.argsort(x_target[:, 0])
            x_sorted = x_target[sort_idx, 0]
            std_sorted = y_pred_std[sort_idx, i]
            
            # Plot standard deviation
            ax.plot(x_sorted, std_sorted, 'b-', linewidth=2)
            ax.fill_between(x_sorted, 0, std_sorted, alpha=0.3)
            
            ax.set_xlabel('x', fontsize=12)
            ax.set_ylabel('Standard Deviation', fontsize=12)
            ax.set_title(
                f'{title} - Output {i}' if num_outputs > 1 else title
            )
            ax.grid(True, alpha=0.3)
            ax.set_ylim(bottom=0)
        
        plt.tight_layout()
    else:
        raise NotImplementedError(
            "Visualization for spatial_dim > 1 not yet implemented"
        )
    
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    if show:
        plt.show()
    else:
        plt.close()


def plot_training_curves(
    history: Dict[str, List[float]],
    save_path: Optional[str] = None,
    title: str = "Training Curves",
    figsize: Tuple[int, int] = (12, 5),
    show: bool = False
) -> None:
    """
    Plot training curves.
    
    Args:
        history: Dictionary containing training history
                 Keys: 'train_loss', 'val_loss', etc.
                 Can be dict of lists OR dict of list-of-dicts (from Trainer)
        save_path: Path to save the figure
        title: Plot title
        figsize: Figure size
        show: Whether to show the plot
    """
    # Convert history format if needed (handle list of dicts from Trainer)
    converted_history = {}
    
    for key, values in history.items():
        if isinstance(values, list) and len(values) > 0:
            # Check if it's a list of dictionaries
            if isinstance(values[0], dict):
                # Extract 'total' loss from each epoch dict
                converted_history[key] = [epoch_dict['total'] for epoch_dict in values]
                
                # Also extract component losses if available
                if 'data' in values[0]:
                    converted_history[f'{key}_data'] = [epoch_dict.get('data', 0.0) for epoch_dict in values]
                if 'physics' in values[0]:
                    converted_history[f'{key}_physics'] = [epoch_dict.get('physics', 0.0) for epoch_dict in values]
                if 'reg' in values[0]:
                    converted_history[f'{key}_reg'] = [epoch_dict.get('reg', 0.0) for epoch_dict in values]
            else:
                # Already a list of numbers
                converted_history[key] = values
        else:
            converted_history[key] = values
    
    history = converted_history
    
    # Determine number of subplots needed
    has_components = any(k.endswith('_data') or k.endswith('_physics') for k in history.keys())
    n_plots = 2 if has_components else 1
    
    fig, axes = plt.subplots(1, n_plots, figsize=figsize)
    if n_plots == 1:
        axes = [axes]
    
    # Plot total loss
    if 'train_loss' in history:
        axes[0].plot(history['train_loss'], label='Train', linewidth=2)
    if 'val_loss' in history:
        axes[0].plot(history['val_loss'], label='Validation', linewidth=2)
    
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Total Loss', fontsize=12)
    axes[0].set_title('Total Loss')
    axes[0].legend(loc='best')
    axes[0].grid(True, alpha=0.3)
    
    # Plot component losses if available
    if has_components and n_plots > 1:
        # Data loss
        if 'train_loss_data' in history:
            axes[1].plot(history['train_loss_data'], label='Train Data', 
                        linewidth=2, linestyle='-')
        if 'val_loss_data' in history:
            axes[1].plot(history['val_loss_data'], label='Val Data', 
                        linewidth=2, linestyle='-')
        
        # Physics loss
        if 'train_loss_physics' in history:
            axes[1].plot(history['train_loss_physics'], label='Train Physics', 
                        linewidth=2, linestyle='--')
        if 'val_loss_physics' in history:
            axes[1].plot(history['val_loss_physics'], label='Val Physics', 
                        linewidth=2, linestyle='--')
        
        # Regularization loss
        if 'train_loss_reg' in history:
            axes[1].plot(history['train_loss_reg'], label='Train Reg', 
                        linewidth=2, linestyle=':')
        if 'val_loss_reg' in history:
            axes[1].plot(history['val_loss_reg'], label='Val Reg', 
                        linewidth=2, linestyle=':')
        
        axes[1].set_xlabel('Epoch', fontsize=12)
        axes[1].set_ylabel('Loss Components', fontsize=12)
        axes[1].set_title('Loss Components')
        axes[1].legend(loc='best')
        axes[1].grid(True, alpha=0.3)
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Training curves saved to {save_path}")
    
    if show:
        plt.show()
    else:
        plt.close()


def plot_calibration_curve(
    confidence_levels: List[float],
    empirical_coverage: List[float],
    save_path: Optional[str] = None,
    title: str = "Calibration Curve",
    figsize: Tuple[int, int] = (8, 8),
    show: bool = False
) -> None:
    """
    Plot calibration curve comparing theoretical vs empirical coverage.
    
    Args:
        confidence_levels: List of theoretical confidence levels
        empirical_coverage: List of empirical coverage probabilities
        save_path: Path to save the figure
        title: Plot title
        figsize: Figure size
        show: Whether to show the plot
    """
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot diagonal (perfect calibration)
    ax.plot([0, 1], [0, 1], 'k--', linewidth=2, label='Perfect Calibration')
    
    # Plot actual calibration
    ax.plot(confidence_levels, empirical_coverage, 'bo-', 
           linewidth=2, markersize=8, label='Model')
    
    # Add points
    for conf, emp in zip(confidence_levels, empirical_coverage):
        offset = 0.02
        if emp > conf:
            va = 'bottom'
        else:
            va = 'top'
            offset = -offset
        
        ax.text(conf, emp + offset, f'{emp:.3f}', 
               ha='center', va=va, fontsize=9)
    
    ax.set_xlabel('Theoretical Coverage Probability', fontsize=12)
    ax.set_ylabel('Empirical Coverage Probability', fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.legend(loc='best', fontsize=11)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, 1.05)
    ax.set_aspect('equal')
    
    plt.tight_layout()
    
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    if show:
        plt.show()
    else:
        plt.close()


def plot_multiple_predictions(
    predictions: List[Dict[str, torch.Tensor]],
    save_path: Optional[str] = None,
    title: str = "Multiple Predictions",
    figsize: Optional[Tuple[int, int]] = None,
    show: bool = False,
    max_plots: int = 6
) -> None:
    """
    Plot multiple prediction samples in a grid.
    
    Args:
        predictions: List of dictionaries containing:
                    - 'x_context', 'y_context'
                    - 'x_target', 'y_target'
                    - 'y_pred_mean', 'y_pred_std'
        save_path: Path to save the figure
        title: Plot title
        figsize: Figure size (auto-calculated if None)
        show: Whether to show the plot
        max_plots: Maximum number of plots to show
    """
    n_plots = min(len(predictions), max_plots)
    
    # Calculate grid size
    n_cols = min(3, n_plots)
    n_rows = (n_plots + n_cols - 1) // n_cols
    
    if figsize is None:
        figsize = (5 * n_cols, 4 * n_rows)
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    
    if n_plots == 1:
        axes = np.array([[axes]])
    elif n_rows == 1:
        axes = axes.reshape(1, -1)
    elif n_cols == 1:
        axes = axes.reshape(-1, 1)
    
    for idx in range(n_plots):
        row = idx // n_cols
        col = idx % n_cols
        ax = axes[row, col]
        
        pred = predictions[idx]
        
        # Convert to numpy
        x_context = pred['x_context'].detach().cpu().numpy()
        y_context = pred['y_context'].detach().cpu().numpy()
        x_target = pred['x_target'].detach().cpu().numpy()
        y_target = pred['y_target'].detach().cpu().numpy()
        y_pred_mean = pred['y_pred_mean'].detach().cpu().numpy()
        y_pred_std = pred['y_pred_std'].detach().cpu().numpy()
        
        # Handle dimensions
        if x_target.ndim == 1:
            x_target = x_target[:, None]
        if y_target.ndim == 1:
            y_target = y_target[:, None]
        if y_pred_mean.ndim == 1:
            y_pred_mean = y_pred_mean[:, None]
        if y_pred_std.ndim == 1:
            y_pred_std = y_pred_std[:, None]
        
        # Sort for smooth plotting
        sort_idx = np.argsort(x_target[:, 0])
        x_sorted = x_target[sort_idx, 0]
        y_mean_sorted = y_pred_mean[sort_idx, 0]
        y_std_sorted = y_pred_std[sort_idx, 0]
        y_true_sorted = y_target[sort_idx, 0]
        
        # Plot
        ax.plot(x_sorted, y_true_sorted, 'k-', linewidth=2, 
               label='Truth', alpha=0.7)
        ax.plot(x_sorted, y_mean_sorted, 'b-', linewidth=2, 
               label='Prediction')
        
        # 90% confidence interval
        from scipy import stats
        z_score = stats.norm.ppf(0.95)
        lower = y_mean_sorted - z_score * y_std_sorted
        upper = y_mean_sorted + z_score * y_std_sorted
        ax.fill_between(x_sorted, lower, upper, alpha=0.3, 
                        color='lightblue', label='90% CI')
        
        # Context points
        if x_context.ndim == 1:
            x_context = x_context[:, None]
        if y_context.ndim == 1:
            y_context = y_context[:, None]
        
        ax.scatter(x_context[:, 0], y_context[:, 0], 
                  c='red', s=30, marker='o', 
                  label='Context', zorder=5, alpha=0.8)
        
        ax.set_title(f'Sample {idx + 1}', fontsize=10)
        ax.grid(True, alpha=0.3)
        
        if idx == 0:
            ax.legend(loc='best', fontsize=8)
    
    # Hide unused subplots
    for idx in range(n_plots, n_rows * n_cols):
        row = idx // n_cols
        col = idx % n_cols
        axes[row, col].axis('off')
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    if show:
        plt.show()
    else:
        plt.close()


def plot_error_distribution(
    errors: torch.Tensor,
    save_path: Optional[str] = None,
    title: str = "Prediction Error Distribution",
    figsize: Tuple[int, int] = (10, 5),
    show: bool = False
) -> None:
    """
    Plot distribution of prediction errors.
    
    Args:
        errors: Prediction errors [N]
        save_path: Path to save the figure
        title: Plot title
        figsize: Figure size
        show: Whether to show the plot
    """
    errors = errors.detach().cpu().numpy().flatten()
    
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    
    # Histogram
    axes[0].hist(errors, bins=50, density=True, alpha=0.7, 
                color='blue', edgecolor='black')
    
    # Fit and plot normal distribution
    mu, std = errors.mean(), errors.std()
    x = np.linspace(errors.min(), errors.max(), 100)
    axes[0].plot(x, 1/(std * np.sqrt(2 * np.pi)) * 
                np.exp(-0.5 * ((x - mu) / std) ** 2),
                'r-', linewidth=2, label=f'N({mu:.3f}, {std:.3f}²)')
    
    axes[0].set_xlabel('Prediction Error', fontsize=12)
    axes[0].set_ylabel('Density', fontsize=12)
    axes[0].set_title('Error Distribution')
    axes[0].legend(loc='best')
    axes[0].grid(True, alpha=0.3)
    
    # Q-Q plot
    from scipy import stats
    stats.probplot(errors, dist="norm", plot=axes[1])
    axes[1].set_title('Q-Q Plot')
    axes[1].grid(True, alpha=0.3)
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    if show:
        plt.show()
    else:
        plt.close()