#!/usr/bin/env python3
# p3_utils.py

import pandas as pd
import numpy as np
import json
import torch
import os
from tqdm import tqdm
from scipy.special import lambertw
from transformers import AutoTokenizer, AutoModelForSequenceClassification

class DifficultyPredictor:
    def __init__(self, model_path, device=None):
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Initializing Predictor on {self.device}...")
        print(f"Model path: {model_path}")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_path,
                num_labels=1,
                problem_type="regression"
            )
            self.model.to(self.device)
            self.model.eval()
            print(f"✓ Predictor loaded successfully from: {model_path}")
        except Exception as e:
            print(f"Error loading predictor: {e}")
            raise e

    def predict_batch(self, texts, batch_size=32):
        all_preds = []
        with torch.no_grad():
            for i in range(0, len(texts), batch_size):
                batch_texts = texts[i:i + batch_size]
                inputs = self.tokenizer(
                    batch_texts, 
                    return_tensors="pt", 
                    padding=True, 
                    truncation=True, 
                    max_length=512
                ).to(self.device)
                outputs = self.model(**inputs)
                logits = outputs.logits.squeeze(-1)
                all_preds.append(logits.cpu().numpy())
        return np.concatenate(all_preds)


class DataHandler:
    def __init__(self, file_paths_dict):
        self.pools = {}
        for group, paths in file_paths_dict.items():
            dfs = [pd.read_parquet(p) for p in paths]
            self.pools[group] = pd.concat(dfs, ignore_index=True)
            if 'predicted_length' not in self.pools[group].columns:
                self.pools[group]['predicted_length'] = 0.0
            print(f"Pool [{group}] loaded: {len(self.pools[group])} samples.")

    def sample_scenario(self, ratios, total_n, seed=42):
        sampled_dfs = []
        groups = ['easy', 'mid', 'hard']
        for i, group in enumerate(groups):
            if group not in self.pools: continue
            n_needed = int(ratios[i] * total_n)
            if n_needed == 0: continue
            pool = self.pools[group]
            replace = len(pool) < n_needed
            sampled = pool.sample(n=n_needed, replace=replace, random_state=seed)
            sampled = sampled.copy()
            sampled['source_group'] = group
            sampled_dfs.append(sampled)
        if not sampled_dfs: raise ValueError("No samples were collected!")
        final_df = pd.concat(sampled_dfs, ignore_index=True)
        return final_df.sample(frac=1, random_state=seed).reset_index(drop=True)


# ==========================================
# [Core] DABA Lambert W Optimizer
# ==========================================
class DABAOptimizer:
    def __init__(self, alpha=2.0, beta=0.002, t_max=8192, epsilon=1e-6):
        """
        DABA: Difficulty-Aware Budget Allocation using Shadow Price Optimization.
        """
        self.alpha = alpha 
        self.beta = beta   
        self.t_max = t_max
        self.epsilon = epsilon

    def _calc_optimal_t(self, tau, lam):
        # 1. Rational Abandonment
        if lam >= self.alpha:
            return 0.0

        # 2. Lambert W Solution
        z = (lam * np.exp(1)) / self.alpha
        
        try:
            w_val = np.real(lambertw(z))
            # t_unc = tau + (1 - W(z)) / beta
            dt = (1.0 - w_val) / self.beta
            t_unc = tau + dt
        except:
            return 0.0

        # 3. Solvency Check
        if dt <= 0: return 0.0
        
        phi = self.alpha * dt * np.exp(-self.beta * dt)
        cost = lam * t_unc
        
        if phi <= cost: return 0.0

        # 4. Hard Truncation
        return np.clip(t_unc, 0, self.t_max)

    def solve(self, pred_lengths, total_budget):
        taus = np.array(pred_lengths)
        n = len(taus)
        
        low = 0.0
        high = self.alpha + 1e-5
        best_allocations = np.zeros(n)
        
        for _ in range(40):
            mid_lam = (low + high) / 2
            
            # Scalar alpha implementation
            current_allocs = np.array([self._calc_optimal_t(t, mid_lam) for t in taus])
            total_usage = np.sum(current_allocs)
            
            if abs(total_usage - total_budget) < self.epsilon * total_budget:
                best_allocations = current_allocs
                break
            
            if total_usage > total_budget:
                low = mid_lam 
            else:
                high = mid_lam
                best_allocations = current_allocs 
        
        return np.floor(best_allocations).astype(int)


class Allocators:
    @staticmethod
    def uniform(n_samples, total_budget, min_tokens=32):
        avg = max(min_tokens, total_budget // n_samples)
        allocs = np.full(n_samples, avg, dtype=int)
        current_sum = np.sum(allocs)
        diff = total_budget - current_sum
        if diff > 0:
            for i in range(min(diff, n_samples)): allocs[-(i+1)] += 1
        elif diff < 0:
            for i in range(min(abs(diff), n_samples)): allocs[-(i+1)] = max(min_tokens, allocs[-(i+1)] + diff // n_samples)
        return allocs

    @staticmethod
    def oracle(oracle_lengths, total_budget, min_tokens=32):
        n = len(oracle_lengths)
        df = pd.DataFrame({'len': oracle_lengths, 'orig_idx': range(n)})
        df_sorted = df.sort_values('len')
        allocs = np.full(n, min_tokens, dtype=int)
        remaining = total_budget - n * min_tokens
        if remaining < 0:
            remaining = total_budget
            allocs = np.zeros(n, dtype=int)
        for _, row in df_sorted.iterrows():
            if remaining <= 0: break
            needed = row['len']
            if allocs[row['orig_idx']] < needed:
                can_add = min(needed - allocs[row['orig_idx']], remaining)
                allocs[row['orig_idx']] += can_add
                remaining -= can_add
        return allocs

    @staticmethod
    def pred_direct(pred_lengths, total_budget, min_tokens=32, max_tokens=4096):
        preds = np.array(pred_lengths)
        n_samples = len(preds)
        preds = np.maximum(preds, 1e-6)
        sum_preds = np.sum(preds)
        if sum_preds == 0: return Allocators.uniform(n_samples, total_budget, min_tokens)
        
        scale_factor = total_budget / sum_preds
        raw_allocs = preds * scale_factor
        clipped_allocs = np.clip(raw_allocs, min_tokens, max_tokens).astype(int)
        
        current_sum = np.sum(clipped_allocs)
        diff = int(total_budget - current_sum)
        final_allocs = clipped_allocs.copy()
        
        if diff > 0:
            indices = np.argsort(preds)[::-1]
            for i in range(abs(diff)):
                idx = indices[i % n_samples]
                if final_allocs[idx] < max_tokens: final_allocs[idx] += 1
        elif diff < 0:
            indices = np.argsort(preds)
            for i in range(abs(diff)):
                idx = indices[i % n_samples]
                if final_allocs[idx] > min_tokens: final_allocs[idx] -= 1
        return final_allocs

    @staticmethod
    def daba_auction(pred_lengths, total_budget, min_tokens_survivor=32, max_tokens=4096):
        preds = np.array(pred_lengths)
        n_samples = len(preds)
        indices_roi = np.argsort(preds)
        estimated_costs = preds[indices_roi] * 1.1
        cumulative_costs = np.cumsum(estimated_costs)
        admitted_count = np.searchsorted(cumulative_costs, total_budget, side='right')
        
        survivor_mask = np.zeros(n_samples, dtype=bool)
        if admitted_count == 0 and total_budget > 0: admitted_count = 1
        survivor_indices_global = indices_roi[:admitted_count]
        survivor_mask[survivor_indices_global] = True
        
        final_allocs = np.zeros(n_samples, dtype=int)
        if admitted_count > 0:
            survivor_preds = preds[survivor_mask]
            sum_surv = np.sum(survivor_preds)
            if sum_surv > 0:
                raw = survivor_preds * (total_budget / sum_surv)
                final_allocs[survivor_mask] = np.clip(raw, min_tokens_survivor, max_tokens).astype(int)

            if np.sum(final_allocs) > total_budget:
                diff = int(np.sum(final_allocs) - total_budget)
                sorted_surv_desc = np.argsort(survivor_preds)[::-1]
                for i in range(diff):
                    local_idx = sorted_surv_desc[i % admitted_count]
                    global_idx = survivor_indices_global[local_idx]
                    if final_allocs[global_idx] > 0: final_allocs[global_idx] -= 1
        return final_allocs, float(admitted_count)
    
    @staticmethod
    def daba_heur(pred_lengths, total_budget, min_tokens=32, max_tokens=4096):
        preds = np.array(pred_lengths)
        n_samples = len(preds)
        mu_d = np.mean(preds)
        bar_B = total_budget / n_samples
        if bar_B < 0.8 * mu_d:
            cutoff = np.median(preds)
            survivor_mask = preds <= cutoff
            if np.sum(survivor_mask) == 0: survivor_mask = np.ones(n_samples, dtype=bool)
            final_allocs = np.zeros(n_samples, dtype=int)
            survivor_preds = preds[survivor_mask]
            sum_s = np.sum(survivor_preds)
            if sum_s > 0:
                raw = survivor_preds * (total_budget / sum_s)
                final_allocs[survivor_mask] = np.clip(raw, min_tokens, max_tokens).astype(int)
            return final_allocs
        else:
            return Allocators.pred_direct(preds, total_budget, min_tokens, max_tokens)

    # ==========================================
    # [Updated] DABA Lambert: Surge-Filling Beta
    # ==========================================
    @staticmethod
    def daba_lambert(pred_lengths, total_budget, alpha=2.0, t_max=8192):
        """
        DABA Lambert W (Headroom-Aware Beta)
        
        Adjusts beta so that the 'Surge' (1/beta) fills the gap between Budget and Prediction.
        Prevents over-allocation in 'MostlyEasy' scenarios.
        """
        preds = np.array(pred_lengths)
        n = len(preds)
        if n == 0: return np.array([])
        
        avg_budget = total_budget / n
        mu_d = np.mean(preds) 
        

        gap = avg_budget - mu_d
        
        target_surge = max(0.2 * mu_d, gap)
        adaptive_beta = 1.0 / target_surge
        optimizer = DABAOptimizer(alpha=alpha, beta=adaptive_beta, t_max=t_max)
        return optimizer.solve(preds, total_budget)
   
class Evaluator:
    def __init__(self, tokenizer_path, use_fast_mode=True):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.use_fast_mode = use_fast_mode
        if not use_fast_mode:
            self.verify_func = math_metric(
                gold_extraction_target=(LatexExtractionConfig(), ExprExtractionConfig()),
                pred_extraction_target=(LatexExtractionConfig(), ExprExtractionConfig()),
                aggregation_function=max,
                precision=6
            )

    def evaluate_truncation(self, df, allocs, batch_size=100):
        results = []
        n_samples = len(df)
        for start_idx in tqdm(range(0, n_samples, batch_size), desc="Evaluating allocations"):
            end_idx = min(start_idx + batch_size, n_samples)
            batch_df = df.iloc[start_idx:end_idx]
            batch_allocs = allocs[start_idx:end_idx]
            for i, (_, row) in enumerate(batch_df.iterrows()):
                budget = batch_allocs[i]
                if budget == 0:
                    results.append(False)
                    continue
                oracle_len = row['oracle_length']
                if budget >= oracle_len:
                    is_correct = row.get('is_correct', False)
                else:
                    if self.use_fast_mode: is_correct = False
                    else: is_correct = self._verify_truncated(row, budget)
                results.append(is_correct)
        return np.array(results)

    def _verify_truncated(self, row, budget):
        try:
            if isinstance(row['full_token_ids'], str): token_ids = json.loads(row['full_token_ids'])
            else: token_ids = row['full_token_ids']
            truncated_ids = token_ids[:budget]
            truncated_text = self.tokenizer.decode(truncated_ids, skip_special_tokens=True)
            grade, _ = self.verify_func([row['gold_answer']], [truncated_text])
            return grade == 1
        except Exception as e:
            print(f"Verification error: {e}")
            return False