﻿"""
Enhanced Eta Experiment: Testing multiple lambda allocation strategies

This experiment tests how different lambda allocation strategies combined with
various eta values affect the performance of two-stage review methods.

Multiple Strategies Tested:
1. Original: ?1 = 12(1-?), ?2 = 12
2. Constant Stage2: ?1 = 12-5?, ?2 = 5
3. Moderate Stage2: ?1 = 12-8?, ?2 = 8
4. Low Stage2: ?1 = 12-3?, ?2 = 3

Parameters:
- Paper number: 100 (fixed)
- Reviewer number: 100 (fixed, consistent with lambda experiment)
- Reviewer quality: 0.7 (uniform and known)
- Eta values: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
- All strategies maintain: ?1 + ?*?2 = 12 (constant total lambda per paper)

Note: High Stage2 strategy removed as it produces invalid configurations
"""

import numpy as np
import csv
import time
from datetime import datetime
import os
import sys
import random
from scipy import stats

sys.path.append('./review_modes')
from two_stage_review import TwoStageReviewConference
from promising_review import PromisingReviewConference
from direct_review import DirectReviewConference
from geometric_mean_utils import geometric_mean_safe

def get_lambda_strategies():
    """
    Define multiple lambda allocation strategies with constant total lambda constraint.
    Uses TARGET_LAMBDA env var (default 8): lambda1 + eta*lambda2 = T.
    """
    import os
    T = float(os.getenv('TARGET_LAMBDA','8'))
    strategies = {
        'original': {
            'name': 'Original (decreasing lambda1)',
            'description': 'lambda1 = T*(1-eta), lambda2 = T',
            'lambda_func': lambda eta: (T * (1 - eta), T)
        },
        'constant_stage2': {
            'name': 'Constant Stage2',
            'description': 'lambda2 = 6, lambda1 = T - 6*eta',
            'lambda_func': lambda eta: (T - 6*eta, 6)
        },
        'moderate_stage2': {
            'name': 'Moderate Stage2',
            'description': 'lambda2 = 8, lambda1 = T - 8*eta',
            'lambda_func': lambda eta: (T - 8*eta, 8)
        },
        'low_stage2': {
            'name': 'Low Stage2',
            'description': 'lambda2 = 3, lambda1 = T - 3*eta',
            'lambda_func': lambda eta: (T - 3*eta, 3)
        },
        'fixed_stage1_4': {
            'name': 'Recommended A: lambda1=4',
            'description': 'lambda1 = 4, lambda2 = (T-4)/eta',
            'lambda_func': lambda eta: (4, (T - 4)/eta)
        },
        'fixed_stage1_3': {
            'name': 'Recommended B: lambda1=3',
            'description': 'lambda1 = 3, lambda2 = (T-3)/eta',
            'lambda_func': lambda eta: (3, (T - 3)/eta)
        },
    }
    # Fractional Stage2 family (lambda2 = frac * T), default FRAC_GRID="0.25,0.5,0.75"
    frac_grid_env = os.getenv('FRAC_GRID', '0.25,0.5,0.75')
    try:
        fracs = [float(x) for x in frac_grid_env.split(',') if x.strip()]
    except Exception:
        fracs = [0.5]
    for f in fracs:
        key = f"constant_stage2_frac_{int(round(f*100)):02d}"
        strategies[key] = {
            'name': f'Const Stage2 Fraction (lambda2={f}*T)',
            'description': f'lambda2 = {f}*T, lambda1 = T - eta*(lambda2)',
            'lambda_func': (lambda eta, _frac=f: (T - (_frac*T)*eta, _frac*T)),
        }
    # Backward single-fraction key for compatibility
    if 'constant_stage2_frac' not in strategies:
        try:
            f = float(os.getenv('LAMBDA2_FRAC','0.5'))
            strategies['constant_stage2_frac'] = {
                'name': f'Const Stage2 Fraction (lambda2={f}*T)',
                'description': f'lambda2 = {f}*T, lambda1 = T - eta*(lambda2)',
                'lambda_func': (lambda eta, _frac=f: (T - (_frac*T)*eta, _frac*T)),
            }
        except Exception:
            pass
    return strategies

def run_eta_experiment():
    """
    Run experiment with different eta values and multiple lambda strategies
    """
    # Fixed parameters
    paper_num = 100
    reviewer_num = 100  # Changed from 1000 to match lambda experiment setting
    reviewer_quality = 0.7
    # Optional: stage-specific reviewer qualities for two-stage modes (keep average ~= reviewer_quality)
    target_total_lambda = int(os.getenv('TARGET_LAMBDA', '8'))

    # Variable parameters
    eta_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    num_trials = 20
    # Quick mode for sanity checks
    quick_mode = ('--quick' in sys.argv) or (os.getenv('QUICK', '0') == '1')
    if quick_mode:
        print('Running in QUICK mode: fewer etas, fewer trials, subset of strategies')
        eta_values = [0.3, 0.5, 0.7]
        num_trials = 5

    # Get all lambda strategies
    strategies = get_lambda_strategies()
    if quick_mode:
        selected = ['fixed_stage1_4', 'fixed_stage1_3', 'original']
        strategies = {k: v for k, v in strategies.items() if k in selected}
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = f"eta_multi_strategy_results_{timestamp}.csv"

    results = []

    # Run direct review baseline once (it's independent of eta and strategy)
    print(f"\n{'='*80}")
    print(f"RUNNING DIRECT REVIEW BASELINE (eta-independent)")
    print(f"{'='*80}")

    direct_metrics = {
        'accuracy': [], 'f1_score': [], 'precision': [], 'recall': [],
        'bp_raw_error': [], 'calibration_error': [], 'bp_f1_loss': [], 'bp_f1_score_50pct': [], 'kl_divergence': [], 'js_divergence': []
    }

    for trial in range(num_trials):
        print(f"Direct baseline trial {trial + 1}/{num_trials}")
        trial_seed = 42 + trial * 1000
        np.random.seed(trial_seed)
        random.seed(trial_seed)

        exp_direct = DirectReviewConference(
            paper_num=paper_num,
            reviewer_num=reviewer_num,
            reviewer_quality=reviewer_quality,
            lambda_per_paper=target_total_lambda
        )
        results_direct = exp_direct.run_direct_review()
        for metric in direct_metrics:
            direct_metrics[metric].append(results_direct[metric])

    # Now test each strategy-eta combination
    for strategy_key, strategy_info in strategies.items():
        print(f"\n{'='*80}")
        print(f"TESTING STRATEGY: {strategy_info['name']}")
        print(f"Description: {strategy_info['description']}")
        print(f"{'='*80}")

        for eta in eta_values:
            # Calculate lambda values using strategy function
            lambda1, lambda2 = strategy_info['lambda_func'](eta)
            # Enforce an upper bound on stage-2 intensity to keep workloads feasible
            LAMBDA2_MAX = int(os.getenv('LAMBDA2_MAX', '18'))
            if lambda2 > LAMBDA2_MAX:
                print(f"Skipping eta={eta:.1f} for {strategy_key}: lambda2={lambda2:.1f} exceeds cap {LAMBDA2_MAX}")
                continue

            # Skip invalid configurations
            if lambda1 <= 0 or lambda2 <= 0:
                print(f"Skipping eta={eta:.1f} for {strategy_key}: invalid lambda values (?1={lambda1:.1f}, ?2={lambda2:.1f})")
                continue

            # Verify total lambda constraint
            total_lambda = lambda1 + eta * lambda2
            if abs(total_lambda - target_total_lambda) > 0.001:
                print(f"WARNING: Total lambda = {total_lambda:.3f} ? {target_total_lambda} for strategy {strategy_key}, eta={eta}")

            print(f"\n{'='*60}")
            print(f"Testing eta = {eta:.1f} ({eta*100:.0f}% papers to stage 2)")
            print(f"Strategy: {strategy_key} - ?1 = {lambda1:.1f}, ?2 = {lambda2:.1f}, total ? = {total_lambda:.1f}")
            print(f"Papers in stage 2: {int(paper_num * eta)}")
            print(f"Stage 1 tasks: {int(paper_num * lambda1)}, Stage 2 tasks: {int(paper_num * eta * lambda2)}")
            print(f"{'='*60}")

            # Storage for two-stage metrics
            twostage_metrics = {
                'accuracy': [], 'f1_score': [], 'precision': [], 'recall': [],
                'bp_raw_error': [], 'calibration_error': [], 'bp_f1_loss': [], 'bp_f1_score_50pct': [], 'kl_divergence': [], 'js_divergence': []
            }
            promising_metrics = {
                'accuracy': [], 'f1_score': [], 'precision': [], 'recall': [],
                'bp_raw_error': [], 'calibration_error': [], 'bp_f1_loss': [], 'bp_f1_score_50pct': [], 'kl_divergence': [], 'js_divergence': []
            }

            for trial in range(num_trials):
                print(f"\nTrial {trial + 1}/{num_trials} for eta={eta:.1f}, strategy={strategy_key}")

                # Set the same seed for each trial to ensure fair comparison
                trial_seed = 42 + trial * 1000 + int(eta * 100) * 12 + hash(strategy_key) % 1000

                print("Running Two-Stage Ambiguous Review...")
                np.random.seed(trial_seed)
                random.seed(trial_seed)
                exp_ambiguous = TwoStageReviewConference(
                    paper_num=paper_num,
                    reviewer_num=reviewer_num,
                    reviewer_quality=reviewer_quality,
                    eta=eta,
                    lambda_per_paper=target_total_lambda,
                    stage1_reviewer_ratio=1.0,
                    lambda1=lambda1,
                )
                results_ambiguous = exp_ambiguous.run_two_stage_review()
                for metric in twostage_metrics:
                    twostage_metrics[metric].append(results_ambiguous[metric])

                print("Running Two-Stage Promising Review...")
                np.random.seed(trial_seed)
                random.seed(trial_seed)
                exp_promising = PromisingReviewConference(
                    paper_num=paper_num,
                    reviewer_num=reviewer_num,
                    reviewer_quality=reviewer_quality,
                    eta=eta,
                    lambda_per_paper=target_total_lambda,
                    stage1_reviewer_ratio=1.0,
                    lambda1=lambda1,
                )
                results_promising = exp_promising.run_promising_review()
                for metric in promising_metrics:
                    promising_metrics[metric].append(results_promising[metric])

            # Build comprehensive result row with all metrics
            result_row = {
                'strategy': strategy_key,
                'strategy_name': strategy_info['name'],
                'eta': eta,
                'eta_percent': eta * 100,
                'paper_num': paper_num,
                'reviewer_num': reviewer_num,
                'target_total_lambda': target_total_lambda,
                'lambda1': lambda1,
                'lambda2': lambda2,
                'actual_total_lambda': total_lambda,
                'papers_in_stage2': int(paper_num * eta),
                'stage1_tasks': int(paper_num * lambda1),
                'stage2_tasks': int(paper_num * eta * lambda2),
                'total_tasks': int(paper_num * lambda1) + int(paper_num * eta * lambda2)
            }

            # Add direct review metrics (same for all strategies)
            for metric in direct_metrics:
                if metric == 'kl_divergence':
                    result_row[f'direct_{metric}_mean'] = geometric_mean_safe(direct_metrics[metric])
                else:
                    result_row[f'direct_{metric}_mean'] = np.mean(direct_metrics[metric])
                result_row[f'direct_{metric}_std'] = np.std(direct_metrics[metric])

            # Add two-stage ambiguous review metrics
            for metric in twostage_metrics:
                if metric == 'kl_divergence':
                    result_row[f'ambiguous_{metric}_mean'] = geometric_mean_safe(twostage_metrics[metric])
                else:
                    result_row[f'ambiguous_{metric}_mean'] = np.mean(twostage_metrics[metric])
                result_row[f'ambiguous_{metric}_std'] = np.std(twostage_metrics[metric])

            # Add promising review metrics
            for metric in promising_metrics:
                if metric == 'kl_divergence':
                    result_row[f'promising_{metric}_mean'] = geometric_mean_safe(promising_metrics[metric])
                else:
                    result_row[f'promising_{metric}_mean'] = np.mean(promising_metrics[metric])
                result_row[f'promising_{metric}_std'] = np.std(promising_metrics[metric])

            results.append(result_row)

            # Print results for this strategy-eta combination
            print(f"\nResults for {strategy_key}, eta={eta:.1f}:")
            print(f"  Ambiguous: Acc={result_row['ambiguous_accuracy_mean']:.4f}?{result_row['ambiguous_accuracy_std']:.4f}")
            print(f"  Promising: Acc={result_row['promising_accuracy_mean']:.4f}?{result_row['promising_accuracy_std']:.4f}")
            print(f"  Direct:    Acc={result_row['direct_accuracy_mean']:.4f}?{result_row['direct_accuracy_std']:.4f}")

    # Write results to CSV
    print(f"\n{'='*80}")
    print(f"EXPERIMENT COMPLETE - SAVING RESULTS")
    print(f"{'='*80}")

    # Generate comprehensive fieldnames for all metrics
    metrics_list = ['accuracy', 'f1_score', 'precision', 'recall',
                   'bp_raw_error', 'calibration_error', 'bp_f1_loss', 'bp_f1_score_50pct', 'kl_divergence', 'js_divergence']

    fieldnames = ['strategy', 'strategy_name', 'eta', 'eta_percent', 'paper_num', 'reviewer_num',
                  'target_total_lambda', 'lambda1', 'lambda2', 'actual_total_lambda',
                  'papers_in_stage2', 'stage1_tasks', 'stage2_tasks', 'total_tasks']

    for mode in ['direct', 'ambiguous', 'promising']:
        for metric in metrics_list:
            fieldnames.extend([f'{mode}_{metric}_mean', f'{mode}_{metric}_std'])

    with open(output_file, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)

    print(f"Multi-strategy eta experiment complete!")
    print(f"Results saved to: {output_file}")
    print(f"Total configurations tested: {len(results)}")
    print(f"Strategies: {list(strategies.keys())}")
    print(f"Eta values: {eta_values}")

    return output_file, results

def create_multi_strategy_analysis_summary(results):
    """
    Create analysis summary for multi-strategy eta experiment results
    """
    print(f"\n{'='*80}")
    print("MULTI-STRATEGY ETA EXPERIMENT ANALYSIS SUMMARY")
    print(f"{'='*80}")

    # Group results by strategy
    strategies = {}
    for r in results:
        strategy = r['strategy']
        if strategy not in strategies:
            strategies[strategy] = []
        strategies[strategy].append(r)

    print(f"\nStrategies tested: {len(strategies)}")
    for strategy, strategy_results in strategies.items():
        print(f"  {strategy}: {len(strategy_results)} eta values")

    # Find best performance across all strategies
    print(f"\n{'='*60}")
    print("BEST PERFORMANCE ANALYSIS")
    print(f"{'='*60}")

    # Find optimal combinations for each metric
    best_accuracy = max(results, key=lambda x: x['ambiguous_accuracy_mean'])
    best_promising = max(results, key=lambda x: x['promising_accuracy_mean'])
    best_kl_ambig = min(results, key=lambda x: x['ambiguous_kl_divergence_mean'])
    best_kl_prom = min(results, key=lambda x: x['promising_kl_divergence_mean'])

    print(f"\nBest Ambiguous Accuracy:")
    print(f"  Strategy: {best_accuracy['strategy']} ({best_accuracy['strategy_name']})")
    print(f"  eta = {best_accuracy['eta']:.1f}, ?1 = {best_accuracy['lambda1']:.1f}, ?2 = {best_accuracy['lambda2']:.1f}")
    print(f"  Accuracy: {best_accuracy['ambiguous_accuracy_mean']:.4f} ? {best_accuracy['ambiguous_accuracy_std']:.4f}")

    print(f"\nBest Promising Accuracy:")
    print(f"  Strategy: {best_promising['strategy']} ({best_promising['strategy_name']})")
    print(f"  eta = {best_promising['eta']:.1f}, ?1 = {best_promising['lambda1']:.1f}, ?2 = {best_promising['lambda2']:.1f}")
    print(f"  Accuracy: {best_promising['promising_accuracy_mean']:.4f} ? {best_promising['promising_accuracy_std']:.4f}")

    # Strategy-wise comparison
    print(f"\n{'='*60}")
    print("STRATEGY COMPARISON")
    print(f"{'='*60}")

    for strategy, strategy_results in strategies.items():
        best_ambig = max(strategy_results, key=lambda x: x['ambiguous_accuracy_mean'])
        best_prom = max(strategy_results, key=lambda x: x['promising_accuracy_mean'])

        print(f"\nStrategy: {strategy}")
        print(f"  Best Ambiguous: eta={best_ambig['eta']:.1f}, acc={best_ambig['ambiguous_accuracy_mean']:.4f}")
        print(f"  Best Promising: eta={best_prom['eta']:.1f}, acc={best_prom['promising_accuracy_mean']:.4f}")
        print(f"  Lambda config: ?1={best_ambig['lambda1']:.1f}-{best_prom['lambda1']:.1f}, ?2={best_ambig['lambda2']:.1f}-{best_prom['lambda2']:.1f}")

    # Direct baseline comparison
    direct_baseline = results[0]['direct_accuracy_mean']
    print(f"\nDirect Review Baseline: {direct_baseline:.4f}")
    print(f"Best improvements:")
    print(f"  Ambiguous: +{best_accuracy['ambiguous_accuracy_mean'] - direct_baseline:.4f} ({(best_accuracy['ambiguous_accuracy_mean'] - direct_baseline)/direct_baseline*100:+.2f}%)")
    print(f"  Promising: +{best_promising['promising_accuracy_mean'] - direct_baseline:.4f} ({(best_promising['promising_accuracy_mean'] - direct_baseline)/direct_baseline*100:+.2f}%)")

if __name__ == "__main__":
    output_file, results = run_eta_experiment()

    # Create analysis summary
    create_multi_strategy_analysis_summary(results)

    print(f"\n{'='*80}")
    print("EXPERIMENT COMPLETE")
    print(f"{'='*80}")
    print(f"Detailed results saved to: {output_file}")
    print(f"CSV contains all metrics for 3 methods across multiple lambda strategies")
    print(f"Methods: Direct (baseline), Ambiguous, Promising")
    print(f"Lambda strategies: {list(get_lambda_strategies().keys())}")
    print(f"Metrics: Accuracy, F1, Precision, Recall, BP Raw Error, BP F1 Loss, BP F1 Score (50%), KL Divergence, JS Divergence")
    print(f"Configuration: 100 papers, 1000 reviewers, total ?=12 (constant), eta=[0.1 to 0.9]")






