"""
Multi-interval runtime benchmark for MCR methods.

This script runs benchmarks at multiple search intervals [32, 16, 8, 4] 
and saves results incrementally after each interval completes.

Features:
- Runs intervals in descending order (32 → 16 → 8 → 4)
- Saves results after each interval (safe for long experiments)
- Can resume from partial results
- 5 independent runs per configuration
- Default SNR = 20dB
"""

import copy
import time
import argparse
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
import json

from ebgmcr import RandomComponentMixtureSynthesizer

from evaluate_synthesized_benchmarks import (
    DefaultDatasetArgument,
    NMF_baseline,
    SparseNMF_baseline,
    BayesNMF_baseline,
    ICA_baseline,
    MCR_ALS_baseline
)


def search_baselines(data, scanning_components=[]):
    """Run all baseline methods for each component count and measure runtime.
    
    Args:
        data: Input mixed data
        scanning_components: List of component numbers to try
        
    Returns:
        Dictionary with runtime for each method and component count
    """
    runtime = {
        'NMF': {}, 
        'Sparse-NMF': {}, 
        'Bayes-NMF': {}, 
        'ICA': {}, 
        'MCR-ALS': {}
    }
    
    for count in scanning_components:
        print(f"    Testing component count: {count}")
        
        # NMF
        NMF_start = time.time()
        _ = NMF_baseline(data, count)
        NMF_end = time.time()
        runtime['NMF'][count] = NMF_end - NMF_start
        
        # Sparse-NMF
        SparseNMF_start = time.time()
        _ = SparseNMF_baseline(data, count)
        SparseNMF_end = time.time()
        runtime['Sparse-NMF'][count] = SparseNMF_end - SparseNMF_start
        
        # Bayes-NMF
        BayesNMF_start = time.time()
        _ = BayesNMF_baseline(data, count)
        BayesNMF_end = time.time()
        runtime['Bayes-NMF'][count] = BayesNMF_end - BayesNMF_start
        
        # ICA
        ICA_start = time.time()
        _ = ICA_baseline(data, count)
        ICA_end = time.time()
        runtime['ICA'][count] = ICA_end - ICA_start
        
        # MCR-ALS
        MCR_ALS_start = time.time()
        _ = MCR_ALS_baseline(data, count)
        MCR_ALS_end = time.time()
        runtime['MCR-ALS'][count] = MCR_ALS_end - MCR_ALS_start

    return runtime


def run_single_interval(
    interval,
    component_number,
    datafold,
    signal_to_noise_ratio,
    repeat_time
):
    """Run benchmarks for a single interval across multiple repeats.
    
    Args:
        interval: Search interval (e.g., 4, 8, 16, 32)
        component_number: Maximum number of components (N)
        datafold: Dataset size multiplier (M = N * datafold)
        signal_to_noise_ratio: SNR in dB
        repeat_time: Number of independent runs
        
    Returns:
        Dictionary with cumulative runtime statistics for this interval
    """
    print(f"\n{'='*70}")
    print(f"INTERVAL = {interval}")
    print(f"{'='*70}")
    
    collected_runtime = {
        'NMF': {}, 
        'Sparse-NMF': {}, 
        'Bayes-NMF': {}, 
        'ICA': {}, 
        'MCR-ALS': {}
    }
    
    # Generate component count search range
    scanning_components = list(range(interval, component_number + interval, interval))
    print(f"Component counts to test: {scanning_components}")
    print(f"Total tests per method: {len(scanning_components)}")
    
    # Run multiple independent experiments
    for run_idx in range(repeat_time):
        print(f"\n  Run {run_idx + 1}/{repeat_time}")
        print(f"  {'-'*60}")
        
        # Generate synthetic dataset
        dataset_config = copy.deepcopy(DefaultDatasetArgument)
        dataset_config['component_number'] = component_number
        dataset_config['signal_to_nosie_ratio'] = signal_to_noise_ratio
        data_number = component_number * datafold
        data_sampler = RandomComponentMixtureSynthesizer(**dataset_config)
        data = data_sampler(data_number)

        # Run baselines
        single_runtime = search_baselines(data.numpy(), scanning_components)

        # Collect results
        for method in single_runtime:
            for count in scanning_components:
                collected_runtime[method].setdefault(count, [])
                collected_runtime[method][count].append(single_runtime[method][count])
    
    # Compute cumulative statistics
    cumulative_stats = {}
    for method in collected_runtime:
        # For each run, compute cumulative time
        cumulative_times = []
        for run_idx in range(repeat_time):
            cum_time = sum(
                collected_runtime[method][c][run_idx] 
                for c in scanning_components
            )
            cumulative_times.append(cum_time)
        
        cumulative_stats[method] = {
            'mean': np.mean(cumulative_times),
            'std': np.std(cumulative_times),
            'min': np.min(cumulative_times),
            'max': np.max(cumulative_times)
        }
    
    return cumulative_stats


def load_existing_results(summary_csv):
    """Load existing results from summary CSV if it exists.
    
    Args:
        summary_csv: Path to summary CSV file
        
    Returns:
        Dictionary with existing results, or empty dict if file doesn't exist
    """
    if not summary_csv.exists():
        return {}
    
    df = pd.read_csv(summary_csv)
    results = {}
    
    for _, row in df.iterrows():
        interval = int(row['interval'])
        datafold = int(row['datafold'])
        
        key = (interval, datafold)
        results[key] = {
            'interval': interval,
            'datafold': datafold,
            'NMF': {'mean': row['NMF_mean_hr'], 'std': row['NMF_std_hr']},
            'Sparse-NMF': {'mean': row['SparseNMF_mean_hr'], 'std': row['SparseNMF_std_hr']},
            'Bayes-NMF': {'mean': row['BayesNMF_mean_hr'], 'std': row['BayesNMF_std_hr']},
            'ICA': {'mean': row['ICA_mean_hr'], 'std': row['ICA_std_hr']},
            'MCR-ALS': {'mean': row['MCRALS_mean_hr'], 'std': row['MCRALS_std_hr']}
        }
    
    return results


def save_interval_results(
    all_results,
    summary_csv,
    component_number,
    snr
):
    """Save all results to summary CSV.
    
    Args:
        all_results: Dictionary with all interval results
        summary_csv: Path to save summary CSV
        component_number: Maximum component number
        snr: Signal-to-noise ratio
    """
    rows = []
    
    for key, result in sorted(all_results.items()):
        interval, datafold = key
        
        row = {
            'interval': interval,
            'datafold': datafold,
            'data_size': f'{datafold}N ({component_number * datafold} samples)',
            'NMF_mean_hr': result['NMF']['mean'] / 3600,
            'NMF_std_hr': result['NMF']['std'] / 3600,
            'SparseNMF_mean_hr': result['Sparse-NMF']['mean'] / 3600,
            'SparseNMF_std_hr': result['Sparse-NMF']['std'] / 3600,
            'BayesNMF_mean_hr': result['Bayes-NMF']['mean'] / 3600,
            'BayesNMF_std_hr': result['Bayes-NMF']['std'] / 3600,
            'ICA_mean_hr': result['ICA']['mean'] / 3600,
            'ICA_std_hr': result['ICA']['std'] / 3600,
            'MCRALS_mean_hr': result['MCR-ALS']['mean'] / 3600,
            'MCRALS_std_hr': result['MCR-ALS']['std'] / 3600
        }
        rows.append(row)
    
    df = pd.DataFrame(rows)
    df.to_csv(summary_csv, index=False)
    
    print(f"\n{'='*70}")
    print(f"✓ Results saved to: {summary_csv}")
    print(f"{'='*70}\n")


def print_current_results(all_results):
    """Print current results in a formatted table.
    
    Args:
        all_results: Dictionary with all interval results
    """
    print("\n" + "="*70)
    print("CURRENT RESULTS (hours)")
    print("="*70)
    
    # Group by datafold
    dataflows = sorted(set(key[1] for key in all_results.keys()))
    
    for datafold in dataflows:
        print(f"\nData size = {datafold}N")
        print("-"*70)
        print(f"{'Interval':<10} {'NMF':<12} {'Sparse-NMF':<12} {'Bayes-NMF':<12} {'ICA':<12} {'MCR-ALS':<12}")
        print("-"*70)
        
        for key, result in sorted(all_results.items()):
            if key[1] == datafold:
                interval = key[0]
                print(f"{interval:<10} ", end="")
                for method in ['NMF', 'Sparse-NMF', 'Bayes-NMF', 'ICA', 'MCR-ALS']:
                    mean_hr = result[method]['mean'] / 3600
                    std_hr = result[method]['std'] / 3600
                    print(f"{mean_hr:.2f}±{std_hr:.2f}  ", end="")
                print()
    
    print("="*70 + "\n")


def main():
    parser = argparse.ArgumentParser(
        description='Multi-interval benchmark for MCR methods',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    # Data generation parameters
    parser.add_argument(
        '--component-number', 
        type=int, 
        default=256,
        help='Maximum number of components (N)'
    )
    parser.add_argument(
        '--dataflows', 
        type=int,
        nargs='+',
        default=[4, 8],
        help='Dataset size multipliers (M = N * datafold)'
    )
    parser.add_argument(
        '--snr', 
        type=float, 
        default=20.0,
        help='Signal-to-noise ratio in dB'
    )
    
    # Search parameters
    parser.add_argument(
        '--intervals', 
        type=int,
        nargs='+',
        default=[32, 16, 8],
        help='Search intervals to test (will be sorted descending)'
    )
    parser.add_argument(
        '--repeat-time', 
        type=int, 
        default=5,
        help='Number of independent runs per configuration'
    )
    
    # Output parameters
    parser.add_argument(
        '--output-dir', 
        type=str, 
        default='./benchmark_results',
        help='Directory to save results'
    )
    parser.add_argument(
        '--resume',
        action='store_true',
        help='Resume from existing results if available'
    )
    
    args = parser.parse_args()
    
    # Sort intervals in descending order (run expensive ones first)
    intervals = sorted(args.intervals, reverse=True)
    dataflows = sorted(args.dataflows)
    
    # Setup output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    summary_csv = output_dir / f'multi_interval_summary_N{args.component_number}_snr{args.snr}dB.csv'
    
    # Print configuration
    print("\n" + "="*70)
    print("MULTI-INTERVAL BENCHMARK CONFIGURATION")
    print("="*70)
    print(f"Component number (N):     {args.component_number}")
    print(f"Data size multipliers:    {dataflows}")
    print(f"Signal-to-noise ratio:    {args.snr} dB")
    print(f"Search intervals:         {intervals} (descending order)")
    print(f"Repeats per config:       {args.repeat_time}")
    print(f"Output directory:         {args.output_dir}")
    print(f"Summary CSV:              {summary_csv.name}")
    print("="*70)
    
    # Load existing results if resuming
    all_results = {}
    if args.resume:
        all_results = load_existing_results(summary_csv)
        if all_results:
            print(f"\n✓ Loaded {len(all_results)} existing result(s)")
            print_current_results(all_results)
    
    # Run experiments
    total_configs = len(intervals) * len(dataflows)
    completed_configs = len(all_results)
    
    print(f"\nTotal configurations: {total_configs}")
    print(f"Completed: {completed_configs}")
    print(f"Remaining: {total_configs - completed_configs}")
    print("\n" + "="*70)
    
    start_time = datetime.now()
    
    for datafold in dataflows:
        print(f"\n{'#'*70}")
        print(f"# DATA SIZE = {datafold}N ({args.component_number * datafold} mixed samples)")
        print(f"{'#'*70}")
        
        for interval in intervals:
            key = (interval, datafold)
            
            # Skip if already completed
            if key in all_results:
                print(f"\n✓ Skipping interval={interval}, datafold={datafold} (already completed)")
                continue
            
            # Run benchmark
            print(f"\n[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Starting interval={interval}, datafold={datafold}")
            
            cumulative_stats = run_single_interval(
                interval=interval,
                component_number=args.component_number,
                datafold=datafold,
                signal_to_noise_ratio=args.snr,
                repeat_time=args.repeat_time
            )
            
            # Store results
            all_results[key] = cumulative_stats
            
            # Save incrementally
            save_interval_results(
                all_results=all_results,
                summary_csv=summary_csv,
                component_number=args.component_number,
                snr=args.snr
            )
            
            # Print current progress
            completed_configs += 1
            elapsed = datetime.now() - start_time
            avg_time_per_config = elapsed / completed_configs
            remaining_configs = total_configs - completed_configs
            estimated_remaining = avg_time_per_config * remaining_configs
            
            print(f"\n{'='*70}")
            print(f"PROGRESS: {completed_configs}/{total_configs} configurations complete")
            print(f"Elapsed time: {elapsed}")
            print(f"Estimated remaining: {estimated_remaining}")
            print(f"{'='*70}")
            
            # Show current results
            print_current_results(all_results)
    
    # Final summary
    total_time = datetime.now() - start_time
    print("\n" + "="*70)
    print("✓ ALL BENCHMARKS COMPLETE!")
    print("="*70)
    print(f"Total time: {total_time}")
    print(f"Results saved to: {summary_csv}")
    print("="*70)
    
    # Print final results table
    print_current_results(all_results)


if __name__ == '__main__':
    main()
