#!/usr/bin/env python3
# p3_run_scenarios.py

import json
import numpy as np
import pandas as pd
import sys
import os
import pickle
import traceback

current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)

from p3_utils import DataHandler, Allocators, Evaluator, DifficultyPredictor

# --- CONFIGURATION ---

DETAILED_DATA_PATH = './data/generations/test_ood_oracle_30B_16k.parquet'
TOKENIZER_PATH = "./models/Qwen3-30B-A3B-Instruct-2507"
PREDICTOR_MODEL_PATH = "./models/predictor-30b/best_checkpoint"
OUTPUT_JSON = "./results/exp2_multi_scenario_metrics_30b_16k.json"
ALLOCATIONS_PKL = "./results/exp2_allocations_30b_16k.pkl"

TARGET_SAMPLES = 500
MAX_TOKEN_LENGTH = 16384

SCENARIOS = {
    "Balanced":   {'ratios': [0.25, 0.25, 0.25, 0.25]},
    "MostlyEasy": {'ratios': [0.7, 0.2, 0.07, 0.03]},
    "MostlyHard": {'ratios': [0.1, 0.2, 0.35, 0.35]},
    "UShaped":    {'ratios': [0.4, 0.1, 0.1, 0.4]} 
}
BUDGETS = [256, 512, 1024, 1536]

class SimpleDataHandler:
    def __init__(self, data_path):
        if not os.path.exists(data_path):
            raise FileNotFoundError(f"Data file not found: {data_path}")
        self.df = pd.read_parquet(data_path)
        print(f"Loaded data: {len(self.df)} samples")
        
    def sample_scenario_by_difficulty(self, difficulty_ratios, total_n, seed=42):
        bins = [0, 512, 1024, 1536, float('inf')]
        labels = [0, 1, 2, 3, 4] 
        self.df['length_bin'] = pd.cut(self.df['oracle_length'], bins=bins, labels=labels, right=False, include_lowest=True)
        sampled_dfs = []
        for i, ratio in enumerate(difficulty_ratios):
            n_needed = int(ratio * total_n)
            if n_needed == 0: continue
            pool_df = self.df[self.df['length_bin'] == i]
            if len(pool_df) == 0: continue
            replace = len(pool_df) < n_needed
            sampled = pool_df.sample(n=n_needed, replace=replace, random_state=seed)
            sampled_dfs.append(sampled)
        if not sampled_dfs: raise ValueError("No samples collected!")
        final_df = pd.concat(sampled_dfs, ignore_index=True)
        if 'length_bin' in self.df.columns: self.df.drop(columns=['length_bin'], inplace=True)
        return final_df.sample(frac=1, random_state=seed).reset_index(drop=True)

def main():
    print("=" * 100)
    print("Phase 3b: Multi-Scenario Budget Evaluation (Surge-Filling Beta)")
    print("=" * 100)

    try:
        data_handler = SimpleDataHandler(DETAILED_DATA_PATH)
        evaluator = Evaluator(TOKENIZER_PATH, use_fast_mode=True)
        predictor = DifficultyPredictor(PREDICTOR_MODEL_PATH)
    except Exception as e:
        print(f"Init Error: {e}")
        return

    final_results = {}
    table_rows = []
    all_viz_data = {} 

    for sc_name, config in SCENARIOS.items():
        difficulty_ratios = config['ratios']
        print(f"\n>>> Running Scenario: {sc_name} <<<")

        try:
            current_df = data_handler.sample_scenario_by_difficulty(difficulty_ratios, TARGET_SAMPLES)
            prompts = current_df['question'].tolist()
            pred_logs = predictor.predict_batch(prompts, batch_size=32)
            preds = np.exp(pred_logs) 
            oracles = current_df['oracle_length'].values
            
            sc_results = {}
            sc_viz_data = {} 

            for avg_budget in BUDGETS:
                total_budget = avg_budget * TARGET_SAMPLES
                strategies = {}
                
                strategies['uniform'] = Allocators.uniform(TARGET_SAMPLES, total_budget)
                strategies['oracle'] = Allocators.oracle(oracles, total_budget)
                strategies['pred_direct'] = Allocators.pred_direct(preds, total_budget, min_tokens=0, max_tokens=MAX_TOKEN_LENGTH)
                strategies['daba_heur'] = Allocators.daba_heur(preds, total_budget, min_tokens=0, max_tokens=MAX_TOKEN_LENGTH)
                strategies['daba_auction'], admitted_cnt = Allocators.daba_auction(preds, total_budget, min_tokens_survivor=0)
                strategies['daba_lambert'] = Allocators.daba_lambert(preds, total_budget, alpha=2.0, t_max=MAX_TOKEN_LENGTH)
                
                # Viz Data
                sc_viz_data[avg_budget] = {
                    "preds": preds,
                    "true_lengths": oracles,
                    "allocations": {k: v.copy() for k, v in strategies.items()}
                }

                # Eval
                row_data = {"Scenario": sc_name, "Budget": avg_budget, "Admitted_Auc": admitted_cnt}
                budget_metrics = {}
                for strat_name, allocs in strategies.items():
                    correct_mask = evaluator.evaluate_truncation(current_df, allocs)
                    acc = float(np.mean(correct_mask)) * 100 
                    budget_metrics[strat_name] = {"global_acc": acc / 100.0}
                    row_data[strat_name] = acc

                table_rows.append(row_data)
                sc_results[str(avg_budget)] = budget_metrics

            final_results[sc_name] = sc_results
            all_viz_data[sc_name] = sc_viz_data

        except Exception as e:
            print(f"!!! Error in {sc_name}: {e}")
            traceback.print_exc()
            continue

    os.makedirs(os.path.dirname(OUTPUT_JSON), exist_ok=True)
    with open(OUTPUT_JSON, 'w') as f: json.dump(final_results, f, indent=2)
    with open(ALLOCATIONS_PKL, 'wb') as f: pickle.dump(all_viz_data, f)

    print("\n" + "=" * 130)
    print("FINAL PERFORMANCE SUMMARY (Accuracy %)")
    print("=" * 130)
    
    df_results = pd.DataFrame(table_rows)
    cols = ["Scenario", "Budget", "uniform", "pred_direct", "daba_heur", "daba_auction", "daba_lambert", "oracle"]
    cols = [c for c in cols if c in df_results.columns]
    print(df_results[cols].to_string(index=False, float_format="%.2f"))
    print("=" * 130)

if __name__ == "__main__":
    main()