import matplotlib.pyplot as plt
import logging as logger
import os
import numpy as np
import pandas as pd
import seaborn as sns

def plot_results(accuracy_data, output_dir):
    """
    Create and save visualizations of regression results
    
    Args:
        accuracy_data: Dictionary containing regression metrics
        output_dir: Directory to save the plots
    """
    # Create output directory if it doesn't exist
    plots_dir = os.path.join(output_dir, 'plots')
    os.makedirs(plots_dir, exist_ok=True)
    
    # Extract data
    client_task_metrics = accuracy_data.get('client_task_metrics', {})
    task_avg_metrics = accuracy_data.get('task_avg_metrics', {})
    client_avg_metrics = accuracy_data.get('client_avg_metrics', {})
    overall_metrics = accuracy_data.get('overall_metrics', {})
    
    # Create figure with multiple subplots for different metrics
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    axes = axes.flatten()
    
    # 1. Task Average MSE/RMSE Bar Chart
    ax = axes[0]
    tasks = list(task_avg_metrics.keys())
    mse_values = [task_avg_metrics[t]['mse'] for t in tasks]
    rmse_values = [task_avg_metrics[t]['rmse'] for t in tasks]
    
    x = np.arange(len(tasks))
    width = 0.35
    
    ax.bar(x - width/2, mse_values, width, label='MSE', color='skyblue')
    ax.bar(x + width/2, rmse_values, width, label='RMSE', color='lightcoral')
    ax.axhline(y=overall_metrics['mse'], color='b', linestyle='--', alpha=0.5, 
               label=f'Overall MSE: {overall_metrics["mse"]:.4f}')
    ax.axhline(y=overall_metrics['rmse'], color='r', linestyle='--', alpha=0.5,
               label=f'Overall RMSE: {overall_metrics["rmse"]:.4f}')
    
    ax.set_xlabel('Task ID')
    ax.set_ylabel('Error')
    ax.set_title('Average MSE and RMSE by Task')
    ax.set_xticks(x)
    ax.set_xticklabels(tasks)
    ax.legend()
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    # 2. Task Average R² Bar Chart
    ax = axes[1]
    r2_values = [task_avg_metrics[t]['r2'] for t in tasks]
    
    ax.bar(tasks, r2_values, color='lightgreen')
    ax.axhline(y=overall_metrics['r2'], color='r', linestyle='--', 
               label=f'Overall R²: {overall_metrics["r2"]:.4f}')
    ax.set_xlabel('Task ID')
    ax.set_ylabel('R² Score')
    ax.set_title('Average R² Score by Task')
    ax.set_xticks(tasks)
    ax.legend()
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    ax.set_ylim(0, 1.1)
    
    # 3. Client Average Metrics
    ax = axes[2]
    clients = list(client_avg_metrics.keys())
    client_mse = [client_avg_metrics[c]['mse'] for c in clients]
    client_r2 = [client_avg_metrics[c]['r2'] for c in clients]
    
    # Create secondary y-axis for R²
    ax2 = ax.twinx()
    
    p1 = ax.bar(clients, client_mse, color='lightblue', alpha=0.7, label='MSE')
    p2 = ax2.plot(clients, client_r2, 'ro-', linewidth=2, markersize=8, label='R²')
    
    ax.set_xlabel('Client ID')
    ax.set_ylabel('MSE', color='b')
    ax2.set_ylabel('R² Score', color='r')
    ax.set_title('Average Metrics by Client')
    ax.tick_params(axis='y', labelcolor='b')
    ax2.tick_params(axis='y', labelcolor='r')
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Combine legends
    lines = [p1] + p2
    labels = ['MSE', 'R²']
    ax.legend(lines, labels, loc='best')
    
    # 4. Heatmap of Client-Task R² Scores
    ax = axes[3]
    # Convert to DataFrame for easier plotting
    df_data = []
    for client_id, task_metrics in client_task_metrics.items():
        for task_id, metrics in task_metrics.items():
            df_data.append({
                'Client': client_id, 
                'Task': task_id, 
                'R²': metrics['r2']
            })
    
    if df_data:  # Only create heatmap if we have data
        df = pd.DataFrame(df_data)
        pivot_df = df.pivot(index='Client', columns='Task', values='R²')
        
        sns.heatmap(pivot_df, annot=True, cmap='viridis', fmt='.3f', 
                    cbar_kws={'label': 'R² Score'}, ax=ax)
        ax.set_title('Client-Task R² Score Heatmap')
    
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'regression_metrics.png'), bbox_inches='tight', dpi=300)
    plt.close()
    
    # Create additional error distribution plot
    plt.figure(figsize=(12, 6))
    
    # Collect all MSE values for distribution plot
    all_mse_values = []
    all_mae_values = []
    labels = []
    
    for client_id, task_metrics in client_task_metrics.items():
        for task_id, metrics in task_metrics.items():
            all_mse_values.append(metrics['mse'])
            all_mae_values.append(metrics['mae'])
            labels.append(f'C{client_id}-T{task_id}')
    
    if all_mse_values:
        plt.subplot(1, 2, 1)
        plt.boxplot([all_mse_values, all_mae_values], labels=['MSE', 'MAE'])
        plt.ylabel('Error Value')
        plt.title('Error Distribution Across All Client-Task Pairs')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        plt.subplot(1, 2, 2)
        plt.hist(all_mse_values, bins=20, alpha=0.5, label='MSE', density=True)
        plt.hist(all_mae_values, bins=20, alpha=0.5, label='MAE', density=True)
        plt.xlabel('Error Value')
        plt.ylabel('Density')
        plt.title('Error Distribution Histogram')
        plt.legend()
        plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'error_distribution.png'), bbox_inches='tight', dpi=300)
    plt.close()
    
    logger.info(f"Saved regression visualization plots to {plots_dir}")
    return plots_dir


def plot_training_curve(opt, round_metrics, round_labels, metric='r2'):
    """
    Create and save a plot of the training curve showing regression metrics per round
    
    Args:
        round_metrics: List of metric values per round (can be R², MSE, etc.)
        round_labels: List of round labels
        metric: Which metric to plot ('r2', 'mse', 'mae', 'rmse')
    """
    plt.figure(figsize=(12, 6))
    
    # Configure based on metric type
    metric_configs = {
        'r2': {
            'ylabel': 'R² Score',
            'title': 'GFedCL Training Curve: R² Score vs. Communication Round',
            'format': '{:.4f}',
            'better': 'higher'
        },
        'mse': {
            'ylabel': 'Mean Squared Error',
            'title': 'GFedCL Training Curve: MSE vs. Communication Round',
            'format': '{:.6f}',
            'better': 'lower'
        },
        'mae': {
            'ylabel': 'Mean Absolute Error',
            'title': 'GFedCL Training Curve: MAE vs. Communication Round',
            'format': '{:.6f}',
            'better': 'lower'
        },
        'rmse': {
            'ylabel': 'Root Mean Squared Error',
            'title': 'GFedCL Training Curve: RMSE vs. Communication Round',
            'format': '{:.6f}',
            'better': 'lower'
        }
    }
    
    config = metric_configs.get(metric, metric_configs['r2'])
    
    # Plot metric
    plt.plot(range(1, len(round_metrics) + 1), round_metrics, 'o-', linewidth=2, markersize=8)
    
    # Add grid and labels
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.xlabel('Communication Round', fontsize=12)
    plt.ylabel(config['ylabel'], fontsize=12)
    plt.title(config['title'], fontsize=14)
    
    # Add ticks and labels
    if len(round_labels) > 20:
        # If too many rounds, show fewer labels
        step = len(round_labels) // 20
        indices = list(range(0, len(round_labels), step))
        plt.xticks([i+1 for i in indices], [round_labels[i] for i in indices], rotation=45)
    else:
        plt.xticks(range(1, len(round_metrics) + 1), round_labels, rotation=45)
    
    # Annotate points with metric values (only every few points if many)
    annotation_step = max(1, len(round_metrics) // 10)
    for i in range(0, len(round_metrics), annotation_step):
        plt.annotate(config['format'].format(round_metrics[i]), 
                    (i + 1, round_metrics[i]),
                    textcoords="offset points", 
                    xytext=(0, 10), 
                    ha='center',
                    fontsize=8)
    
    # Add task boundaries if we have multiple tasks
    task_boundaries = []
    current_task = round_labels[0].split(',')[0] if round_labels else ""
    
    for i, label in enumerate(round_labels):
        task = label.split(',')[0]
        if task != current_task:
            task_boundaries.append(i + 0.5)
            current_task = task
    
    # Draw vertical lines for task boundaries
    for boundary in task_boundaries:
        plt.axvline(x=boundary, color='r', linestyle='--', alpha=0.5)
    
    # Add best/worst indicators
    if config['better'] == 'higher':
        best_idx = np.argmax(round_metrics)
        worst_idx = np.argmin(round_metrics)
    else:
        best_idx = np.argmin(round_metrics)
        worst_idx = np.argmax(round_metrics)
    
    plt.plot(best_idx + 1, round_metrics[best_idx], 'g*', markersize=15, 
             label=f'Best: {config["format"].format(round_metrics[best_idx])}')
    plt.plot(worst_idx + 1, round_metrics[worst_idx], 'rx', markersize=12, 
             label=f'Worst: {config["format"].format(round_metrics[worst_idx])}')
    
    plt.legend()
    plt.tight_layout()
    
    # Save the figure
    save_path = os.path.join(opt.output_dir, f'training_curve_{metric}.png')
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()
    
    logger.info(f"Saved {metric} training curve plot to {save_path}")


def plot_all_tasks_accuracy(opt, all_tasks_metrics, plots_dir=None):
    """
    Create a visualization showing the performance on all tasks over training rounds
    Now adapted for regression metrics
    """
    if plots_dir is None:
        plots_dir = os.path.join(opt.output_dir, 'plots')
    os.makedirs(plots_dir, exist_ok=True)
    
    # Prepare data for plotting
    num_tasks = opt.num_task
    rounds = []
    
    # Initialize task data for different metrics
    task_r2_data = {t: [] for t in range(num_tasks)}
    task_mse_data = {t: [] for t in range(num_tasks)}
    
    # Extract metric values for each task over time
    for round_data in all_tasks_metrics:
        rounds.append(round_data['round'])
        
        # Store metrics for each task (or None if not available)
        for t in range(num_tasks):
            if t in round_data['tasks']:
                # Extract R² from the stored accuracy (which is R² * 100)
                task_r2_data[t].append(round_data['tasks'][t] / 100.0)
                # If MSE data is available, use it; otherwise estimate from R²
                if 'tasks_mse' in round_data and t in round_data['tasks_mse']:
                    task_mse_data[t].append(round_data['tasks_mse'][t])
                else:
                    # Estimate MSE (this is approximate)
                    task_mse_data[t].append(1.0 - round_data['tasks'][t] / 100.0)
            else:
                task_r2_data[t].append(None)
                task_mse_data[t].append(None)
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
    
    # Different line styles and markers for each task
    line_styles = ['-', '--', '-.', ':', '-', '--', '-.']
    markers = ['o', 's', '^', 'D', 'x', '+', '*']
    colors = plt.cm.tab10.colors
    
    # Plot 1: R² scores
    for t in range(num_tasks):
        # Filter out None values for plotting
        valid_indices = [i for i, val in enumerate(task_r2_data[t]) if val is not None]
        valid_r2 = [task_r2_data[t][i] for i in valid_indices]
        
        if valid_r2:
            ax1.plot(
                valid_indices, 
                valid_r2, 
                label=f'Task {t+1}',
                linestyle=line_styles[t % len(line_styles)],
                marker=markers[t % len(markers)],
                color=colors[t % len(colors)],
                linewidth=2,
                markersize=8
            )
    
    ax1.set_ylabel('R² Score', fontsize=14)
    ax1.set_title('R² Score on All Tasks Throughout Training', fontsize=16)
    ax1.grid(True, linestyle='--', alpha=0.7)
    ax1.legend(fontsize=12, loc='lower right')
    ax1.set_ylim(-0.1, 1.1)
    
    # Plot 2: MSE values
    for t in range(num_tasks):
        # Filter out None values for plotting
        valid_indices = [i for i, val in enumerate(task_mse_data[t]) if val is not None]
        valid_mse = [task_mse_data[t][i] for i in valid_indices]
        
        if valid_mse:
            ax2.plot(
                valid_indices, 
                valid_mse, 
                label=f'Task {t+1}',
                linestyle=line_styles[t % len(line_styles)],
                marker=markers[t % len(markers)],
                color=colors[t % len(colors)],
                linewidth=2,
                markersize=8
            )
    
    ax2.set_xlabel('Training Round', fontsize=14)
    ax2.set_ylabel('MSE (Estimated)', fontsize=14)
    ax2.set_title('MSE on All Tasks Throughout Training', fontsize=16)
    ax2.grid(True, linestyle='--', alpha=0.7)
    ax2.legend(fontsize=12, loc='upper right')
    
    # Add task boundaries
    current_task = 0
    task_boundaries = []
    
    for i, round_label in enumerate(rounds):
        task = int(round_label.split(',')[0].split(' ')[1]) - 1
        if task > current_task:
            task_boundaries.append(i - 0.5)
            current_task = task
    
    # Draw vertical lines at task boundaries
    for boundary in task_boundaries:
        ax1.axvline(x=boundary, color='gray', linestyle='--', alpha=0.7)
        ax2.axvline(x=boundary, color='gray', linestyle='--', alpha=0.7)
    
    # Make x-axis tick labels more readable
    if len(rounds) > 20:
        skip = len(rounds) // 20 + 1
        plt.xticks(
            range(0, len(rounds), skip),
            [rounds[i] for i in range(0, len(rounds), skip)],
            rotation=45, 
            ha='right', 
            fontsize=10
        )
    else:
        plt.xticks(range(len(rounds)), rounds, rotation=45, ha='right', fontsize=10)
    
    plt.tight_layout()
    
    # Save figure
    file_path = os.path.join(plots_dir, 'all_tasks_regression_metrics.png')
    plt.savefig(file_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    logger.info(f"Saved all-tasks regression metrics visualization to {file_path}")
    
    # Create a forgetting analysis specific to regression
    if num_tasks > 1 and all_tasks_metrics:
        plot_regression_forgetting(opt, all_tasks_metrics, plots_dir)
    
    return file_path


def plot_regression_forgetting(opt, all_tasks_metrics, plots_dir):
    """
    Create a visualization focusing on performance degradation in regression metrics
    """
    plt.figure(figsize=(12, 8))
    
    num_tasks = opt.num_task
    
    # Calculate performance degradation for each previous task after learning a new task
    task_transitions = []
    
    for current_task in range(1, num_tasks):
        current_task_rounds = []
        for i, round_data in enumerate(all_tasks_metrics):
            if round_data['current_task'] == current_task:
                current_task_rounds.append(i)
        
        if not current_task_rounds:
            continue
        
        for prev_task in range(current_task):
            # Get best R² on previous task before starting current task
            best_prev_r2 = 0
            
            for i, round_data in enumerate(all_tasks_metrics):
                if i < current_task_rounds[0]:
                    if prev_task in round_data['tasks']:
                        # Convert from percentage back to R²
                        r2_value = round_data['tasks'][prev_task] / 100.0
                        best_prev_r2 = max(best_prev_r2, r2_value)
            
            # Track R² degradation throughout current task
            degradation_data = []
            round_indices = []
            round_labels = []
            
            for i in current_task_rounds:
                round_data = all_tasks_metrics[i]
                if prev_task in round_data['tasks']:
                    current_r2 = round_data['tasks'][prev_task] / 100.0
                    degradation = max(0, best_prev_r2 - current_r2)
                    degradation_data.append(degradation)
                    round_indices.append(i)
                    round_labels.append(round_data['round'])
            
            if degradation_data:
                plt.plot(
                    round_indices,
                    degradation_data,
                    label=f'Task {prev_task+1} → Task {current_task+1}',
                    marker='o',
                    linewidth=2
                )
    
    plt.xlabel('Training Round', fontsize=14)
    plt.ylabel('R² Degradation', fontsize=14)
    plt.title('Performance Degradation (R² Loss) Throughout Training', fontsize=16)
    plt.grid(True, linestyle='--', alpha=0.7)
    
    if len(plt.gca().get_lines()) > 1:
        plt.legend(fontsize=12, loc='upper left')
    
    # Make x-axis labels more readable
    round_labels = [d['round'] for d in all_tasks_metrics]
    if len(round_labels) > 20:
        skip = len(round_labels) // 20 + 1
        plt.xticks(
            range(0, len(round_labels), skip),
            [round_labels[i] for i in range(0, len(round_labels), skip)],
            rotation=45, 
            ha='right', 
            fontsize=10
        )
    else:
        plt.xticks(range(len(round_labels)), round_labels, rotation=45, ha='right', fontsize=10)
    
    plt.tight_layout()
    
    file_path = os.path.join(plots_dir, 'regression_performance_degradation.png')
    plt.savefig(file_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    logger.info(f"Saved regression performance degradation visualization to {file_path}")
    
    return file_path