"""Visualization utilities for ARCOS experiments."""

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from typing import Dict, List, Optional, Tuple
from pathlib import Path
import os

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")


class ARCOSPlotter:
    """Plotter for ARCOS experiment results."""
    
    def __init__(self, output_dir: str = "./outputs"):
        """Initialize plotter.
        
        Args:
            output_dir: Output directory for plots
        """
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # Set figure size and DPI
        self.figsize = (10, 6)
        self.dpi = 300
    
    def plot_metrics_over_rounds(
        self,
        metrics_df: pd.DataFrame,
        metric_name: str,
        policy_column: str = "policy",
        round_column: str = "round",
        save_name: Optional[str] = None,
        title: Optional[str] = None,
        ylabel: Optional[str] = None
    ) -> None:
        """Plot metric over rounds for different policies.
        
        Args:
            metrics_df: DataFrame with metrics
            metric_name: Name of metric to plot
            policy_column: Column name for policy
            round_column: Column name for round
            save_name: Name to save plot (without extension)
            title: Plot title
            ylabel: Y-axis label
        """
        plt.figure(figsize=self.figsize, dpi=self.dpi)
        
        # Get unique policies
        policies = metrics_df[policy_column].unique()
        
        for policy in policies:
            policy_data = metrics_df[metrics_df[policy_column] == policy]
            policy_data = policy_data.sort_values(round_column)
            
            plt.plot(
                policy_data[round_column],
                policy_data[metric_name],
                marker='o',
                linewidth=2,
                markersize=6,
                label=policy
            )
        
        plt.xlabel('Round', fontsize=12)
        plt.ylabel(ylabel or metric_name, fontsize=12)
        plt.title(title or f'{metric_name} Over Rounds', fontsize=14)
        plt.legend(fontsize=11)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        
        # Save plot
        if save_name is None:
            save_name = f"curves_{metric_name.lower().replace(' ', '_')}"
        
        save_path = self.output_dir / f"{save_name}.png"
        plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        plt.close()
        
        print(f"Saved plot: {save_path}")
    
    def plot_cost_vs_gain(
        self,
        metrics_df: pd.DataFrame,
        cost_column: str = "labeled_count",
        gain_column: str = "delta_R",
        policy_column: str = "policy",
        save_name: str = "cost_vs_gain"
    ) -> None:
        """Plot cost vs gain curve.
        
        Args:
            metrics_df: DataFrame with metrics
            cost_column: Column name for cost (e.g., number of labels)
            gain_column: Column name for gain (e.g., risk reduction)
            policy_column: Column name for policy
            save_name: Name to save plot (without extension)
        """
        plt.figure(figsize=self.figsize, dpi=self.dpi)
        
        # Get unique policies
        policies = metrics_df[policy_column].unique()
        
        for policy in policies:
            policy_data = metrics_df[metrics_df[policy_column] == policy]
            policy_data = policy_data.sort_values(cost_column)
            
            plt.plot(
                policy_data[cost_column],
                policy_data[gain_column],
                marker='o',
                linewidth=2,
                markersize=6,
                label=policy
            )
        
        plt.xlabel('Number of Labels', fontsize=12)
        plt.ylabel('Risk Reduction (|ΔR|)', fontsize=12)
        plt.title('Cost vs Gain: Labels vs Risk Reduction', fontsize=14)
        plt.legend(fontsize=11)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        
        # Save plot
        save_path = self.output_dir / f"{save_name}.png"
        plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        plt.close()
        
        print(f"Saved plot: {save_path}")
    
    def plot_budget_comparison(
        self,
        metrics_df: pd.DataFrame,
        budget_column: str = "budget",
        metric_name: str = "delta_R",
        policy_column: str = "policy",
        save_name: str = "budget_comparison"
    ) -> None:
        """Plot metric comparison across budgets.
        
        Args:
            metrics_df: DataFrame with metrics
            budget_column: Column name for budget
            metric_name: Name of metric to compare
            policy_column: Column name for policy
            save_name: Name to save plot (without extension)
        """
        plt.figure(figsize=self.figsize, dpi=self.dpi)
        
        # Get final round metrics for each budget and policy
        final_metrics = []
        for budget in metrics_df[budget_column].unique():
            for policy in metrics_df[policy_column].unique():
                budget_policy_data = metrics_df[
                    (metrics_df[budget_column] == budget) & 
                    (metrics_df[policy_column] == policy)
                ]
                if not budget_policy_data.empty:
                    final_round = budget_policy_data['round'].max()
                    final_metric = budget_policy_data[
                        budget_policy_data['round'] == final_round
                    ][metric_name].iloc[0]
                    final_metrics.append({
                        'budget': budget,
                        'policy': policy,
                        metric_name: final_metric
                    })
        
        final_df = pd.DataFrame(final_metrics)
        
        # Create grouped bar plot
        x = np.arange(len(final_df[budget_column].unique()))
        width = 0.35
        
        policies = final_df[policy_column].unique()
        
        for i, policy in enumerate(policies):
            policy_data = final_df[final_df[policy_column] == policy]
            policy_data = policy_data.sort_values('budget')
            
            plt.bar(
                x + i * width,
                policy_data[metric_name],
                width,
                label=policy,
                alpha=0.8
            )
        
        plt.xlabel('Budget (%)', fontsize=12)
        plt.ylabel(metric_name, fontsize=12)
        plt.title(f'{metric_name} Comparison Across Budgets', fontsize=14)
        plt.xticks(x + width/2, final_df['budget'].unique())
        plt.legend(fontsize=11)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        
        # Save plot
        save_path = self.output_dir / f"{save_name}.png"
        plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        plt.close()
        
        print(f"Saved plot: {save_path}")
    
    def plot_trace_components(
        self,
        metrics_df: pd.DataFrame,
        policy_column: str = "policy",
        round_column: str = "round",
        save_name: str = "trace_components"
    ) -> None:
        """Plot TRACE bound components over rounds.
        
        Args:
            metrics_df: DataFrame with metrics
            policy_column: Column name for policy
            round_column: Column name for round
            save_name: Name to save plot (without extension)
        """
        # Check if required columns exist
        if policy_column not in metrics_df.columns:
            print(f"Warning: {policy_column} column not found in metrics. Using single policy.")
            # Create a dummy policy column if it doesn't exist
            metrics_df = metrics_df.copy()
            metrics_df[policy_column] = "default"
        
        if round_column not in metrics_df.columns:
            print(f"Warning: {round_column} column not found in metrics. Using index as rounds.")
            metrics_df = metrics_df.copy()
            metrics_df[round_column] = range(len(metrics_df))
        
        plt.figure(figsize=(12, 8), dpi=self.dpi)
        
        # Create subplots for each component
        components = ['W1', 'output_discrepancy', 'Lx', 'bound_proxy']
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        axes = axes.ravel()
        
        # Get unique policies
        policies = metrics_df[policy_column].unique()
        
        for i, component in enumerate(components):
            ax = axes[i]
            
            # Check if component column exists
            if component not in metrics_df.columns:
                print(f"Warning: {component} column not found, skipping plot.")
                ax.text(0.5, 0.5, f'{component} data not available', 
                       ha='center', va='center', transform=ax.transAxes)
                ax.set_title(f'{component} Over Rounds', fontsize=12)
                continue
            
            for policy in policies:
                policy_data = metrics_df[metrics_df[policy_column] == policy]
                policy_data = policy_data.sort_values(round_column)
                
                ax.plot(
                    policy_data[round_column],
                    policy_data[component],
                    marker='o',
                    linewidth=2,
                    markersize=6,
                    label=policy
                )
            
            ax.set_xlabel('Round', fontsize=11)
            ax.set_ylabel(component, fontsize=11)
            ax.set_title(f'{component} Over Rounds', fontsize=12)
            ax.legend(fontsize=10)
            ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        # Save plot
        save_path = self.output_dir / f"{save_name}.png"
        plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        plt.close()
        
        print(f"Saved plot: {save_path}")
    
    def plot_training_curves(
        self,
        train_history: List[Dict[str, float]],
        val_history: Optional[List[Dict[str, float]]] = None,
        save_name: str = "training_curves"
    ) -> None:
        """Plot training curves.
        
        Args:
            train_history: List of training metrics per epoch
            val_history: List of validation metrics per epoch
            save_name: Name to save plot (without extension)
        """
        plt.figure(figsize=self.figsize, dpi=self.dpi)
        
        epochs = range(1, len(train_history) + 1)
        
        # Plot training metrics
        for metric in ['loss', 'accuracy']:
            if metric in train_history[0]:
                train_values = [epoch[metric] for epoch in train_history]
                plt.plot(epochs, train_values, marker='o', label=f'Train {metric}', linewidth=2)
        
        # Plot validation metrics
        if val_history:
            for metric in ['loss', 'accuracy']:
                if metric in val_history[0]:
                    val_values = [epoch[metric] for epoch in val_history]
                    plt.plot(epochs, val_values, marker='s', label=f'Val {metric}', linewidth=2)
        
        plt.xlabel('Epoch', fontsize=12)
        plt.ylabel('Metric Value', fontsize=12)
        plt.title('Training Curves', fontsize=14)
        plt.legend(fontsize=11)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        
        # Save plot
        save_path = self.output_dir / f"{save_name}.png"
        plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        plt.close()
        
        print(f"Saved plot: {save_path}")
    
    def create_summary_table(
        self,
        metrics_df: pd.DataFrame,
        save_name: str = "summary_table"
    ) -> None:
        """Create summary table of final metrics.
        
        Args:
            metrics_df: DataFrame with metrics
            save_name: Name to save table (without extension)
        """
        # Get final round metrics for each budget and policy
        final_metrics = []
        
        for budget in metrics_df['budget'].unique():
            for policy in metrics_df['policy'].unique():
                budget_policy_data = metrics_df[
                    (metrics_df['budget'] == budget) & 
                    (metrics_df['policy'] == policy)
                ]
                if not budget_policy_data.empty:
                    final_round = budget_policy_data['round'].max()
                    final_data = budget_policy_data[
                        budget_policy_data['round'] == final_round
                    ].iloc[0]
                    
                    final_metrics.append({
                        'Budget (%)': budget,
                        'Policy': policy,
                        'Final Round': final_round,
                        'Final |ΔR|': f"{final_data['delta_R']:.4f}",
                        'Final W1': f"{final_data['W1']:.4f}",
                        'Final Output Discrepancy': f"{final_data['output_discrepancy']:.4f}",
                        'Final Lx': f"{final_data['Lx']:.4f}",
                        'Final Bound Proxy': f"{final_data['bound_proxy']:.4f}"
                    })
        
        summary_df = pd.DataFrame(final_metrics)
        
        # Save as CSV
        csv_path = self.output_dir / f"{save_name}.csv"
        summary_df.to_csv(csv_path, index=False)
        print(f"Saved summary table: {csv_path}")
        
        # Create HTML table
        html_path = self.output_dir / f"{save_name}.html"
        html_table = summary_df.to_html(index=False, classes='table table-striped')
        
        html_content = f"""
        <!DOCTYPE html>
        <html>
        <head>
            <title>ARCOS Summary Table</title>
            <style>
                .table {{
                    border-collapse: collapse;
                    width: 100%;
                    margin: 20px 0;
                }}
                .table th, .table td {{
                    border: 1px solid #ddd;
                    padding: 8px;
                    text-align: left;
                }}
                .table th {{
                    background-color: #f2f2f2;
                    font-weight: bold;
                }}
                .table tr:nth-child(even) {{
                    background-color: #f9f9f9;
                }}
                .table tr:hover {{
                    background-color: #f5f5f5;
                }}
            </style>
        </head>
        <body>
            <h1>ARCOS Experiment Summary</h1>
            {html_table}
        </body>
        </html>
        """
        
        with open(html_path, 'w') as f:
            f.write(html_content)
        
        print(f"Saved HTML table: {html_path}")
    
    def plot_all_curves(self, metrics_df: pd.DataFrame) -> None:
        """Plot all standard curves for ARCOS metrics.
        
        Args:
            metrics_df: DataFrame with metrics
        """
        print("Creating all ARCOS curves...")
        
        # Plot individual metrics
        metrics_to_plot = [
            ('delta_R', '|ΔR|'),
            ('W1', 'W1 Distance'),
            ('output_discrepancy', 'Output Discrepancy'),
            ('bound_proxy', 'Bound Proxy')
        ]
        
        for metric, label in metrics_to_plot:
            if metric in metrics_df.columns:
                self.plot_metrics_over_rounds(
                    metrics_df, metric, ylabel=label,
                    save_name=f"curves_{metric}"
                )
        
        # Plot cost vs gain
        if 'labeled_count' in metrics_df.columns:
            self.plot_cost_vs_gain(metrics_df)
        
        # Plot budget comparison
        if 'budget' in metrics_df.columns:
            self.plot_budget_comparison(metrics_df)
        
        # Plot ARCOS components
        self.plot_trace_components(metrics_df)
        
        # Create summary table
        self.create_summary_table(metrics_df)
        
        print("All curves created successfully!")


def create_plots_from_csv(
    csv_path: str,
    output_dir: str = "./outputs"
) -> None:
    """Create all plots from a CSV metrics file.
    
    Args:
        csv_path: Path to CSV file with metrics
        output_dir: Output directory for plots
    """
    # Load metrics
    metrics_df = pd.read_csv(csv_path)
    
    # Create plotter
    plotter = ARCOSPlotter(output_dir)
    
    # Create all plots
    plotter.plot_all_curves(metrics_df)
    
    print(f"Plots created from {csv_path} and saved to {output_dir}")


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Create ARCOS plots from CSV")
    parser.add_argument("csv_path", help="Path to CSV metrics file")
    parser.add_argument("--output-dir", default="./outputs", help="Output directory for plots")
    
    args = parser.parse_args()
    
    create_plots_from_csv(args.csv_path, args.output_dir)

