"""
Deep Model Assisted Statistical Inference Experiment (ICML Rebuttal Version)
Supports: AG News, SST-2 text classification tasks
Modifications: Introduced Validation Splitting and Safe Selection mechanisms
"""

import os
import sys
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from datasets import load_dataset
from modelscope import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import copy # For saving best model parameters

# Import defined modules
from src.utils import set_global_seed, CSVLogger, accuracy
from src.models import LinearHead, ConcatHead, WeightedEnsemble, ResidualModel

# ================= Configuration Section =================
CONFIG = {
    # Task configuration
    "dataset_name": "ag_news",  # or "sst2"
    "num_classes": 4,           # ag_news=4, sst2=2
    "class_names": ["World", "Sports", "Business", "Sci/Tech"], 
    
    # Experiment variables (Few-shot Settings)
    "labeled_sizes": [16, 32, 64, 128, 256], 
    "seed": 42,
    
    # Model configuration
    "model_id": "Qwen/Qwen3-8B",
    "cache_dir": "./modelscope_cache", 
    "max_length": 512,
    
    # Training configuration
    "batch_size": 16,
    "lr": 1e-3,
    "epochs": 50,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}
# =========================================

def load_qwen_model():
    """Load Qwen3-8B model and Tokenizer"""
    print(f"[Init] Downloading/Loading Qwen3-8B from {CONFIG['cache_dir']}...")
    try:
        model_dir = snapshot_download(CONFIG['model_id'], cache_dir=CONFIG['cache_dir'])
    except Exception as e:
        print(f"[Warning] Snapshot download failed: {e}. Trying local path directly.")
        model_dir = os.path.join(CONFIG['cache_dir'], CONFIG['model_id'])

    print(f"[Init] Model Path: {model_dir}")
    
    tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left" 
    
    model = AutoModelForCausalLM.from_pretrained(
        model_dir,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    for p in model.parameters():
        p.requires_grad = False
    model.eval()
    return tokenizer, model

def get_dataset(name, split, n_samples=None, seed=42):
    ds = load_dataset(name, split=split)
    if n_samples and n_samples < len(ds):
        ds = ds.shuffle(seed=seed).select(range(n_samples))
    return ds

def generate_bb_logits(model, tokenizer, texts, class_names, device):
    class_str = ", ".join(class_names)
    prompts = []
    for t in texts:
        p = (
            f"Classify the article into one of: {class_str}.\n"
            f"Article: {t}\n"
            f"Category:"
        )
        prompts.append(p)
        
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=CONFIG['max_length']).to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs, 
            max_new_tokens=10, 
            do_sample=False, 
            pad_token_id=tokenizer.eos_token_id
        )
        
    decoded = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
    
    CONFIDENCE_SCORE = 5.0 
    batch_logits = []
    
    for text in decoded:
        text = text.lower().strip()
        probs = torch.zeros(len(class_names))
        matched = -1
        for i, c in enumerate(class_names):
            if c.lower() in text:
                matched = i
                break
        if matched != -1:
            probs[matched] = CONFIDENCE_SCORE
        batch_logits.append(probs)
        
    return torch.stack(batch_logits)

def extract_features_and_cache(model, tokenizer, dataset, desc="Caching"):
    cache = {"hs": [], "bb": [], "y": []}
    batch_size = 8 
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    print(f"[Process] Starting feature extraction for {desc}...")
    with torch.no_grad():
        for batch in tqdm(dataloader):
            texts = batch['text']
            labels = batch['label']
            
            inputs_feat = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=CONFIG['max_length']).to(CONFIG['device'])
            out_feat = model(**inputs_feat, output_hidden_states=True)
            
            last_idx = inputs_feat.attention_mask.sum(dim=1) - 1
            hs = out_feat.hidden_states[-1] 
            hs_pooled = hs[torch.arange(hs.size(0)), last_idx] 
            
            bb_logits = generate_bb_logits(model, tokenizer, texts, CONFIG['class_names'], CONFIG['device'])
            
            cache["hs"].append(hs_pooled.cpu().float())
            cache["bb"].append(bb_logits.cpu().float())
            cache["y"].append(labels.cpu())
            
    full_ds = TensorDataset(
        torch.cat(cache["hs"]),
        torch.cat(cache["bb"]),
        torch.cat(cache["y"])
    )
    return full_ds

# ==============================================================================
# Modification 1: Training function now accepts Train and Val datasets, returns best validation accuracy
# ==============================================================================
def train_and_validate(train_ds, val_ds, input_dim, num_classes):
    """
    Train all models and select the best checkpoint based on Validation Set.
    Returns: 
        models: Dictionary of trained models
        val_accuracies: Best accuracy of each model on Val set (for Safe Selection)
    """
    device = CONFIG['device']
    train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True)
    
    # Validation set data preparation (move to GPU)
    val_loader = DataLoader(val_ds, batch_size=len(val_ds), shuffle=False)
    val_batch = next(iter(val_loader))
    val_hs, val_bb, val_y = [t.to(device) for t in val_batch]
    
    # 1. Calculate Black-Box accuracy on validation set (Baseline)
    bb_val_acc = accuracy(val_bb.argmax(1).cpu().numpy(), val_y.cpu().numpy())
    
    # Initialize models
    models = {
        "scratch": LinearHead(input_dim, num_classes).to(device),
        "weighted": WeightedEnsemble(input_dim, num_classes).to(device),
        "concat": ConcatHead(input_dim, num_classes).to(device),
        "residual": ResidualModel(input_dim, num_classes).to(device)
    }
    
    optimizers = {name: AdamW(m.parameters(), lr=CONFIG['lr'], weight_decay=1e-2) for name, m in models.items()}
    loss_fn = nn.CrossEntropyLoss()
    
    # Record best states
    best_states = {name: copy.deepcopy(m.state_dict()) for name, m in models.items()}
    best_val_accs = {name: -1.0 for name in models.keys()}
    
    # Training loop
    for epoch in range(CONFIG['epochs']):
        # --- Train ---
        for m in models.values(): m.train()
        for hs, bb, y in train_loader:
            hs, bb, y = hs.to(device), bb.to(device), y.to(device)
            for name, model in models.items():
                opt = optimizers[name]
                opt.zero_grad()
                out = model(hs) if name == "scratch" else model(hs, bb)
                loss = loss_fn(out, y)
                loss.backward()
                opt.step()
        
        # --- Validation ---
        for name, model in models.items():
            model.eval()
            with torch.no_grad():
                out = model(val_hs) if name == "scratch" else model(val_hs, val_bb)
                acc = accuracy(out.argmax(1).cpu().numpy(), val_y.cpu().numpy())
                
                # Save best model
                if acc > best_val_accs[name]:
                    best_val_accs[name] = acc
                    best_states[name] = copy.deepcopy(model.state_dict())

    # Load best weights
    for name, model in models.items():
        model.load_state_dict(best_states[name])
        
    return models, best_val_accs, bb_val_acc

# ==============================================================================
# Modification 2: Evaluation function adds Safe Selection logic
# ==============================================================================
def evaluate_safe(models, test_ds, val_accs, bb_val_acc, val_ds=None):
    """
    Evaluate all models and perform Safe Selection for Residual.
    Also add Val-Tuned Weighted Baseline to address reviewer comments.
    """
    device = CONFIG['device']
    test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

    results = {k: [] for k in ["bb_only", "scratch", "weighted", "concat", "residual_raw", "residual_safe", "weighted_val_tuned"]}
    fallback_counter = 0
    total_batches = 0

    # Pre-determine whether to fallback (Safe Selection Decision)
    # If Residual performs worse than BB on validation set, fallback
    should_fallback = val_accs["residual"] < bb_val_acc

    # === New: Search for best Alpha on Val set (Fair Weighted Baseline) ===
    best_alpha = 0.0
    if val_ds is not None:
        # Load entire validation set
        val_loader_full = DataLoader(val_ds, batch_size=len(val_ds), shuffle=False)
        val_batch = next(iter(val_loader_full))
        val_hs, val_bb, val_y = [t.to(device) for t in val_batch]

        scratch_model = models["scratch"]
        scratch_model.eval()

        best_val_score = -1.0

        with torch.no_grad():
            scratch_logits = scratch_model(val_hs)
            # Grid search
            for alpha in [0.0, 0.2, 0.4, 0.5, 0.6, 0.8, 1.0]:
                # Simple linear weighting of logits: alpha * BB + (1-alpha) * Scratch
                mixed_logits = alpha * val_bb + (1 - alpha) * scratch_logits
                acc = accuracy(mixed_logits.argmax(1).cpu().numpy(), val_y.cpu().numpy())
                if acc > best_val_score:
                    best_val_score = acc
                    best_alpha = alpha
        print(f"    [Weighted] Best Val-Tuned Alpha: {best_alpha}")

    for m in models.values(): m.eval()
    
    with torch.no_grad():
        for hs, bb, y in test_loader:
            hs, bb, y = hs.to(device), bb.to(device), y.to(device)
            total_batches += 1
            
            # 1. BB Only
            bb_acc = accuracy(bb.argmax(1).cpu().numpy(), y.cpu().numpy())
            results["bb_only"].append(bb_acc)
            
            # 2. Baselines
            for name in ["scratch", "weighted", "concat"]:
                if name == "scratch": logits = models[name](hs)
                else: logits = models[name](hs, bb)
                results[name].append(accuracy(logits.argmax(1).cpu().numpy(), y.cpu().numpy()))
            
            # 3. Residual Raw (Always use Residual prediction)
            res_logits = models["residual"](hs, bb)
            res_raw_acc = accuracy(res_logits.argmax(1).cpu().numpy(), y.cpu().numpy())
            results["residual_raw"].append(res_raw_acc)
            
            # 4. Residual Safe (Based on validation set decision)
            if should_fallback:
                results["residual_safe"].append(bb_acc) # Use BB results
                fallback_counter += 1
            else:
                results["residual_safe"].append(res_raw_acc) # Use Residual results

            # 5. === New: Weighted Val-Tuned ===
            scratch_logits_test = models["scratch"](hs)
            # Use the best_alpha found earlier
            w_logits = best_alpha * bb + (1 - best_alpha) * scratch_logits_test
            
    # Average
    final_metrics = {k: sum(v)/len(v) for k, v in results.items()}
    
    # Record fallback rate (0.0 or 1.0, as this is Dataset level decision)
    final_metrics["fallback_rate"] = 1.0 if should_fallback else 0.0
    
    return final_metrics

def main():
    set_global_seed(CONFIG['seed'])
    
    # 1. Prepare model
    tokenizer, model = load_qwen_model()
    
    # 2. Prepare test set cache
    print("\n=== Phase 1: Caching Test Set ===")
    raw_test_ds = get_dataset(CONFIG['dataset_name'], "test")
    # For demonstration speed, take first 500 test samples, remove slicing for full run
    raw_test_ds = raw_test_ds.select(range(500)) 
    test_ds_cached = extract_features_and_cache(model, tokenizer, raw_test_ds, desc="Test Set")
    
    # 3. Experiment loop
    logger = CSVLogger("outputs", f"qwen_{CONFIG['dataset_name']}_safe_results.csv")
    
    print("\n=== Phase 2: Caching Training Pool ===")
    max_n = max(CONFIG['labeled_sizes'])
    raw_train_pool = get_dataset(CONFIG['dataset_name'], "train", n_samples=max_n, seed=CONFIG['seed'])
    train_pool_cached = extract_features_and_cache(model, tokenizer, raw_train_pool, desc="Train Pool")
    
    input_dim = train_pool_cached[0][0].shape[0]
    print(f"[Info] Feature Dimension: {input_dim}")
    
    print("\n=== Phase 3: Running Experiments (With Safe Selection) ===")
    for n in CONFIG['labeled_sizes']:
        print(f"\n>>> n = {n}")
        
        # 1. Get n samples
        subset_indices = list(range(n))
        dataset_n = torch.utils.data.Subset(train_pool_cached, subset_indices)
        
        # 2. Split Train / Val (Key step!)
        # If very few samples (<=32), use 50% for validation, otherwise use 20%
        val_size = int(n * 0.5) if n <= 32 else int(n * 0.2)
        train_size = n - val_size
        
        # Ensure at least 1 validation sample
        if val_size < 1: val_size = 1; train_size = n - 1
            
        train_ds, val_ds = random_split(
            dataset_n, [train_size, val_size], 
            generator=torch.Generator().manual_seed(CONFIG['seed'])
        )
        
        print(f"    Data Split: Train={len(train_ds)}, Val={len(val_ds)}")
        
        # 3. Train and get validation metrics
        models, val_accs, bb_val_acc = train_and_validate(train_ds, val_ds, input_dim, CONFIG['num_classes'])
        
        print(f"    [Validation] BB: {bb_val_acc:.4f} | Residual: {val_accs['residual']:.4f}")
        
        # 4. Evaluate (Safe)
        scores = evaluate_safe(models, test_ds_cached, val_accs, bb_val_acc, val_ds=val_ds)
        
        # Print results
        print(f"--- Results (n={n}) ---")
        print(f"BB Only:       {scores['bb_only']:.4f}")
        print(f"Residual Raw:  {scores['residual_raw']:.4f}")
        print(f"Residual Safe: {scores['residual_safe']:.4f}  <-- MAIN RESULT")
        print(f"Fallback Rate: {scores['fallback_rate']:.1f}")
            
        # Record
        log_entry = {"n": n}
        log_entry.update(scores)
        logger.log(log_entry)
        logger.save()

if __name__ == "__main__":
    main()
