"""
Visualization utilities for medical image segmentation results
"""

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from typing import List, Tuple, Dict, Optional
import os
from sklearn.metrics import confusion_matrix
import pandas as pd

# Set style
plt.style.use('seaborn-v0_8-paper')
sns.set_palette("husl")

class SegmentationVisualizer:
    """Visualization utilities for segmentation results"""
    
    def __init__(self, num_classes: int = 5, class_names: List[str] = None):
        self.num_classes = num_classes
        self.class_names = class_names or [f'Class {i}' for i in range(num_classes)]
        
        # Define colors for each class
        self.colors = [
            [1.0, 0.0, 0.0],  # Red - Heart
            [0.0, 1.0, 0.0],  # Green - Liver
            [0.0, 0.0, 1.0],  # Blue - Kidney
            [1.0, 1.0, 0.0],  # Yellow - Lung
            [1.0, 0.0, 1.0]   # Magenta - Brain
        ]
    
    def visualize_prediction(self, 
                           image: torch.Tensor, 
                           prediction: torch.Tensor, 
                           target: torch.Tensor,
                           threshold: float = 0.5,
                           save_path: str = None) -> plt.Figure:
        """
        Visualize segmentation prediction vs ground truth
        
        Args:
            image: (C, H, W) - input image
            prediction: (C, H, W) - predicted mask
            target: (C, H, W) - ground truth mask
            threshold: Threshold for binary conversion
            save_path: Path to save figure
        Returns:
            fig: Matplotlib figure
        """
        # Convert to numpy
        image_np = image.permute(1, 2, 0).cpu().numpy()
        pred_np = (prediction > threshold).float().cpu().numpy()
        target_np = target.cpu().numpy()
        
        # Normalize image
        image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
        
        # Create figure
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        # Original image
        axes[0, 0].imshow(image_np)
        axes[0, 0].set_title('Original Image')
        axes[0, 0].axis('off')
        
        # Ground truth
        target_colored = self._create_colored_mask(target_np)
        axes[0, 1].imshow(target_colored)
        axes[0, 1].set_title('Ground Truth')
        axes[0, 1].axis('off')
        
        # Prediction
        pred_colored = self._create_colored_mask(pred_np)
        axes[0, 2].imshow(pred_colored)
        axes[0, 2].set_title('Prediction')
        axes[0, 2].axis('off')
        
        # Overlay on original
        overlay_gt = self._overlay_mask(image_np, target_np, alpha=0.6)
        axes[1, 0].imshow(overlay_gt)
        axes[1, 0].set_title('GT Overlay')
        axes[1, 0].axis('off')
        
        overlay_pred = self._overlay_mask(image_np, pred_np, alpha=0.6)
        axes[1, 1].imshow(overlay_pred)
        axes[1, 1].set_title('Prediction Overlay')
        axes[1, 1].axis('off')
        
        # Difference
        diff = np.abs(pred_np - target_np)
        diff_colored = self._create_colored_mask(diff)
        axes[1, 2].imshow(diff_colored)
        axes[1, 2].set_title('Difference')
        axes[1, 2].axis('off')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        return fig
    
    def _create_colored_mask(self, mask: np.ndarray) -> np.ndarray:
        """Create colored mask from multi-class segmentation"""
        h, w, c = mask.shape
        colored_mask = np.zeros((h, w, 3))
        
        for class_idx in range(c):
            class_mask = mask[:, :, class_idx]
            if class_mask.max() > 0:
                colored_mask += np.stack([class_mask * self.colors[class_idx][i] for i in range(3)], axis=-1)
        
        return np.clip(colored_mask, 0, 1)
    
    def _overlay_mask(self, image: np.ndarray, mask: np.ndarray, alpha: float = 0.6) -> np.ndarray:
        """Overlay mask on image"""
        h, w, c = mask.shape
        overlay = image.copy()
        
        for class_idx in range(c):
            class_mask = mask[:, :, class_idx]
            if class_mask.max() > 0:
                for i in range(3):
                    overlay[:, :, i] = (1 - alpha) * overlay[:, :, i] + alpha * class_mask * self.colors[class_idx][i]
        
        return overlay
    
    def plot_training_curves(self, history: Dict[str, List[float]], save_path: str = None) -> plt.Figure:
        """Plot training curves"""
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        # Loss curves
        axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
        axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
        axes[0, 0].set_title('Training and Validation Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Dice score
        axes[0, 1].plot(history['val_dice'], label='Val Dice', color='green', linewidth=2)
        axes[0, 1].set_title('Validation Dice Score')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Dice Score')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # IoU score
        axes[0, 2].plot(history['val_iou'], label='Val IoU', color='blue', linewidth=2)
        axes[0, 2].set_title('Validation IoU Score')
        axes[0, 2].set_xlabel('Epoch')
        axes[0, 2].set_ylabel('IoU Score')
        axes[0, 2].legend()
        axes[0, 2].grid(True, alpha=0.3)
        
        # Hausdorff distance
        axes[1, 0].plot(history['val_hausdorff'], label='Val Hausdorff', color='red', linewidth=2)
        axes[1, 0].set_title('Validation Hausdorff Distance')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Hausdorff Distance')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # Boundary F1 score
        axes[1, 1].plot(history['val_boundary_f1'], label='Val Boundary F1', color='purple', linewidth=2)
        axes[1, 1].set_title('Validation Boundary F1 Score')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Boundary F1 Score')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        
        # Learning rate
        axes[1, 2].plot(history['learning_rate'], label='Learning Rate', color='orange', linewidth=2)
        axes[1, 2].set_title('Learning Rate Schedule')
        axes[1, 2].set_xlabel('Epoch')
        axes[1, 2].set_ylabel('Learning Rate')
        axes[1, 2].legend()
        axes[1, 2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        return fig
    
    def plot_metrics_comparison(self, metrics_dict: Dict[str, Dict[str, float]], save_path: str = None) -> plt.Figure:
        """Plot metrics comparison between different models"""
        models = list(metrics_dict.keys())
        metrics = ['mean_dice', 'mean_iou', 'mean_hausdorff', 'mean_boundary_f1', 'pixel_accuracy']
        
        # Create data for plotting
        data = []
        for model in models:
            for metric in metrics:
                data.append({
                    'Model': model,
                    'Metric': metric,
                    'Value': metrics_dict[model][metric]
                })
        
        df = pd.DataFrame(data)
        
        # Create figure
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        # Plot each metric
        for i, metric in enumerate(metrics):
            row = i // 3
            col = i % 3
            
            metric_data = df[df['Metric'] == metric]
            sns.barplot(data=metric_data, x='Model', y='Value', ax=axes[row, col])
            axes[row, col].set_title(f'{metric.replace("_", " ").title()}')
            axes[row, col].tick_params(axis='x', rotation=45)
            axes[row, col].grid(True, alpha=0.3)
        
        # Remove empty subplot
        axes[1, 2].remove()
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        return fig
    
    def plot_confusion_matrix(self, 
                            predictions: torch.Tensor, 
                            targets: torch.Tensor,
                            threshold: float = 0.5,
                            save_path: str = None) -> plt.Figure:
        """Plot confusion matrix for segmentation"""
        # Convert to binary
        pred_binary = (predictions > threshold).float()
        target_binary = targets.float()
        
        # Flatten and convert to numpy
        pred_flat = pred_binary.view(-1, self.num_classes).cpu().numpy()
        target_flat = target_binary.view(-1, self.num_classes).cpu().numpy()
        
        # Convert to class labels
        pred_labels = np.argmax(pred_flat, axis=1)
        target_labels = np.argmax(target_flat, axis=1)
        
        # Compute confusion matrix
        cm = confusion_matrix(target_labels, pred_labels, labels=range(self.num_classes))
        
        # Plot
        fig, ax = plt.subplots(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                   xticklabels=self.class_names, yticklabels=self.class_names, ax=ax)
        ax.set_title('Confusion Matrix')
        ax.set_xlabel('Predicted')
        ax.set_ylabel('Actual')
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        return fig
    
    def plot_attention_maps(self, 
                          image: torch.Tensor,
                          attention_maps: List[torch.Tensor],
                          save_path: str = None) -> plt.Figure:
        """Plot attention maps for visualization"""
        image_np = image.permute(1, 2, 0).cpu().numpy()
        image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
        
        num_maps = len(attention_maps)
        fig, axes = plt.subplots(2, (num_maps + 1) // 2, figsize=(15, 8))
        axes = axes.flatten() if num_maps > 1 else [axes]
        
        # Original image
        axes[0].imshow(image_np)
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        # Attention maps
        for i, att_map in enumerate(attention_maps):
            if i + 1 < len(axes):
                att_np = att_map.cpu().numpy()
                if len(att_np.shape) == 3:
                    att_np = att_np.mean(axis=0)
                
                im = axes[i + 1].imshow(att_np, cmap='hot')
                axes[i + 1].set_title(f'Attention Map {i + 1}')
                axes[i + 1].axis('off')
                plt.colorbar(im, ax=axes[i + 1])
        
        # Remove empty subplots
        for i in range(num_maps + 1, len(axes)):
            axes[i].remove()
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        return fig
    
    def create_results_summary(self, 
                             results: Dict[str, float],
                             save_path: str = None) -> plt.Figure:
        """Create results summary visualization"""
        # Create figure
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Main metrics
        main_metrics = ['mean_dice', 'mean_iou', 'mean_hausdorff', 'mean_boundary_f1']
        main_values = [results[metric] for metric in main_metrics]
        main_labels = [metric.replace('_', ' ').title() for metric in main_metrics]
        
        # Bar plot of main metrics
        bars = axes[0, 0].bar(main_labels, main_values, color=['green', 'blue', 'red', 'purple'])
        axes[0, 0].set_title('Main Performance Metrics')
        axes[0, 0].set_ylabel('Score')
        axes[0, 0].tick_params(axis='x', rotation=45)
        
        # Add value labels on bars
        for bar, value in zip(bars, main_values):
            axes[0, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                           f'{value:.3f}', ha='center', va='bottom')
        
        # Per-class dice scores
        class_dice = [results[f'dice_class_{i}'] for i in range(self.num_classes)]
        bars = axes[0, 1].bar(self.class_names, class_dice, color=self.colors[:self.num_classes])
        axes[0, 1].set_title('Per-Class Dice Scores')
        axes[0, 1].set_ylabel('Dice Score')
        axes[0, 1].tick_params(axis='x', rotation=45)
        
        # Add value labels
        for bar, value in zip(bars, class_dice):
            axes[0, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                           f'{value:.3f}', ha='center', va='bottom')
        
        # Performance comparison (if available)
        if 'baseline_dice' in results:
            comparison_metrics = ['Dice Score', 'IoU Score', 'Hausdorff Distance']
            msa_values = [results['mean_dice'], results['mean_iou'], results['mean_hausdorff']]
            baseline_values = [results['baseline_dice'], results['baseline_iou'], results['baseline_hausdorff']]
            
            x = np.arange(len(comparison_metrics))
            width = 0.35
            
            axes[1, 0].bar(x - width/2, msa_values, width, label='MSA-UNet', color='skyblue')
            axes[1, 0].bar(x + width/2, baseline_values, width, label='Baseline', color='lightcoral')
            axes[1, 0].set_title('Model Comparison')
            axes[1, 0].set_ylabel('Score')
            axes[1, 0].set_xticks(x)
            axes[1, 0].set_xticklabels(comparison_metrics)
            axes[1, 0].legend()
        
        # Summary statistics
        summary_text = f"""
        Model Performance Summary:
        
        Mean Dice Score: {results['mean_dice']:.4f}
        Mean IoU Score: {results['mean_iou']:.4f}
        Mean Hausdorff Distance: {results['mean_hausdorff']:.4f}
        Mean Boundary F1: {results['mean_boundary_f1']:.4f}
        Pixel Accuracy: {results['pixel_accuracy']:.4f}
        
        Best Performing Class: {self.class_names[np.argmax(class_dice)]}
        Worst Performing Class: {self.class_names[np.argmin(class_dice)]}
        """
        
        axes[1, 1].text(0.1, 0.5, summary_text, transform=axes[1, 1].transAxes,
                        fontsize=10, verticalalignment='center',
                        bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
        axes[1, 1].set_title('Summary Statistics')
        axes[1, 1].axis('off')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        return fig

def create_visualization_report(results_dir: str, 
                              model_name: str = 'MSA-UNet',
                              num_classes: int = 5) -> None:
    """Create comprehensive visualization report"""
    
    # Create visualizer
    class_names = ['Heart', 'Liver', 'Kidney', 'Lung', 'Brain']
    visualizer = SegmentationVisualizer(num_classes, class_names)
    
    # Load results
    results_file = os.path.join(results_dir, 'metrics.json')
    if os.path.exists(results_file):
        with open(results_file, 'r') as f:
            results = json.load(f)
        
        # Create results summary
        summary_fig = visualizer.create_results_summary(results)
        summary_fig.savefig(os.path.join(results_dir, 'results_summary.png'), 
                           dpi=300, bbox_inches='tight')
        plt.close(summary_fig)
        
        print(f"Visualization report created in {results_dir}")

if __name__ == "__main__":
    # Test visualizer
    visualizer = SegmentationVisualizer(5, ['Heart', 'Liver', 'Kidney', 'Lung', 'Brain'])
    
    # Create dummy data
    image = torch.randn(3, 128, 128)
    prediction = torch.randn(5, 128, 128)
    target = torch.randint(0, 2, (5, 128, 128)).float()
    
    # Test visualization
    fig = visualizer.visualize_prediction(image, prediction, target)
    plt.show()
    
    # Test training curves
    history = {
        'train_loss': [1.0, 0.8, 0.6, 0.4, 0.3],
        'val_loss': [1.1, 0.9, 0.7, 0.5, 0.4],
        'val_dice': [0.6, 0.7, 0.8, 0.85, 0.88],
        'val_iou': [0.5, 0.6, 0.7, 0.75, 0.8],
        'val_hausdorff': [10.0, 8.0, 6.0, 5.0, 4.0],
        'val_boundary_f1': [0.6, 0.7, 0.8, 0.85, 0.88],
        'learning_rate': [0.001, 0.001, 0.001, 0.0001, 0.0001]
    }
    
    fig = visualizer.plot_training_curves(history)
    plt.show()
    
    print("Visualization tests completed successfully!")

