import numpy as np
import csv
import time
from datetime import datetime
import os
import sys
import random
from scipy import stats

sys.path.append('./review_modes')
from direct_review import DirectReviewConference
from two_stage_review import TwoStageReviewConference
from promising_review import PromisingReviewConference
from geometric_mean_utils import geometric_mean_safe

def run_paper_num_lambda_per_paper_experiment():
    """
    Run experiment varying paper numbers while keeping lambda_per_paper constant (default 12)
    paper_num varies: [50, 100, 200, 400, 1000, 5000]
    reviewer_num = paper_num
    lambda_per_paper constant (default 12; override via TARGET_LAMBDA)
    """
    # Fixed parameters
    lambda_per_paper = float(os.getenv('TARGET_LAMBDA', '8'))
    reviewer_quality = 0.7

    # Variable parameters
    paper_nums = [50, 100, 200, 400, 1000, 5000]
    num_trials = 5

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = f"paper_num_lambda_per_paper_experiment_results_{timestamp}.csv"

    results = []

    for paper_num in paper_nums:
        reviewer_num = paper_num  # Reviewer number equals paper number

        # Calculate lambda_per_reviewer based on lambda_per_paper
        # lambda_per_paper = (lambda_per_reviewer * reviewer_num) / paper_num
        # Since reviewer_num = paper_num, lambda_per_reviewer = lambda_per_paper
        lambda_per_reviewer = lambda_per_paper

        print(f"\n{'='*60}")
        print(f"Testing paper_num = {paper_num}, reviewer_num = {reviewer_num} (equal)")
        print(f"Target λ per paper = {lambda_per_paper}")
        print(f"Calculated λ per reviewer = {lambda_per_reviewer}")
        print(f"Total reviews = {lambda_per_paper * paper_num}")
        print(f"{'='*60}")

        # Two-stage configuration (original strategy): lambda1 = T*(1-eta), lambda2 = T (derived)
        eta_opt = float(os.getenv('ETA_ORIG', '0.4'))
        lambda1_opt = lambda_per_paper * (1 - eta_opt)
        if lambda1_opt < 0:
            raise ValueError(f"lambda1 computed negative: {lambda1_opt} (eta={eta_opt}, T={lambda_per_paper})")
        print(f"Using two-stage strategy: original | eta={eta_opt}, lambda1={lambda1_opt:.3f} (lambda2 will be derived as T)")

        # Storage for all metrics
        direct_metrics = {
            'accuracy': [], 'f1_score': [], 'precision': [], 'recall': [],
            'bp_raw_error': [], 'calibration_error': [], 'bp_f1_loss': [], 'bp_f1_score_50pct': [], 'kl_divergence': [], 'js_divergence': []
        }
        twostage_metrics = {
            'accuracy': [], 'f1_score': [], 'precision': [], 'recall': [],
            'bp_raw_error': [], 'calibration_error': [], 'bp_f1_loss': [], 'bp_f1_score_50pct': [], 'kl_divergence': [], 'js_divergence': []
        }
        promising_metrics = {
            'accuracy': [], 'f1_score': [], 'precision': [], 'recall': [],
            'bp_raw_error': [], 'calibration_error': [], 'bp_f1_loss': [], 'bp_f1_score_50pct': [], 'kl_divergence': [], 'js_divergence': []
        }

        for trial in range(num_trials):
            print(f"\nTrial {trial + 1}/{num_trials} for paper_num={paper_num}")

            # Set the same seed for each trial across different methods
            trial_seed = 42 + trial * 1000 + paper_num * 10

            print("Running Direct Review...")
            np.random.seed(trial_seed)
            random.seed(trial_seed)
            exp1 = DirectReviewConference(
                paper_num=paper_num,
                reviewer_num=reviewer_num,
                reviewer_quality=reviewer_quality,
                lambda_per_paper=lambda_per_paper
            )
            results1 = exp1.run_direct_review()
            # Collect all direct review metrics
            for metric in direct_metrics:
                direct_metrics[metric].append(results1[metric])

            print("Running Two-Stage Ambiguous Review (original strategy)...")
            np.random.seed(trial_seed)
            random.seed(trial_seed)
            # Optimal strategy: eta fixed, lambda2 fixed, lambda1 derived
            lambda1 = lambda1_opt
            exp2 = TwoStageReviewConference(
                paper_num=paper_num,
                reviewer_num=reviewer_num,
                reviewer_quality=reviewer_quality,
                eta=eta_opt,
                lambda_per_paper=lambda_per_paper,
                stage1_reviewer_ratio=0.5,  # Irrelevant with robust allocator
                lambda1=lambda1  # Explicit stage 1 lambda (lambda2 derived inside)
            )
            results2 = exp2.run_two_stage_review()
            # Collect all two-stage review metrics
            for metric in twostage_metrics:
                twostage_metrics[metric].append(results2[metric])

            print("Running Two-Stage Promising Review (original strategy)...")
            np.random.seed(trial_seed)
            random.seed(trial_seed)
            # Optimal strategy: eta fixed, lambda2 fixed, lambda1 derived
            lambda1 = lambda1_opt
            exp3 = PromisingReviewConference(
                paper_num=paper_num,
                reviewer_num=reviewer_num,
                reviewer_quality=reviewer_quality,
                eta=eta_opt,
                lambda_per_paper=lambda_per_paper,
                stage1_reviewer_ratio=0.5,  # Irrelevant with robust allocator
                lambda1=lambda1  # Explicit stage 1 lambda (lambda2 derived inside)
            )
            results3 = exp3.run_promising_review()
            # Collect all promising review metrics
            for metric in promising_metrics:
                promising_metrics[metric].append(results3[metric])

        # Build comprehensive result row with all metrics
        result_row = {
            'paper_num': paper_num,
            'reviewer_num': reviewer_num,
            'lambda_per_paper': lambda_per_paper,
            'lambda_per_reviewer': lambda_per_reviewer,
            'total_reviews': lambda_per_paper * paper_num
        }

        # Add all direct review metrics
        for metric in direct_metrics:
            if metric == 'kl_divergence':
                result_row[f'direct_{metric}_mean'] = geometric_mean_safe(direct_metrics[metric])
            else:
                result_row[f'direct_{metric}_mean'] = np.mean(direct_metrics[metric])
            result_row[f'direct_{metric}_std'] = np.std(direct_metrics[metric])

        # Add all two-stage review metrics
        for metric in twostage_metrics:
            if metric == 'kl_divergence':
                result_row[f'twostage_{metric}_mean'] = geometric_mean_safe(twostage_metrics[metric])
            else:
                result_row[f'twostage_{metric}_mean'] = np.mean(twostage_metrics[metric])
            result_row[f'twostage_{metric}_std'] = np.std(twostage_metrics[metric])

        # Add all promising review metrics
        for metric in promising_metrics:
            if metric == 'kl_divergence':
                result_row[f'promising_{metric}_mean'] = geometric_mean_safe(promising_metrics[metric])
            else:
                result_row[f'promising_{metric}_mean'] = np.mean(promising_metrics[metric])
            result_row[f'promising_{metric}_std'] = np.std(promising_metrics[metric])

        results.append(result_row)

        print(f"\nResults for paper_num={paper_num} (reviewers={reviewer_num}):")
        print(f"Direct Review:")
        print(f"  Accuracy: {result_row['direct_accuracy_mean']:.4f} ± {result_row['direct_accuracy_std']:.4f}")
        print(f"  F1: {result_row['direct_f1_score_mean']:.4f} ± {result_row['direct_f1_score_std']:.4f}")
        print(f"  BP Raw Error: {result_row['direct_bp_raw_error_mean']:.4f} ± {result_row['direct_bp_raw_error_std']:.4f}")
        print(f"  KL Divergence: {result_row['direct_kl_divergence_mean']:.4f} ± {result_row['direct_kl_divergence_std']:.4f}")

        print(f"Two-Stage Ambiguous:")
        print(f"  Accuracy: {result_row['twostage_accuracy_mean']:.4f} ± {result_row['twostage_accuracy_std']:.4f}")
        print(f"  F1: {result_row['twostage_f1_score_mean']:.4f} ± {result_row['twostage_f1_score_std']:.4f}")
        print(f"  BP Raw Error: {result_row['twostage_bp_raw_error_mean']:.4f} ± {result_row['twostage_bp_raw_error_std']:.4f}")
        print(f"  KL Divergence: {result_row['twostage_kl_divergence_mean']:.4f} ± {result_row['twostage_kl_divergence_std']:.4f}")

        print(f"Two-Stage Promising:")
        print(f"  Accuracy: {result_row['promising_accuracy_mean']:.4f} ± {result_row['promising_accuracy_std']:.4f}")
        print(f"  F1: {result_row['promising_f1_score_mean']:.4f} ± {result_row['promising_f1_score_std']:.4f}")
        print(f"  BP Raw Error: {result_row['promising_bp_raw_error_mean']:.4f} ± {result_row['promising_bp_raw_error_std']:.4f}")
        print(f"  KL Divergence: {result_row['promising_kl_divergence_mean']:.4f} ± {result_row['promising_kl_divergence_std']:.4f}")

    # Write results to CSV
    # Generate comprehensive fieldnames for all metrics
    metrics_list = ['accuracy', 'f1_score', 'precision', 'recall',
                   'bp_raw_error', 'calibration_error', 'bp_f1_loss', 'bp_f1_score_50pct', 'kl_divergence', 'js_divergence']

    fieldnames = ['paper_num', 'reviewer_num', 'lambda_per_paper', 'lambda_per_reviewer', 'total_reviews']
    for mode in ['direct', 'twostage', 'promising']:
        for metric in metrics_list:
            fieldnames.extend([f'{mode}_{metric}_mean', f'{mode}_{metric}_std'])

    with open(output_file, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for row in results:
            writer.writerow(row)

    print(f"\n{'='*60}")
    print(f"Experiment complete! Results saved to {output_file}")
    print(f"{'='*60}")

    return output_file, results

if __name__ == "__main__":
    output_file, results = run_paper_num_lambda_per_paper_experiment()

    print(f"\nSummary Table - Key Metrics (λ per paper = {results[0]['lambda_per_paper']}):")
    print(f"{'Papers':<8} {'Revs':<6} {'λ/paper':<8} {'λ/rev':<8} {'Direct F1':<13} {'Direct KLD':<13} {'Ambig F1':<13} {'Ambig KLD':<13} {'Prom F1':<13} {'Prom KLD':<13}")
    print("-" * 130)
    for r in results:
        print(f"{r['paper_num']:<8} {r['reviewer_num']:<6} "
              f"{r['lambda_per_paper']:<8} {r['lambda_per_reviewer']:<8} "
              f"{r['direct_f1_score_mean']:.3f}±{r['direct_f1_score_std']:.3f}   "
              f"{r['direct_kl_divergence_mean']:.4f}±{r['direct_kl_divergence_std']:.4f}  "
              f"{r['twostage_f1_score_mean']:.3f}±{r['twostage_f1_score_std']:.3f}   "
              f"{r['twostage_kl_divergence_mean']:.4f}±{r['twostage_kl_divergence_std']:.4f}  "
              f"{r['promising_f1_score_mean']:.3f}±{r['promising_f1_score_std']:.3f}   "
              f"{r['promising_kl_divergence_mean']:.4f}±{r['promising_kl_divergence_std']:.4f}")

    print(f"\nDetailed results saved to: {output_file}")
    print(f"CSV contains all metrics: Accuracy, F1, Precision, Recall, BP Raw Error, BP F1 Loss, BP F1 Score (50%), KL Divergence, JS Divergence")
    print(f"Configuration: λ per paper = 10 (constant), paper_num = reviewer_num, paper_nums = [20, 50, 100, 200, 400, 1000]")