import numpy as np
import csv
import time
from datetime import datetime
import os
import sys
import random
from scipy import stats

sys.path.append('./review_modes')
from direct_review import DirectReviewConference
from two_stage_review import TwoStageReviewConference
from promising_review import PromisingReviewConference
from geometric_mean_utils import geometric_mean_safe

def run_lambda_experiment():
    paper_num = 100
    reviewer_num = 100
    reviewer_quality = 0.85
    lambdas = [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]  # Tasks per paper, interval of 2
    num_trials = 20

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = f"lambda_experiment_results_{timestamp}.csv"

    results = []

    for lambda_val in lambdas:
        print(f"\n{'='*60}")
        print(f"Testing lambda = {lambda_val}")
        print(f"{'='*60}")

        # Storage for all metrics
        direct_metrics = {
            'accuracy': [], 'f1_score': [], 'precision': [], 'recall': [],
            'bp_raw_error': [], 'calibration_error': [], 'bp_f1_loss': [], 'bp_f1_score_50pct': [], 'kl_divergence': [], 'js_divergence': []
        }
        twostage_metrics = {
            'accuracy': [], 'f1_score': [], 'precision': [], 'recall': [],
            'bp_raw_error': [], 'calibration_error': [], 'bp_f1_loss': [], 'bp_f1_score_50pct': [], 'kl_divergence': [], 'js_divergence': []
        }
        promising_metrics = {
            'accuracy': [], 'f1_score': [], 'precision': [], 'recall': [],
            'bp_raw_error': [], 'calibration_error': [], 'bp_f1_loss': [], 'bp_f1_score_50pct': [], 'kl_divergence': [], 'js_divergence': []
        }

        # Two-stage strategy (original): lambda1 = T*(1-eta), lambda2 derived by constraint (becomes T)
        eta_opt = float(os.getenv('ETA_ORIG', '0.4'))
        lambda1_opt = lambda_val * (1.0 - eta_opt)
        if lambda1_opt < 0:
            raise ValueError(f'Invalid lambda1 computed: {lambda1_opt} (eta={eta_opt}, T={lambda_val})')
        print(f'Using two-stage strategy: original | eta={eta_opt}, lambda1={lambda1_opt:.3f} (lambda2 will be derived as T)')

        for trial in range(num_trials):
            print(f"\nTrial {trial + 1}/{num_trials} for lambda={lambda_val}")

            # Set the same seed for each trial across different methods
            # This ensures the same papers and initial conditions
            trial_seed = 42 + trial * 1000 + lambda_val * 10

            print("Running Direct Review...")
            # Set seed before creating each experiment
            np.random.seed(trial_seed)
            random.seed(trial_seed)
            exp1 = DirectReviewConference(
                paper_num=paper_num,
                reviewer_num=reviewer_num,
                reviewer_quality=reviewer_quality,
                lambda_per_paper=lambda_val
            )
            results1 = exp1.run_direct_review()
            # Collect all direct review metrics
            for metric in direct_metrics:
                direct_metrics[metric].append(results1[metric])

            print("Running Two-Stage Ambiguous Review (original strategy)...")
            # Reset seed to ensure same initial papers/reviewers
            np.random.seed(trial_seed)
            random.seed(trial_seed)
            # Strategy: eta fixed, lambda2 = frac * total, lambda1 = T - eta*lambda2
            lambda1 = lambda1_opt
            exp2 = TwoStageReviewConference(
                paper_num=paper_num,
                reviewer_num=reviewer_num,
                reviewer_quality=reviewer_quality,
                eta=eta_opt,
                lambda_per_paper=lambda_val,
                stage1_reviewer_ratio=0.5,  # Irrelevant with robust allocator
                lambda1=lambda1  # Explicit stage 1 lambda (lambda2 derived inside via constraint)
            )
            results2 = exp2.run_two_stage_review()
            # Collect all two-stage review metrics
            for metric in twostage_metrics:
                twostage_metrics[metric].append(results2[metric])

            print("Running Two-Stage Promising Review (original strategy)...")
            # Reset seed to ensure same initial papers/reviewers
            np.random.seed(trial_seed)
            random.seed(trial_seed)
            # Strategy: eta fixed, lambda2 = frac * total, lambda1 = T - eta*lambda2
            lambda1 = lambda1_opt
            exp3 = PromisingReviewConference(
                paper_num=paper_num,
                reviewer_num=reviewer_num,
                reviewer_quality=reviewer_quality,
                eta=eta_opt,
                lambda_per_paper=lambda_val,
                stage1_reviewer_ratio=0.5,  # Irrelevant with robust allocator
                lambda1=lambda1  # Explicit stage 1 lambda (lambda2 derived inside via constraint)
            )
            results3 = exp3.run_promising_review()
            # Collect all promising review metrics
            for metric in promising_metrics:
                promising_metrics[metric].append(results3[metric])

        # Build comprehensive result row with all metrics
        result_row = {'lambda': lambda_val}

        # Add all direct review metrics
        for metric in direct_metrics:
            if metric == 'kl_divergence':
                result_row[f'direct_{metric}_mean'] = geometric_mean_safe(direct_metrics[metric])
            else:
                result_row[f'direct_{metric}_mean'] = np.mean(direct_metrics[metric])
            result_row[f'direct_{metric}_std'] = np.std(direct_metrics[metric])

        # Add all two-stage review metrics
        for metric in twostage_metrics:
            if metric == 'kl_divergence':
                result_row[f'twostage_{metric}_mean'] = geometric_mean_safe(twostage_metrics[metric])
            else:
                result_row[f'twostage_{metric}_mean'] = np.mean(twostage_metrics[metric])
            result_row[f'twostage_{metric}_std'] = np.std(twostage_metrics[metric])

        # Add all promising review metrics
        for metric in promising_metrics:
            if metric == 'kl_divergence':
                result_row[f'promising_{metric}_mean'] = geometric_mean_safe(promising_metrics[metric])
            else:
                result_row[f'promising_{metric}_mean'] = np.mean(promising_metrics[metric])
            result_row[f'promising_{metric}_std'] = np.std(promising_metrics[metric])

        results.append(result_row)

        print(f"\nResults for lambda={lambda_val}:")
        print(f"Direct Review:")
        print(f"  Accuracy: {result_row['direct_accuracy_mean']:.4f} ± {result_row['direct_accuracy_std']:.4f}")
        print(f"  F1: {result_row['direct_f1_score_mean']:.4f} ± {result_row['direct_f1_score_std']:.4f}")
        print(f"  BP Raw Error: {result_row['direct_bp_raw_error_mean']:.4f} ± {result_row['direct_bp_raw_error_std']:.4f}")
        print(f"  KL Divergence: {result_row['direct_kl_divergence_mean']:.4f} ± {result_row['direct_kl_divergence_std']:.4f}")

        print(f"Two-Stage Ambiguous:")
        print(f"  Accuracy: {result_row['twostage_accuracy_mean']:.4f} ± {result_row['twostage_accuracy_std']:.4f}")
        print(f"  F1: {result_row['twostage_f1_score_mean']:.4f} ± {result_row['twostage_f1_score_std']:.4f}")
        print(f"  BP Raw Error: {result_row['twostage_bp_raw_error_mean']:.4f} ± {result_row['twostage_bp_raw_error_std']:.4f}")
        print(f"  KL Divergence: {result_row['twostage_kl_divergence_mean']:.4f} ± {result_row['twostage_kl_divergence_std']:.4f}")

        print(f"Two-Stage Promising:")
        print(f"  Accuracy: {result_row['promising_accuracy_mean']:.4f} ± {result_row['promising_accuracy_std']:.4f}")
        print(f"  F1: {result_row['promising_f1_score_mean']:.4f} ± {result_row['promising_f1_score_std']:.4f}")
        print(f"  BP Raw Error: {result_row['promising_bp_raw_error_mean']:.4f} ± {result_row['promising_bp_raw_error_std']:.4f}")
        print(f"  KL Divergence: {result_row['promising_kl_divergence_mean']:.4f} ± {result_row['promising_kl_divergence_std']:.4f}")

    with open(output_file, 'w', newline='') as csvfile:
        # Generate comprehensive fieldnames for all metrics
        metrics_list = ['accuracy', 'f1_score', 'precision', 'recall',
                       'bp_raw_error', 'calibration_error', 'bp_f1_loss', 'bp_f1_score_50pct', 'kl_divergence', 'js_divergence']

        fieldnames = ['lambda']
        for mode in ['direct', 'twostage', 'promising']:
            for metric in metrics_list:
                fieldnames.extend([f'{mode}_{metric}_mean', f'{mode}_{metric}_std'])
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for row in results:
            writer.writerow(row)

    print(f"\n{'='*60}")
    print(f"Experiment complete! Results saved to {output_file}")
    print(f"{'='*60}")

    return output_file, results

if __name__ == "__main__":
    output_file, results = run_lambda_experiment()

    print("\nSummary Table - Key Metrics:")
    print(f"{'Lambda':<8} {'Direct F1':<13} {'Direct KLD':<13} {'Ambig F1':<13} {'Ambig KLD':<13} {'Prom F1':<13} {'Prom KLD':<13}")
    print("-" * 100)
    for r in results:
        print(f"{r['lambda']:<8} "
              f"{r['direct_f1_score_mean']:.3f}±{r['direct_f1_score_std']:.3f}   "
              f"{r['direct_kl_divergence_mean']:.4f}±{r['direct_kl_divergence_std']:.4f}  "
              f"{r['twostage_f1_score_mean']:.3f}±{r['twostage_f1_score_std']:.3f}   "
              f"{r['twostage_kl_divergence_mean']:.4f}±{r['twostage_kl_divergence_std']:.4f}  "
              f"{r['promising_f1_score_mean']:.3f}±{r['promising_f1_score_std']:.3f}   "
              f"{r['promising_kl_divergence_mean']:.4f}±{r['promising_kl_divergence_std']:.4f}")

    print(f"\nDetailed results saved to: {output_file}")
    print(f"CSV contains all metrics: Accuracy, F1, Precision, Recall, BP Raw Error, BP F1 Loss, BP F1 Score (50%), KL Divergence, JS Divergence")