"""
Analysis script for multi-modal baseline methods results.

Reads all baseline results from 040_multimodal_param_sim_baselines.py
and creates heatmap visualizations showing classification accuracies
for each method across different datasets, subspaces, and label types.
"""

import numpy as np
import pandas as pd
import os
import glob
import re
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

# Paths
RESULTS_DIR = "03_results/reports"
OUTPUT_DIR = "03_results/processed"
PLOTS_DIR = "03_results/plots"
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(PLOTS_DIR, exist_ok=True)

# Find all baseline result files
baseline_files = glob.glob(os.path.join(RESULTS_DIR, "baselines_mm-parametric*.csv"))
print(f"Found {len(baseline_files)} baseline result files")

# Find all FiGURO result files
figuro_files = glob.glob(os.path.join(RESULTS_DIR, "larrp_mm-parametric-sim4*.csv"))
print(f"Found {len(figuro_files)} FiGURO result files")

if len(baseline_files) == 0 and len(figuro_files) == 0:
    print("No results found!")
    exit(1)

# Read and combine all results
all_results = []

for fpath in baseline_files:
    # Extract dataset name from filename
    # Format: baselines_mm-parametric_n-10000_rseed-0_paired-True_v-small.csv
    fname = os.path.basename(fpath)
    match = re.search(r'v-([^.]+)\.csv', fname)
    if match:
        dataset_name = match.group(1)
    else:
        dataset_name = "unknown"
    
    # Extract other metadata
    n_match = re.search(r'n-(\d+)', fname)
    n_samples = int(n_match.group(1)) if n_match else None
    
    seed_match = re.search(r'rseed-(\d+)', fname)
    seed = int(seed_match.group(1)) if seed_match else None
    
    paired_match = re.search(r'paired-(True|False)', fname)
    paired = paired_match.group(1) == 'True' if paired_match else None
    
    # Read the file
    df = pd.read_csv(fpath)
    df['dataset'] = dataset_name
    df['n_samples'] = n_samples
    df['seed'] = seed
    df['paired'] = paired
    df['filepath'] = fpath
    
    all_results.append(df)

# Process FiGURO result files
for fpath in figuro_files:
    # Extract dataset name from filename
    # Format: larrp_mm-parametric-sim4_n-10000_rseed-42_paired-True_v-small.csv
    fname = os.path.basename(fpath)
    match = re.search(r'v-([^.]+)\.csv', fname)
    if match:
        dataset_name = match.group(1)
    else:
        dataset_name = "unknown"
    
    # Extract other metadata
    n_match = re.search(r'n-(\d+)', fname)
    n_samples = int(n_match.group(1)) if n_match else None
    
    seed_match = re.search(r'rseed-(\d+)', fname)
    seed = int(seed_match.group(1)) if seed_match else None
    
    paired_match = re.search(r'paired-(True|False)', fname)
    paired = paired_match.group(1) == 'True' if paired_match else None
    
    # Read the FiGURO file
    df = pd.read_csv(fpath)
    
    # Filter to only r_square_threshold = 0.05
    df = df[df['r_square_threshold'] == 0.05]
    
    if len(df) == 0:
        print(f"Warning: No data with r_square_threshold=0.05 in {fname}")
        continue
    
    # Get the final row for this threshold
    final_row = df.iloc[-1]
    
    # Extract final ranks (shared, specific_1, specific_2)
    final_ranks_str = final_row['final_ranks']
    final_ranks = [int(x.strip()) for x in final_ranks_str.split(',')]
    
    # Extract classification accuracies
    class_acc_str = final_row['classification_accuracy']
    class_accs = [float(x.strip()) for x in class_acc_str.split(',')]
    
    # Extract label predictions
    label1_str = final_row['label_1_pred']
    label1_preds = [float(x.strip()) for x in label1_str.split(',')]
    
    label2_str = final_row['label_2_pred']
    label2_preds = [float(x.strip()) for x in label2_str.split(',')]
    
    # Create FiGURO result row
    # For FiGURO: we have 3 spaces (Shared, Specific 1, Specific 2)
    # We'll duplicate the Shared space as both joint_X and joint_Y
    # and keep individual_X and individual_Y as Specific 1 and 2
    
    # Build subspace_names: joint_X, joint_Y, individual_X, individual_Y
    subspace_names = "joint_X, joint_Y, individual_X, individual_Y"
    
    # Build metrics by duplicating shared space values
    classification_accuracy = f"{class_accs[0]}, {class_accs[0]}, {class_accs[1]}, {class_accs[2]}"
    label_1_pred = f"{label1_preds[0]}, {label1_preds[0]}, {label1_preds[1]}, {label1_preds[2]}"
    label_2_pred = f"{label2_preds[0]}, {label2_preds[0]}, {label2_preds[1]}, {label2_preds[2]}"
    
    # Create a DataFrame row for FiGURO
    figuro_row = pd.DataFrame([{
        'method': 'FiGURO',
        'joint_rank': final_ranks[0],  # Shared rank
        'individual_rank_X': final_ranks[1],  # Specific 1 rank
        'individual_rank_Y': final_ranks[2],  # Specific 2 rank
        'subspace_names': subspace_names,
        'classification_accuracy': classification_accuracy,
        'silhouette_score': final_row.get('silhouette_score', ''),
        'label_1_pred': label_1_pred,
        'label_2_pred': label_2_pred,
        'dataset': dataset_name,
        'n_samples': n_samples,
        'seed': seed,
        'paired': paired,
        'filepath': fpath
    }])
    
    all_results.append(figuro_row)

# Combine all results
combined_df = pd.concat(all_results, ignore_index=True)
print(f"\nTotal rows: {len(combined_df)}")
print(f"Methods: {combined_df['method'].unique()}")
print(f"Datasets: {combined_df['dataset'].unique()}")

# Parse the comma-separated accuracy strings
def parse_accuracies(row):
    """Parse classification_accuracy string into individual values."""
    acc_str = row['classification_accuracy']
    subspace_str = row['subspace_names']
    
    # Parse subspace names
    subspaces = [s.strip() for s in subspace_str.split(',')]
    
    # Parse accuracies
    accs = [float(a.strip()) for a in acc_str.split(',')]
    
    return pd.Series({
        subspace: acc for subspace, acc in zip(subspaces, accs)
    })

combined_df = pd.concat(all_results, ignore_index=True)
print(f"\nTotal rows: {len(combined_df)}")
print(f"Methods: {combined_df['method'].unique()}")
print(f"Datasets: {combined_df['dataset'].unique()}")

# Parse the comma-separated strings to extract metrics for each subspace and label
def parse_metrics_by_subspace_and_label(row):
    """
    Parse classification_accuracy and label predictions for each subspace.
    Returns a dict mapping (subspace, label) to accuracy.
    """
    subspace_str = row['subspace_names']
    acc_str = row['classification_accuracy']
    label1_str = row['label_1_pred']
    label2_str = row['label_2_pred']
    
    # Parse subspace names
    subspaces = [s.strip() for s in subspace_str.split(',')]
    
    # Parse accuracies (label 0)
    accs_label0 = [float(a.strip()) for a in acc_str.split(',')]
    
    # Parse label 1 predictions (R²)
    accs_label1 = [float(a.strip()) for a in label1_str.split(',')]
    
    # Parse label 2 predictions (R²)
    accs_label2 = [float(a.strip()) for a in label2_str.split(',')]
    
    results = {}
    for i, subspace in enumerate(subspaces):
        results[(subspace, 'label_0')] = accs_label0[i]
        results[(subspace, 'label_1')] = accs_label1[i]
        results[(subspace, 'label_2')] = accs_label2[i]
    
    return results

# Create detailed results with one row per method-dataset-subspace-label
detailed_results = []

for idx, row in combined_df.iterrows():
    metrics = parse_metrics_by_subspace_and_label(row)
    
    for (subspace, label), accuracy in metrics.items():
        # Skip joint_Z subspace
        if subspace == 'joint_Z':
            continue
            
        detailed_results.append({
            'dataset': row['dataset'],
            'method': row['method'],
            'n_samples': row['n_samples'],
            'seed': row['seed'],
            'paired': row['paired'],
            'subspace': subspace,
            'label': label,
            'accuracy': accuracy,
            'joint_rank': row['joint_rank'],
            'individual_rank_X': row['individual_rank_X'],
            'individual_rank_Y': row['individual_rank_Y']
        })

detailed_df = pd.DataFrame(detailed_results)

# Save detailed results
detailed_output = os.path.join(OUTPUT_DIR, "baselines_mm_parametric_detailed.csv")
detailed_df.to_csv(detailed_output, index=False)
print(f"\nSaved detailed results to {detailed_output}")

# Define method ordering (needed for rank tables and heatmaps)
method_order = ['CCA', 'DIVAS', 'JIVE', 'AJIVE', 'PPD', 'SLIDE', 'ShIndICA', 'FiGURO']
methods = [m for m in method_order if m in detailed_df['method'].unique()]

# ----- Save rank summary tables per dataset -----
# Create a table for each dataset with rows: joint_rank, individual_rank_X, individual_rank_Y
# and columns = methods. If multiple seeds/runs exist, use the median rank per method.
rank_tables_dir = os.path.join(OUTPUT_DIR, 'rank_tables')
os.makedirs(rank_tables_dir, exist_ok=True)

all_datasets = sorted(combined_df['dataset'].unique())
for dataset in all_datasets:
    table = pd.DataFrame(index=['joint_rank', 'individual_rank_X', 'individual_rank_Y'], columns=methods)
    for method in methods:
        sel = combined_df[(combined_df['dataset'] == dataset) & (combined_df['method'] == method)]
        if sel.empty:
            table.loc[:, method] = pd.NA
            continue
        
        # For FiGURO, compute mean ± SEM across seeds
        if method == 'FiGURO':
            n_seeds = sel['seed'].nunique()
            
            joint_vals = sel['joint_rank'].dropna()
            indx_vals = sel['individual_rank_X'].dropna()
            indy_vals = sel['individual_rank_Y'].dropna()
            
            # Helper function to format mean ± SEM
            def _format_mean_sem(values):
                if len(values) == 0:
                    return pd.NA
                mean_val = values.mean()
                if len(values) > 1:
                    sem_val = values.sem()
                    return f"{mean_val:.1f} ± {sem_val:.1f}"
                else:
                    return f"{mean_val:.1f}"
            
            table.at['joint_rank', method] = _format_mean_sem(joint_vals)
            table.at['individual_rank_X', method] = _format_mean_sem(indx_vals)
            table.at['individual_rank_Y', method] = _format_mean_sem(indy_vals)
        else:
            # For other methods, take median across runs/seeds
            joint_med = sel['joint_rank'].dropna().median() if 'joint_rank' in sel.columns else pd.NA
            indx_med = sel['individual_rank_X'].dropna().median() if 'individual_rank_X' in sel.columns else pd.NA
            indy_med = sel['individual_rank_Y'].dropna().median() if 'individual_rank_Y' in sel.columns else pd.NA
            # round to integer where possible
            def _maybe_int(x):
                try:
                    if pd.isna(x):
                        return pd.NA
                    return int(round(float(x)))
                except Exception:
                    return pd.NA
            table.at['joint_rank', method] = _maybe_int(joint_med)
            table.at['individual_rank_X', method] = _maybe_int(indx_med)
            table.at['individual_rank_Y', method] = _maybe_int(indy_med)

    out_csv = os.path.join(rank_tables_dir, f"baselines_mm_parametric_ranks_{dataset}.csv")
    table.to_csv(out_csv)
    print(f"Saved rank table for dataset '{dataset}' -> {out_csv}")

# Also save a combined CSV with dataset as top-level index
combined_tables = []
for dataset in all_datasets:
    df_table = pd.read_csv(os.path.join(rank_tables_dir, f"baselines_mm_parametric_ranks_{dataset}.csv"), index_col=0)
    # reshape to long format
    df_long = df_table.reset_index().melt(id_vars='index', var_name='method', value_name='rank')
    df_long.rename(columns={'index': 'rank_name'}, inplace=True)
    df_long['dataset'] = dataset
    combined_tables.append(df_long)
if combined_tables:
    combined_ranks_df = pd.concat(combined_tables, ignore_index=True)
    combined_ranks_csv = os.path.join(OUTPUT_DIR, 'baselines_mm_parametric_ranks_alldatasets.csv')
    combined_ranks_df.to_csv(combined_ranks_csv, index=False)
    print(f"Saved combined rank table -> {combined_ranks_csv}")


# Create heatmap figure: one subplot per method
# Rows: subspaces (joint_X, joint_Y, individual_X, individual_Y)
# Columns: labels (label_0, label_1, label_2)
# Values: accuracies averaged across datasets

n_methods = len(methods)

# Define consistent subspace ordering
subspace_order = ['joint_X', 'joint_Y', 'individual_X', 'individual_Y']
label_order = ['label_0', 'label_1', 'label_2']

# Create figure with one subplot per method
# 18 cm width = 18/2.54 inches ≈ 7.09 inches
# Increase width to better accommodate 2x4 layout: use ~32 cm width
fig, axes = plt.subplots(2, 4, figsize=(12.6, 6.3))
axes = axes.flatten()  # Flatten to 1D array for easier indexing

for i, method in enumerate(methods):
    ax = axes[i]
    
    # Filter data for this method
    method_data = detailed_df[detailed_df['method'] == method]
    
    # Check if this method has multiple seeds (for SEM calculation)
    n_unique_seeds = method_data['seed'].nunique()
    
    # Compute average accuracy and SEM across datasets for each (subspace, label) combination
    heatmap_mean = method_data.groupby(['subspace', 'label'])['accuracy'].mean().unstack(fill_value=np.nan)
    heatmap_sem = method_data.groupby(['subspace', 'label'])['accuracy'].sem().unstack(fill_value=np.nan)
    heatmap_count = method_data.groupby(['subspace', 'label'])['accuracy'].count().unstack(fill_value=0)
    
    # Reindex to ensure consistent ordering
    heatmap_mean = heatmap_mean.reindex(index=subspace_order, columns=label_order)
    heatmap_sem = heatmap_sem.reindex(index=subspace_order, columns=label_order)
    heatmap_count = heatmap_count.reindex(index=subspace_order, columns=label_order)
    
    # Create annotations with mean ± SEM
    annot_labels = np.empty_like(heatmap_mean, dtype=object)
    for row_idx, row_name in enumerate(subspace_order):
        for col_idx, col_name in enumerate(label_order):
            mean_val = heatmap_mean.loc[row_name, col_name]
            sem_val = heatmap_sem.loc[row_name, col_name]
            count_val = heatmap_count.loc[row_name, col_name]
            
            if np.isnan(mean_val):
                annot_labels[row_idx, col_idx] = ''
            # Only show SEM if method has multiple seeds AND count > 1
            elif n_unique_seeds > 1 and count_val > 1 and not np.isnan(sem_val):
                annot_labels[row_idx, col_idx] = f'{mean_val:.2f}\n±{sem_val:.2f}'
            else:
                annot_labels[row_idx, col_idx] = f'{mean_val:.2f}'
    
    # Create heatmap with 11pt font size
    sns.heatmap(heatmap_mean, annot=annot_labels, fmt='', cmap='coolwarm', 
                vmin=0, vmax=1, ax=ax, cbar_kws={'label': 'Accuracy / R²'},
                linewidths=0.5, linecolor='gray', annot_kws={'fontsize': 8})
    
    ax.set_title(f'{method}', fontsize=11, fontweight='bold')
    ax.set_xlabel('Label Type', fontsize=11)
    ax.set_ylabel('Subspace', fontsize=11)
    
    # Set tick label font sizes
    ax.tick_params(axis='both', which='major', labelsize=11)
    
    # Rotate/adjust y-axis labels for better readability; for FiGURO show a single 'joint' label
    if method == 'FiGURO':
        # Convert existing tick Text objects to strings and replace joint_X/joint_Y with 'joint'
        ytxts = [t.get_text() for t in ax.get_yticklabels()]
        new_yt = ['joint' if t in ('joint_X', 'joint_Y') else t for t in ytxts]
        ax.set_yticklabels(new_yt, rotation=0)
    else:
        ax.set_yticklabels([t.get_text() for t in ax.get_yticklabels()], rotation=0)
    ax.set_xticklabels([t.get_text() for t in ax.get_xticklabels()], rotation=45, ha='right')

plt.tight_layout()
heatmap_output = os.path.join(PLOTS_DIR, "baselines_mm_parametric_heatmaps.png")
plt.savefig(heatmap_output, dpi=300, bbox_inches='tight')
print(f"\nSaved heatmap figure to {heatmap_output}")
plt.close()

# Also create individual heatmaps per dataset
datasets = sorted(detailed_df['dataset'].unique())

for dataset in datasets:
    fig, axes = plt.subplots(2, 4, figsize=(12.6, 6.3))
    axes = axes.flatten()  # Flatten to 1D array for easier indexing
    
    for i, method in enumerate(methods):
        ax = axes[i]
        
        # Filter data for this method and dataset
        method_data = detailed_df[(detailed_df['method'] == method) & 
                                  (detailed_df['dataset'] == dataset)]
        
        if len(method_data) == 0:
            ax.text(0.5, 0.5, 'No data', ha='center', va='center', 
                   transform=ax.transAxes, fontsize=11)
            ax.set_title(f'{method}', fontsize=11, fontweight='bold')
            ax.axis('off')
            continue
        
        # Check if this method has multiple seeds (for SEM calculation)
        n_unique_seeds = method_data['seed'].nunique()
        
        # Compute average accuracy and SEM for each (subspace, label) combination
        heatmap_mean = method_data.groupby(['subspace', 'label'])['accuracy'].mean().unstack(fill_value=np.nan)
        heatmap_sem = method_data.groupby(['subspace', 'label'])['accuracy'].sem().unstack(fill_value=np.nan)
        heatmap_count = method_data.groupby(['subspace', 'label'])['accuracy'].count().unstack(fill_value=0)
        
        # Reindex to ensure consistent ordering
        heatmap_mean = heatmap_mean.reindex(index=subspace_order, columns=label_order)
        heatmap_sem = heatmap_sem.reindex(index=subspace_order, columns=label_order)
        heatmap_count = heatmap_count.reindex(index=subspace_order, columns=label_order)
        
        # Create annotations with mean ± SEM
        annot_labels = np.empty_like(heatmap_mean, dtype=object)
        for row_idx, row_name in enumerate(subspace_order):
            for col_idx, col_name in enumerate(label_order):
                mean_val = heatmap_mean.loc[row_name, col_name]
                sem_val = heatmap_sem.loc[row_name, col_name]
                count_val = heatmap_count.loc[row_name, col_name]
                
                if np.isnan(mean_val):
                    annot_labels[row_idx, col_idx] = ''
                # Only show SEM if method has multiple seeds AND count > 1
                elif n_unique_seeds > 1 and count_val > 1 and not np.isnan(sem_val):
                    annot_labels[row_idx, col_idx] = f'{mean_val:.2f}\n±{sem_val:.2f}'
                else:
                    annot_labels[row_idx, col_idx] = f'{mean_val:.2f}'
        
        # Create heatmap with 11pt font size
        sns.heatmap(heatmap_mean, annot=annot_labels, fmt='', cmap='coolwarm', 
                    vmin=0, vmax=1, ax=ax, cbar_kws={'label': 'Accuracy / R²'},
                    linewidths=0.5, linecolor='gray', annot_kws={'fontsize': 8})
        
        ax.set_title(f'{method}', fontsize=11, fontweight='bold')
        ax.set_xlabel('Label Type', fontsize=11)
        ax.set_ylabel('Subspace', fontsize=11)
        
        # Set tick label font sizes
        ax.tick_params(axis='both', which='major', labelsize=11)
        
        # Rotate/adjust y-axis labels; for FiGURO show a single 'joint' label
        if method == 'FiGURO':
            ytxts = [t.get_text() for t in ax.get_yticklabels()]
            new_yt = ['joint' if t in ('joint_X', 'joint_Y') else t for t in ytxts]
            ax.set_yticklabels(new_yt, rotation=0)
        else:
            ax.set_yticklabels([t.get_text() for t in ax.get_yticklabels()], rotation=0)
        ax.set_xticklabels([t.get_text() for t in ax.get_xticklabels()], rotation=45, ha='right')
    
    # Remove the dataset title
    plt.tight_layout()
    heatmap_output = os.path.join(PLOTS_DIR, f"baselines_mm_parametric_heatmaps_{dataset}.png")
    plt.savefig(heatmap_output, dpi=300, bbox_inches='tight')
    print(f"Saved heatmap for dataset {dataset}")
    plt.close()

print("\nAnalysis complete!")
