#!/usr/bin/env python3
"""
Phase 2: DABA Difficulty Predictor Training 

"""

import os
import math
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from scipy.stats import pearsonr
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)

# ================= Configuration =================
MODEL_PATH = "./models/deberta-v3-base" 

# DATA_PATH = "./data/cache/train_labeled_30B_16k.parquet"
# OUTPUT_DIR = "./models/predictor-30b"

DATA_PATH = "./data/cache/train_labeled_30B_4096.parquet"
OUTPUT_DIR = "./models/predictor-7b"

BEST_MODEL_DIR = os.path.join(OUTPUT_DIR, "best_checkpoint")

MAX_LEN = 512
BATCH_SIZE = 8        # Base model suggests 8-16
LEARNING_RATE = 1e-5  # Lower learning rate for stability
EPOCHS = 10
SEED = 42
# ================= 1. Dataset Class (Key Fix) =================
class TailTruncatedDataset(Dataset):
    """
    Key Fix: Retain tail when text is too long, not head.
    """
    def __init__(self, texts, labels, tokenizer, max_len=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # 1. Get all Tokens
        input_ids = self.tokenizer.encode(text, add_special_tokens=False)
        
        # 2. Tail truncation strategy (Keep last max_len-2)
        capacity = self.max_len - 2
        if len(input_ids) > capacity:
            input_ids = input_ids[-capacity:] 
            
        # 3. Add special Tokens
        input_ids = [self.tokenizer.cls_token_id] + input_ids + [self.tokenizer.sep_token_id]
        
        # 4. Padding
        attention_mask = [1] * len(input_ids)
        padding_length = self.max_len - len(input_ids)
        if padding_length > 0:
            input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length
            attention_mask = attention_mask + [0] * padding_length
            
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "labels": torch.tensor(label, dtype=torch.float)
        }

# ================= 2. Utils =================
def compute_metrics(eval_pred):
    preds, labels = eval_pred
    preds = preds.squeeze()
    
    # Pearson Correlation
    pearson, _ = pearsonr(preds, labels)
    # MAE (Revert to Token count)
    mae_tokens = np.mean(np.abs(np.exp(preds) - np.exp(labels)))
    
    return {"pearson": pearson, "mae_tokens": mae_tokens}

def prepare_data(df_path):
    print(f"Loading data from {df_path}...")
    df = pd.read_parquet(df_path)
    print(f"Original Count: {len(df)}")

    # 1. Keep only correct questions (Correct Only)
    if 'is_correct' in df.columns:
        df = df[df['is_correct'] == True]
        print(f"Correct Samples: {len(df)}")
    
    # 2. Get length
    if 'oracle_length' in df.columns:
        df['target'] = df['oracle_length']
    elif 'gen_length' in df.columns:
        df['target'] = df['gen_length']
    else:
        df['target'] = df['response'].apply(lambda x: len(x.split()) * 1.3)

    # ================= Core Fix: Truncation for 4096 data =================
    
    cutoff_threshold = 16000
    
    truncated_df = df[df['target'] >= cutoff_threshold]
    df = df[df['target'] < cutoff_threshold]
    
    print(f"Dropped {len(truncated_df)} samples near ceiling ({cutoff_threshold}+ tokens).")
    
    # Filter too short
    df = df[df['target'] >= 20] # Slightly relax lower bound
    
    # ==========================================================

    print(f"Final Valid Training Size: {len(df)}")
    
    if len(df) < 50:
        raise ValueError("Too few samples after filtering! Check cutoff_threshold or data source.")

    # Log transform
    df['label'] = np.log(df['target'])
    
    # Print stats to help judge
    print(f"Label Stats (Tokens): Mean={df['target'].mean():.1f}, Min={df['target'].min()}, Max={df['target'].max()}")
    
    return df

# ================= 3. Main =================
def main():
    print(">>> Initializing DABA Predictor Trainer (Adapted for 4096)...")
    
    # 1. Load Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    
    # 2. Prepare data
    df = prepare_data(DATA_PATH)
    train_df, val_df = train_test_split(df, test_size=0.1, random_state=SEED)
    
    train_ds = TailTruncatedDataset(
        train_df['question'].tolist(), 
        train_df['label'].tolist(), 
        tokenizer, MAX_LEN
    )
    val_ds = TailTruncatedDataset(
        val_df['question'].tolist(), 
        val_df['label'].tolist(), 
        tokenizer, MAX_LEN
    )
    
    # 3. Load model
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_PATH, num_labels=1, problem_type="regression"
    )
    
    # 4. Config Trainer
    args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        weight_decay=0.01,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="pearson",
        greater_is_better=True,
        logging_steps=50,
        save_total_limit=1,
        fp16=True,
        report_to="none",
        seed=SEED
    )
    
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        tokenizer=tokenizer, 
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )
    
    # 5. Train
    print("\n>>> Start Training...")
    trainer.train()
    
    # 6. Save final
    print(f"\n>>> Saving Best Model to {BEST_MODEL_DIR}...")
    trainer.save_model(BEST_MODEL_DIR)
    
    # 7. Plot (Log Scale)
    print(">>> Generating Log-Scale Validation Plot...")
    preds = trainer.predict(val_ds).predictions.squeeze()
    labels = np.array(val_df['label'])
    
    plt.figure(figsize=(8,8))
    
    # Revert to real values for plotting (Exp)
    y_true = np.exp(labels)
    y_pred = np.exp(preds)
    
    plt.scatter(y_true, y_pred, alpha=0.5, s=15, c='blue', edgecolors='none')
    

    plt.plot([100, 5000], [100, 5000], 'r--', linewidth=2, label='Ideal')
    
    # Set Log axis
    plt.xscale('log')
    plt.yscale('log')
    
    # Set display range
    plt.xlim(10, 5000)
    plt.ylim(10, 5000)
    
    plt.xlabel('True Length (Log Scale)')
    plt.ylabel('Pred Length (Log Scale)')
    plt.title(f'Difficulty Prediction (Pearson: {pearsonr(preds, labels)[0]:.4f})')
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.legend()
    
    save_plot_path = os.path.join(OUTPUT_DIR, "simple_scatter_log.png")
    plt.savefig(save_plot_path, dpi=150)
    print(f">>> Plot saved to {save_plot_path}")

if __name__ == "__main__":
    main()