"""
Eta Mixture Experiment: Unknown per-reviewer quality, known prior (spammer-hammer)

Goal: For fixed paper_num, reviewer_num, and total lambda per paper (T),
      explore eta and two-stage allocation strategies to find robust settings
      when reviewer qualities are sampled from a known prior mixture
      (spammer-hammer: 50% at 0.9, 50% at 0.5), but identities are unknown to BP
      which only receives the prior.

Defaults:
- paper_num = 100
- reviewer_num = 100
- reviewer quality prior (q_exp=0.9, frac_exp=0.5, q_base=0.5)
- TARGET_LAMBDA (T) = 8
- strategies: original, constant_stage2, low_stage2, constant_stage2_frac
- eta grid: [0.3, 0.4, 0.5, 0.6, 0.7]
- trials: 20 (QUICK reduces)

Outputs:
- CSV: eta_mixture_strategy_results_YYYYMMDD_HHMMSS.csv
- Summary printed with best combinations
"""

import os
import sys
import csv
import time
import random
from datetime import datetime
import numpy as np

sys.path.append('./review_modes')
from direct_review import DirectReviewConference
from two_stage_review import TwoStageReviewConference
from promising_review import PromisingReviewConference
from geometric_mean_utils import geometric_mean_safe


def get_lambda_strategies(T):
    """Return strategy dicts mapping name -> (name, desc, lambda_func(eta)->(lambda1, lambda2))"""
    # Base fixed strategies
    strategies = {
        'original': {
            'name': 'Original (lambda1=T*(1-eta), lambda2=T)',
            'lambda_func': lambda eta: (T * (1 - eta), T),
        },
        'constant_stage2': {
            'name': 'Constant Stage2 (lambda2=6)',
            'lambda_func': lambda eta: (T - 6 * eta, 6),
        },
        'low_stage2': {
            'name': 'Low Stage2 (lambda2=3)',
            'lambda_func': lambda eta: (T - 3 * eta, 3),
        },
    }
    # Fractional Stage2 family: default FRAC_GRID="0.25,0.5,0.75"
    frac_grid_env = os.getenv('FRAC_GRID', '0.25,0.5,0.75')
    fracs = []
    try:
        fracs = [float(x) for x in frac_grid_env.split(',') if x.strip()]
    except Exception:
        fracs = [0.5]
    for f in fracs:
        key = f"constant_stage2_frac_{int(round(f*100)):02d}"
        strategies[key] = {
            'name': f'Const Stage2 Fraction (lambda2={f}*T)',
            'lambda_func': (lambda eta, _frac=f: (T - (_frac * T) * eta, _frac * T)),
        }
    # Backward compat single-fraction entry controlled by LAMBDA2_FRAC
    if 'constant_stage2_frac' not in strategies:
        try:
            frac_single = float(os.getenv('LAMBDA2_FRAC', '0.5'))
            strategies['constant_stage2_frac'] = {
                'name': f'Const Stage2 Fraction (lambda2={frac_single}*T)',
                'lambda_func': (lambda eta, _frac=frac_single: (T - (_frac * T) * eta, _frac * T)),
            }
        except Exception:
            pass
    # Optional filter via env STRATEGIES (comma-separated keys)
    only = os.getenv('STRATEGIES')
    if only:
        keep = [x.strip() for x in only.split(',') if x.strip()]
        strategies = {k: v for k, v in strategies.items() if k in keep}
    return strategies



def run_eta_mixture_experiment():
    # Fixed defaults
    paper_num = int(os.getenv('PAPER_NUM', '100'))
    reviewer_num = int(os.getenv('REVIEWER_NUM', '100'))
    T = float(os.getenv('TARGET_LAMBDA', '8'))  # lambda per paper

    # Reviewer prior (spammer-hammer)
    q_exp = float(os.getenv('Q_EXP', '0.9'))
    frac_exp = float(os.getenv('FRAC_EXP', '0.5'))
    q_base = float(os.getenv('Q_BASE', '0.5'))
    reviewer_prior_tuple = (q_exp, frac_exp, q_base)
    # For logging only (average quality)
    reviewer_quality_avg = q_exp * frac_exp + q_base * (1 - frac_exp)

    # Grid
    eta_values = [float(x) for x in os.getenv('ETA_GRID', '0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9').split(',')]
    num_trials = int(os.getenv('TRIALS', '20'))

    quick = ('--quick' in sys.argv) or (os.getenv('QUICK', '0') == '1')
    if quick:
        eta_values = [0.4, 0.6]
        num_trials = 5

    strategies = get_lambda_strategies(T)

    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    out_csv = f'eta_mixture_strategy_results_{timestamp}.csv'

    # Direct baseline (same T, mixture prior, eta-independent)
    direct_metrics = {k: [] for k in ['accuracy','f1_score','precision','recall','bp_raw_error','calibration_error','bp_f1_loss','bp_f1_score_50pct','kl_divergence','js_divergence']}

    print('='*80)
    print('RUNNING DIRECT REVIEW BASELINE (mixture prior)')
    print('='*80)
    for trial in range(num_trials):
        seed = 2025 + trial * 997
        np.random.seed(seed); random.seed(seed)
        exp = DirectReviewConference(
            paper_num=paper_num,
            reviewer_num=reviewer_num,
            reviewer_quality=reviewer_quality_avg,
            lambda_per_paper=T,
            reviewer_prior_tuple=reviewer_prior_tuple,
        )
        res = exp.run_direct_review()
        for k in direct_metrics:
            direct_metrics[k].append(res[k])

    # Precompute baseline aggregates
    direct_agg = {f'direct_{m}_mean': (geometric_mean_safe(v) if m=='kl_divergence' else float(np.mean(v))) for m,v in direct_metrics.items()}
    direct_agg.update({f'direct_{m}_std': float(np.std(v)) for m,v in direct_metrics.items()})

    # Iterate strategies x eta
    results = []
    for strat_key, strat in strategies.items():
        for eta in eta_values:
            lam1, lam2 = strat['lambda_func'](eta)
            if lam1 < 0:
                # Skip invalid
                print(f"Skip invalid: {strat_key} eta={eta} -> lambda1={lam1:.3f}")
                continue

            print('\n' + '='*60)
            print(f"Strategy={strat_key} | eta={eta} | T={T} | lambda1={lam1:.3f} | lambda2={lam2:.3f}")
            print('='*60)

            amb_metrics = {k: [] for k in direct_metrics}
            pro_metrics = {k: [] for k in direct_metrics}

            for trial in range(num_trials):
                seed = 4242 + trial * 101 + int(eta*1000) + hash(strat_key)%1000
                np.random.seed(seed); random.seed(seed)

                # Ambiguous
                exp2 = TwoStageReviewConference(
                    paper_num=paper_num,
                    reviewer_num=reviewer_num,
                    reviewer_quality=reviewer_quality_avg,
                    eta=eta,
                    lambda_per_paper=T,
                    stage1_reviewer_ratio=0.5,
                    lambda1=lam1,
                    reviewer_prior_tuple=reviewer_prior_tuple,
                )
                res2 = exp2.run_two_stage_review()
                for k in amb_metrics:
                    amb_metrics[k].append(res2[k])

                # Promising
                np.random.seed(seed); random.seed(seed)
                exp3 = PromisingReviewConference(
                    paper_num=paper_num,
                    reviewer_num=reviewer_num,
                    reviewer_quality=reviewer_quality_avg,
                    eta=eta,
                    lambda_per_paper=T,
                    stage1_reviewer_ratio=0.5,
                    lambda1=lam1,
                    reviewer_prior_tuple=reviewer_prior_tuple,
                )
                res3 = exp3.run_promising_review()
                for k in pro_metrics:
                    pro_metrics[k].append(res3[k])

            # Aggregate row
            row = {
                'strategy': strat_key,
                'strategy_name': strat['name'],
                'eta': eta,
                'eta_percent': eta*100.0,
                'paper_num': paper_num,
                'reviewer_num': reviewer_num,
                'target_total_lambda': T,
                'lambda1': lam1,
                'lambda2': lam2,
                'actual_total_lambda': T,
                'papers_in_stage2': int(paper_num * eta),
                'stage1_tasks': int(paper_num * lam1),
                'stage2_tasks': int(paper_num * eta * lam2),
                'total_tasks': int(paper_num * lam1 + paper_num * eta * lam2),
            }
            # Direct
            row.update(direct_agg)
            # Ambiguous
            for m,v in amb_metrics.items():
                row[f'ambiguous_{m}_mean'] = (geometric_mean_safe(v) if m=='kl_divergence' else float(np.mean(v)))
                row[f'ambiguous_{m}_std'] = float(np.std(v))
            # Promising
            for m,v in pro_metrics.items():
                row[f'promising_{m}_mean'] = (geometric_mean_safe(v) if m=='kl_divergence' else float(np.mean(v)))
                row[f'promising_{m}_std'] = float(np.std(v))

            results.append(row)

    # Save CSV
    fields = ['strategy','strategy_name','eta','eta_percent','paper_num','reviewer_num','target_total_lambda','lambda1','lambda2','actual_total_lambda','papers_in_stage2','stage1_tasks','stage2_tasks','total_tasks']
    metrics = ['accuracy','f1_score','precision','recall','bp_raw_error','calibration_error','bp_f1_loss','bp_f1_score_50pct','kl_divergence','js_divergence']
    for mode in ['direct','ambiguous','promising']:
        for m in metrics:
            fields.extend([f'{mode}_{m}_mean', f'{mode}_{m}_std'])

    with open(out_csv,'w',newline='') as fh:
        w=csv.DictWriter(fh, fieldnames=fields)
        w.writeheader()
        for r in results:
            w.writerow(r)

    # Summary
    print('\n' + '='*80)
    print('MIXTURE ETA EXPERIMENT SUMMARY')
    print('='*80)
    if results:
        best_bal = max(results, key=lambda x: min(x['ambiguous_accuracy_mean'], x['promising_accuracy_mean']))
        best_amb = max(results, key=lambda x: x['ambiguous_accuracy_mean'])
        best_pro = max(results, key=lambda x: x['promising_accuracy_mean'])
        print(f"Best Balanced: {best_bal['strategy']} eta={best_bal['eta']:.2f} l1={best_bal['lambda1']:.2f} l2={best_bal['lambda2']:.2f} | amb={best_bal['ambiguous_accuracy_mean']:.4f} prom={best_bal['promising_accuracy_mean']:.4f}")
        print(f"Best Ambiguous: {best_amb['strategy']} eta={best_amb['eta']:.2f} acc={best_amb['ambiguous_accuracy_mean']:.4f}")
        print(f"Best Promising: {best_pro['strategy']} eta={best_pro['eta']:.2f} acc={best_pro['promising_accuracy_mean']:.4f}")
    print(f"Direct baseline (avg): acc={direct_agg['direct_accuracy_mean']:.4f} ± {direct_agg['direct_accuracy_std']:.4f}")
    print(f"Results saved: {out_csv}")
    print(f"Config: papers={paper_num}, reviewers={reviewer_num}, T={T}, prior=(q_exp={q_exp}, frac={frac_exp}, q_base={q_base})")

    return out_csv, results


if __name__ == '__main__':
    run_eta_mixture_experiment()
