"""
Analysis script for additional loss term experiments.

This script reads results from experiments with different additional loss terms:
- L1 regularization (040_multimodal_param_sim_testl1.py)
- Orthogonal loss (040_multimodal_param_sim_ortholoss.py)
- L2 norm regularization (040_multimodal_param_sim_noshared.py)

For each experiment type and dataset version, it computes:
- Mean ± SEM of final ranks (shared, modality 1, modality 2)
- Grouped by loss term weight
"""

import os
import sys
import pandas as pd
import numpy as np
import glob
import re
from pathlib import Path

project_root = Path(__file__).parent.parent.parent.absolute()

# Add src to path
sys.path.append(str(project_root))

from calculate_mean_sem import calculate_mean_sem


def parse_final_ranks(ranks_str):
    """Parse the final_ranks string into individual rank values."""
    if pd.isna(ranks_str):
        return None, None, None
    
    # Remove brackets and split
    ranks_str = str(ranks_str).strip()
    parts = ranks_str.split(',')
    
    if len(parts) >= 3:
        try:
            shared_rank = int(parts[0].strip())
            mod1_rank = int(parts[1].strip())
            mod2_rank = int(parts[2].strip())
            return shared_rank, mod1_rank, mod2_rank
        except (ValueError, IndexError):
            return None, None, None
    return None, None, None


def parse_three_values(values_str):
    """Parse a string with 3 comma-separated values into individual float values."""
    if pd.isna(values_str):
        return None, None, None
    
    values_str = str(values_str).strip()
    parts = values_str.split(',')
    
    if len(parts) >= 3:
        try:
            val0 = float(parts[0].strip())
            val1 = float(parts[1].strip())
            val2 = float(parts[2].strip())
            return val0, val1, val2
        except (ValueError, IndexError):
            return None, None, None
    return None, None, None


def extract_metadata_from_filename(filename):
    """Extract metadata (seed, data_version, etc.) from filename."""
    metadata = {}
    
    # Extract seed
    seed_match = re.search(r'rseed-(\d+)', filename)
    if seed_match:
        metadata['seed'] = int(seed_match.group(1))
    
    # Extract data version
    version_match = re.search(r'_v-([^_\.]+)', filename)
    if version_match:
        metadata['data_version'] = version_match.group(1)
    
    # Extract n_samples
    n_match = re.search(r'_n-(\d+)', filename)
    if n_match:
        metadata['n_samples'] = int(n_match.group(1))
    
    # Extract paired status for L1 and ortho
    paired_match = re.search(r'paired-(True|False)', filename)
    if paired_match:
        metadata['paired'] = paired_match.group(1) == 'True'
    
    # Extract l2norm weight for noshared experiments
    l2norm_match = re.search(r'l2norm-([0-9.]+?)(?:w|\.csv)', filename)
    if l2norm_match:
        try:
            metadata['l2norm'] = float(l2norm_match.group(1))
        except ValueError:
            # If conversion fails, try removing trailing dots
            weight_str = l2norm_match.group(1).rstrip('.')
            metadata['l2norm'] = float(weight_str)
    
    return metadata


def load_l1_results():
    """Load results from L1 regularization experiments."""
    pattern = "03_results/reports/larrp_mm-parametric-sim6-l1*.csv"
    files = glob.glob(pattern)
    
    all_data = []
    for file in files:
        metadata = extract_metadata_from_filename(file)
        try:
            df = pd.read_csv(file)
            
            # Get all unique combinations of weight and final_ranks
            if len(df) > 0 and 'l1_weight' in df.columns and 'final_ranks' in df.columns:
                # Group by config/weight to get final results for each configuration
                unique_configs = df.groupby(['config', 'l1_weight']).last().reset_index()
                
                for _, row in unique_configs.iterrows():
                    # Parse final ranks
                    shared, mod1, mod2 = parse_final_ranks(row.get('final_ranks', ''))
                    
                    # Parse predictability metrics
                    acc_shared, acc_mod1, acc_mod2 = parse_three_values(row.get('classification_accuracy', ''))
                    pred1_shared, pred1_mod1, pred1_mod2 = parse_three_values(row.get('label_1_pred', ''))
                    pred2_shared, pred2_mod1, pred2_mod2 = parse_three_values(row.get('label_2_pred', ''))
                    
                    if shared is not None:
                        result = {
                            'experiment': 'L1',
                            'data_version': metadata.get('data_version', 'unknown'),
                            'seed': metadata.get('seed', -1),
                            'n_samples': metadata.get('n_samples', -1),
                            'paired': metadata.get('paired', False),
                            'weight': row.get('l1_weight', np.nan),
                            'config': row.get('config', -1),
                            'shared_rank': shared,
                            'mod1_rank': mod1,
                            'mod2_rank': mod2,
                            'total_rank': shared + mod1 + mod2,
                            'r_square_threshold': row.get('r_square_threshold', np.nan),
                            # Classification accuracy from each subspace
                            'acc_shared': acc_shared,
                            'acc_mod1': acc_mod1,
                            'acc_mod2': acc_mod2,
                            # Label 1 prediction from each subspace
                            'pred1_shared': pred1_shared,
                            'pred1_mod1': pred1_mod1,
                            'pred1_mod2': pred1_mod2,
                            # Label 2 prediction from each subspace
                            'pred2_shared': pred2_shared,
                            'pred2_mod1': pred2_mod1,
                            'pred2_mod2': pred2_mod2,
                        }
                        all_data.append(result)
        except Exception as e:
            print(f"Error reading {file}: {e}")
    
    return pd.DataFrame(all_data)


def load_ortho_results():
    """Load results from orthogonal loss experiments."""
    pattern = "03_results/reports/larrp_mm-parametric-sim6-ortho*.csv"
    files = glob.glob(pattern)
    
    all_data = []
    for file in files:
        metadata = extract_metadata_from_filename(file)
        try:
            df = pd.read_csv(file)
            
            # Get all unique combinations of weight and final_ranks
            if len(df) > 0 and 'ortho_weight' in df.columns and 'final_ranks' in df.columns:
                # Group by config/weight to get final results for each configuration
                unique_configs = df.groupby(['config', 'ortho_weight']).last().reset_index()
                
                for _, row in unique_configs.iterrows():
                    # Parse final ranks
                    shared, mod1, mod2 = parse_final_ranks(row.get('final_ranks', ''))
                    
                    # Parse predictability metrics
                    acc_shared, acc_mod1, acc_mod2 = parse_three_values(row.get('classification_accuracy', ''))
                    pred1_shared, pred1_mod1, pred1_mod2 = parse_three_values(row.get('label_1_pred', ''))
                    pred2_shared, pred2_mod1, pred2_mod2 = parse_three_values(row.get('label_2_pred', ''))
                    
                    if shared is not None:
                        result = {
                            'experiment': 'Ortho',
                            'data_version': metadata.get('data_version', 'unknown'),
                            'seed': metadata.get('seed', -1),
                            'n_samples': metadata.get('n_samples', -1),
                            'paired': metadata.get('paired', False),
                            'weight': row.get('ortho_weight', np.nan),
                            'config': row.get('config', -1),
                            'shared_rank': shared,
                            'mod1_rank': mod1,
                            'mod2_rank': mod2,
                            'total_rank': shared + mod1 + mod2,
                            'r_square_threshold': row.get('r_square_threshold', np.nan),
                            # Classification accuracy from each subspace
                            'acc_shared': acc_shared,
                            'acc_mod1': acc_mod1,
                            'acc_mod2': acc_mod2,
                            # Label 1 prediction from each subspace
                            'pred1_shared': pred1_shared,
                            'pred1_mod1': pred1_mod1,
                            'pred1_mod2': pred1_mod2,
                            # Label 2 prediction from each subspace
                            'pred2_shared': pred2_shared,
                            'pred2_mod1': pred2_mod1,
                            'pred2_mod2': pred2_mod2,
                        }
                        all_data.append(result)
        except Exception as e:
            print(f"Error reading {file}: {e}")
    
    return pd.DataFrame(all_data)


def load_l2norm_results():
    """Load results from L2 norm regularization experiments (noshared)."""
    pattern = "03_results/reports/larrp_mm-parametric-sim5*.csv"
    files = glob.glob(pattern)
    
    all_data = []
    for file in files:
        metadata = extract_metadata_from_filename(file)
        try:
            df = pd.read_csv(file)
            
            # For L2norm, the weight is in the filename, not in a column
            # So we just get the last row for final results
            if len(df) > 0:
                last_row = df.iloc[-1]
                
                # Parse final ranks
                shared, mod1, mod2 = parse_final_ranks(last_row.get('final_ranks', ''))
                
                # Parse predictability metrics
                acc_shared, acc_mod1, acc_mod2 = parse_three_values(last_row.get('classification_accuracy', ''))
                pred1_shared, pred1_mod1, pred1_mod2 = parse_three_values(last_row.get('label_1_pred', ''))
                pred2_shared, pred2_mod1, pred2_mod2 = parse_three_values(last_row.get('label_2_pred', ''))
                
                if shared is not None:
                    result = {
                        'experiment': 'L2norm',
                        'data_version': metadata.get('data_version', 'unknown'),
                        'seed': metadata.get('seed', -1),
                        'n_samples': metadata.get('n_samples', -1),
                        'paired': False,  # noshared experiments don't have paired parameter
                        'weight': metadata.get('l2norm', np.nan),
                        'config': last_row.get('config', -1),
                        'shared_rank': shared,
                        'mod1_rank': mod1,
                        'mod2_rank': mod2,
                        'total_rank': shared + mod1 + mod2,
                        'r_square_threshold': last_row.get('r_square_threshold', np.nan),
                        # Classification accuracy from each subspace
                        'acc_shared': acc_shared,
                        'acc_mod1': acc_mod1,
                        'acc_mod2': acc_mod2,
                        # Label 1 prediction from each subspace
                        'pred1_shared': pred1_shared,
                        'pred1_mod1': pred1_mod1,
                        'pred1_mod2': pred1_mod2,
                        # Label 2 prediction from each subspace
                        'pred2_shared': pred2_shared,
                        'pred2_mod1': pred2_mod1,
                        'pred2_mod2': pred2_mod2,
                    }
                    all_data.append(result)
        except Exception as e:
            print(f"Error reading {file}: {e}")
    
    return pd.DataFrame(all_data)


def compute_summary_statistics(df, group_cols=['experiment', 'data_version', 'weight']):
    """
    Compute mean ± SEM for final ranks and predictability metrics grouped by specified columns.
    
    Args:
        df: DataFrame with results
        group_cols: Columns to group by
    
    Returns:
        DataFrame with summary statistics
    """
    if len(df) == 0:
        return pd.DataFrame()
    
    summary_data = []
    
    for group_keys, group_df in df.groupby(group_cols):
        if not isinstance(group_keys, tuple):
            group_keys = (group_keys,)
        
        # Create base result dictionary
        result = {col: val for col, val in zip(group_cols, group_keys)}
        
        # Add sample counts
        result['n_runs'] = len(group_df)
        result['n_seeds'] = group_df['seed'].nunique()
        
        # Compute statistics for each rank type
        for rank_type in ['shared_rank', 'mod1_rank', 'mod2_rank', 'total_rank']:
            values = group_df[rank_type].dropna().values
            
            if len(values) > 0:
                mean = np.mean(values)
                sem = np.std(values, ddof=1) / np.sqrt(len(values)) if len(values) > 1 else 0
                
                result[f'{rank_type}_mean'] = mean
                result[f'{rank_type}_sem'] = sem
                result[f'{rank_type}_mean_sem'] = f"{mean:.2f} ± {sem:.2f}"
            else:
                result[f'{rank_type}_mean'] = np.nan
                result[f'{rank_type}_sem'] = np.nan
                result[f'{rank_type}_mean_sem'] = "N/A"
        
        # Compute statistics for predictability metrics
        predictability_metrics = [
            'acc_shared', 'acc_mod1', 'acc_mod2',
            'pred1_shared', 'pred1_mod1', 'pred1_mod2',
            'pred2_shared', 'pred2_mod1', 'pred2_mod2'
        ]
        
        for metric in predictability_metrics:
            if metric in group_df.columns:
                values = group_df[metric].dropna().values
                
                if len(values) > 0:
                    mean = np.mean(values)
                    sem = np.std(values, ddof=1) / np.sqrt(len(values)) if len(values) > 1 else 0
                    
                    result[f'{metric}_mean'] = mean
                    result[f'{metric}_sem'] = sem
                    result[f'{metric}_mean_sem'] = f"{mean:.4f} ± {sem:.4f}"
                else:
                    result[f'{metric}_mean'] = np.nan
                    result[f'{metric}_sem'] = np.nan
                    result[f'{metric}_mean_sem'] = "N/A"
        
        summary_data.append(result)
    
    return pd.DataFrame(summary_data)


def main():
    """Main analysis function."""
    print("=" * 80)
    print("Analysis of Additional Loss Term Experiments")
    print("=" * 80)
    
    # Load all results
    print("\n1. Loading L1 regularization results...")
    l1_df = load_l1_results()
    print(f"   Loaded {len(l1_df)} L1 experiment results")
    print(f"   Data versions: {sorted(l1_df['data_version'].unique())}")
    print(f"   Weight values: {sorted(l1_df['weight'].unique())}")
    
    print("\n2. Loading orthogonal loss results...")
    ortho_df = load_ortho_results()
    print(f"   Loaded {len(ortho_df)} Ortho experiment results")
    print(f"   Data versions: {sorted(ortho_df['data_version'].unique())}")
    print(f"   Weight values: {sorted(ortho_df['weight'].unique())}")
    
    print("\n3. Loading L2 norm regularization results...")
    l2_df = load_l2norm_results()
    print(f"   Loaded {len(l2_df)} L2norm experiment results")
    print(f"   Data versions: {sorted(l2_df['data_version'].unique())}")
    print(f"   Weight values: {sorted(l2_df['weight'].unique())}")
    
    # Combine all results
    all_results = pd.concat([l1_df, ortho_df, l2_df], ignore_index=True)
    print(f"\n4. Total combined results: {len(all_results)}")
    
    # Compute summary statistics
    print("\n5. Computing summary statistics...")
    
    # Overall summary by experiment, data_version, and weight
    summary = compute_summary_statistics(
        all_results, 
        group_cols=['experiment', 'data_version', 'weight']
    )
    
    # Sort by experiment, data_version, weight
    summary = summary.sort_values(['experiment', 'data_version', 'weight'])
    
    # Save summary
    output_dir = "03_results/processed"
    os.makedirs(output_dir, exist_ok=True)
    
    output_file = os.path.join(output_dir, "mm_param_extralosses_summary.csv")
    summary.to_csv(output_file, index=False)
    print(f"\n6. Summary statistics saved to: {output_file}")
    
    # Print summaries for each experiment type
    print("\n" + "=" * 80)
    print("SUMMARY STATISTICS - FINAL RANKS")
    print("=" * 80)
    
    for exp_type in ['L1', 'Ortho', 'L2norm']:
        exp_summary = summary[summary['experiment'] == exp_type]
        
        if len(exp_summary) > 0:
            print(f"\n{'=' * 80}")
            print(f"{exp_type} Regularization")
            print(f"{'=' * 80}")
            
            for data_version in sorted(exp_summary['data_version'].unique()):
                version_summary = exp_summary[exp_summary['data_version'] == data_version]
                
                print(f"\nData Version: {data_version}")
                print("-" * 80)
                print(f"{'Weight':<12} {'Shared Rank':<20} {'Mod1 Rank':<20} {'Mod2 Rank':<20} {'Total Rank':<20} {'N Runs':<8}")
                print("-" * 80)
                
                for _, row in version_summary.iterrows():
                    print(f"{row['weight']:<12.4f} "
                          f"{row['shared_rank_mean_sem']:<20} "
                          f"{row['mod1_rank_mean_sem']:<20} "
                          f"{row['mod2_rank_mean_sem']:<20} "
                          f"{row['total_rank_mean_sem']:<20} "
                          f"{row['n_runs']:<8}")
    
    # Print predictability summaries
    print("\n" + "=" * 80)
    print("SUMMARY STATISTICS - CLASSIFICATION ACCURACY")
    print("=" * 80)
    
    for exp_type in ['L1', 'Ortho', 'L2norm']:
        exp_summary = summary[summary['experiment'] == exp_type]
        
        if len(exp_summary) > 0:
            print(f"\n{'=' * 80}")
            print(f"{exp_type} Regularization")
            print(f"{'=' * 80}")
            
            for data_version in sorted(exp_summary['data_version'].unique()):
                version_summary = exp_summary[exp_summary['data_version'] == data_version]
                
                print(f"\nData Version: {data_version}")
                print("-" * 80)
                print(f"{'Weight':<12} {'Acc Shared':<22} {'Acc Mod1':<22} {'Acc Mod2':<22} {'N Runs':<8}")
                print("-" * 80)
                
                for _, row in version_summary.iterrows():
                    acc_shared = row.get('acc_shared_mean_sem', 'N/A')
                    acc_mod1 = row.get('acc_mod1_mean_sem', 'N/A')
                    acc_mod2 = row.get('acc_mod2_mean_sem', 'N/A')
                    print(f"{row['weight']:<12.4f} "
                          f"{acc_shared:<22} "
                          f"{acc_mod1:<22} "
                          f"{acc_mod2:<22} "
                          f"{row['n_runs']:<8}")
    
    # Print label prediction summaries
    print("\n" + "=" * 80)
    print("SUMMARY STATISTICS - LABEL PREDICTABILITY (R²)")
    print("=" * 80)
    
    for exp_type in ['L1', 'Ortho', 'L2norm']:
        exp_summary = summary[summary['experiment'] == exp_type]
        
        if len(exp_summary) > 0:
            print(f"\n{'=' * 80}")
            print(f"{exp_type} Regularization")
            print(f"{'=' * 80}")
            
            for data_version in sorted(exp_summary['data_version'].unique()):
                version_summary = exp_summary[exp_summary['data_version'] == data_version]
                
                print(f"\nData Version: {data_version}")
                print(f"\n  Label 1 Prediction:")
                print("  " + "-" * 76)
                print(f"  {'Weight':<12} {'Pred1 Shared':<22} {'Pred1 Mod1':<22} {'Pred1 Mod2':<22}")
                print("  " + "-" * 76)
                
                for _, row in version_summary.iterrows():
                    pred1_shared = row.get('pred1_shared_mean_sem', 'N/A')
                    pred1_mod1 = row.get('pred1_mod1_mean_sem', 'N/A')
                    pred1_mod2 = row.get('pred1_mod2_mean_sem', 'N/A')
                    print(f"  {row['weight']:<12.4f} "
                          f"{pred1_shared:<22} "
                          f"{pred1_mod1:<22} "
                          f"{pred1_mod2:<22}")
                
                print(f"\n  Label 2 Prediction:")
                print("  " + "-" * 76)
                print(f"  {'Weight':<12} {'Pred2 Shared':<22} {'Pred2 Mod1':<22} {'Pred2 Mod2':<22}")
                print("  " + "-" * 76)
                
                for _, row in version_summary.iterrows():
                    pred2_shared = row.get('pred2_shared_mean_sem', 'N/A')
                    pred2_mod1 = row.get('pred2_mod1_mean_sem', 'N/A')
                    pred2_mod2 = row.get('pred2_mod2_mean_sem', 'N/A')
                    print(f"  {row['weight']:<12.4f} "
                          f"{pred2_shared:<22} "
                          f"{pred2_mod1:<22} "
                          f"{pred2_mod2:<22}")
    
    # Also create experiment-specific summaries
    for exp_type, exp_df in [('L1', l1_df), ('Ortho', ortho_df), ('L2norm', l2_df)]:
        if len(exp_df) > 0:
            exp_summary = compute_summary_statistics(exp_df, group_cols=['data_version', 'weight'])
            exp_output = os.path.join(output_dir, f"mm_param_{exp_type.lower()}_summary.csv")
            exp_summary.to_csv(exp_output, index=False)
            print(f"\n{exp_type}-specific summary saved to: {exp_output}")
    
    # Save raw combined data
    raw_output = os.path.join(output_dir, "mm_param_extralosses_raw.csv")
    all_results.to_csv(raw_output, index=False)
    print(f"\nRaw combined data saved to: {raw_output}")
    
    print("\n" + "=" * 80)
    print("Analysis complete!")
    print("=" * 80)


if __name__ == "__main__":
    main()
