import torch
import os
import random
import numpy as np
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from datasets import load_dataset
from typing import List, Dict, Any, Optional
import pickle
import pathlib

def read_lrp_file(file_path="./data/lrp_masks.pkl"):
    """Read new format LRP data file"""
    file_path = pathlib.Path(file_path)
    
    if not file_path.exists():
        print(f"File {file_path} does not exist")
        return []
        
    with open(file_path, 'rb') as f:
        try:
            # Directly read samples_data list
            samples_data = pickle.load(f)
            print(f"Successfully read {len(samples_data)} LRP and activation samples")
            return samples_data
        except Exception as e:
            print(f"Read error: {e}")
            return []

def read_mixed_lrp_file(seed=58, sample_config=None, shuffle=True):
    prefix = "xxx/project/DISP/"
    
    if sample_config is None:
        sample_config = {
            #"alpaca": ("alpaca/lrp_train_ppl.pkl", 625),
            # "c4": ("c4/lrp_train_ppl.pkl", 1000),
            "wikitext": ("wikitext/lrp_train_ppl.pkl", 1403),
            #"pajama": ("pajama/lrp_train_ppl.pkl", 1000),
            # "commonsense_qa": ("commonsense_qa/lrp_train_ppl.pkl", 500),
            # "arc-c": ("arc-c/lrp_train_ppl.pkl", 1000), # 1119
            # "arc-e": ("arc-e/lrp_train_ppl.pkl", 2251), # 2251
            # "piqa": ("piqa/lrp_train_ppl.pkl", 9000), # 16000
            # "winogrande": ("winogrande/lrp_train_ppl.pkl", 200), # ?
            # "hellaswag": ("hellaswag/lrp_train_ppl.pkl", 9000), # 10000/39905
        }

    all_samples = []

    for name, (rel_path, sample_size) in sample_config.items():
        full_path = os.path.join(prefix, rel_path)
        with open(full_path, 'rb') as f:
            data = pickle.load(f)
        print(f"{name}: {len(data)}")
        sampled_data = random.Random(seed).sample(data, sample_size)
        all_samples.extend(sampled_data)
        print(f"{name} loaded and sampled {sample_size} items")

    if shuffle:
        random.Random(seed).shuffle(all_samples)

    print(f"Successfully read {len(all_samples)} LRP and activation samples")
    return all_samples



# ----------------------------- Alpaca -----------------------------
def build_alpaca_ids(tokenizer, split: str = "train", add_special_tokens: bool = True):
    ds = load_dataset("tatsu-lab/alpaca", split=split)
    texts = ds["text"][:10000]
    full_text = "\n\n".join(texts)
    return tokenizer(full_text,
                     return_tensors="pt",
                     add_special_tokens=add_special_tokens).input_ids.squeeze(0)

# ----------------------------- C4 -----------------------------
def build_c4_ids(tokenizer, split: str = "train", add_special_tokens: bool = True):
    if split == "train":
        ds = load_dataset("json", data_files="xxx/cache/huggingface/datasets/c4-train.00000-of-01024.json.gz", split="train")
    else:
        ds = load_dataset("json", data_files="xxx/cache/huggingface/datasets/c4-validation.00000-of-00008.json.gz", split="validation")
    if split == "train":
        texts = ds["text"][:5000]
    else:
        texts = ds["text"]
    full_text = "\n\n".join(texts)
    return tokenizer(full_text,
                     return_tensors="pt",
                     add_special_tokens=add_special_tokens).input_ids.squeeze(0)


# ----------------------------- WikiText -----------------------------
def build_wikitext_ids(tokenizer, split: str = "test", add_special_tokens: bool = True):
    ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
    full_text = "\n\n".join(ds["text"])
    # Convert to [total_tokens] shape, remove batch dimension for easier slicing
    return tokenizer(full_text,
                     return_tensors="pt",
                     add_special_tokens=add_special_tokens).input_ids.squeeze(0)

# ----------------------------- PTB -----------------------------
def build_ptb_ids(tokenizer, split: str = "train", add_special_tokens: bool = True):
    ds = load_dataset("ptb_text_only", "penn_treebank", split=split)

    field = "sentence" if "sentence" in ds.column_names else "text"
    full_text = "\n\n".join(ds[field])
    return tokenizer(full_text,
                     return_tensors="pt",
                     add_special_tokens=add_special_tokens).input_ids.squeeze(0)

# ----------------------------- Red Pajama -----------------------------
def build_redpajama_ids(tokenizer, split: str = "train", add_special_tokens: bool = True):
    # read xxx/cache/huggingface/datasets/slimpajama-001.parquet
    df = pd.read_parquet("xxx/cache/huggingface/datasets/slimpajama-001.parquet")
    texts = df["text"][:2000]
    full_text = "\n\n".join(texts)
    return tokenizer(full_text,
                     return_tensors="pt",
                     add_special_tokens=add_special_tokens).input_ids.squeeze(0)

# ----------------------------- Mixed -----------------------------
# mix from wikitext, c4, alpaca
def build_mixed_ids(tokenizer, split: str = "train", add_special_tokens: bool = True):
    ds_wikitext = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
    ds_c4 = load_dataset("json", data_files="xxx/cache/huggingface/datasets/c4-train.00000-of-01024.json.gz", split="train")
    ds_alpaca = load_dataset("tatsu-lab/alpaca", split=split)
    # take 200 samples from each
    ds_wikitext = ds_wikitext.select(range(200))
    ds_c4 = ds_c4.select(range(200))
    ds_alpaca = ds_alpaca.select(range(200))
    # concat
    full_text = "\n\n".join(ds_wikitext["text"]) + "\n\n".join(ds_c4["text"]) + "\n\n".join(ds_alpaca["text"])
    return tokenizer(full_text,
                     return_tensors="pt",
                     add_special_tokens=add_special_tokens).input_ids.squeeze(0)
    

def sample_wikitext_sequences(
        input_ids: torch.Tensor,
        seqlen: int = 2048,
        n: int = 32,
        random_sample: bool = True,
        start_idx: int = 0):
    """
    Returns sample tensor with shape [n, seqlen].
    input_ids only needs to be prepared once, subsequent calls are very fast.
    """
    total_tokens = input_ids.size(0)
    num_segments = total_tokens // seqlen
    if num_segments == 0:
        raise ValueError(f"seqlen {seqlen} > corpus ({total_tokens} tokens)")

    segments = input_ids.unfold(0, seqlen, seqlen)

    if n is None:
        return segments

    if random_sample:
        idx = torch.randint(num_segments, (n,))
    else:
        end = start_idx + n
        idx = torch.arange(start_idx, end) % num_segments  # Auto wrap

    return segments[idx]        # shape: [n, seqlen]

def calculate_perplexity(model, input_ids, limit_length=2048, device="cuda", autocast_dtype=torch.float16):
    model.eval()
    
    # Move inputs to device
    inputs = input_ids.to(device)
    
    # Handle too long sequences
    if inputs.shape[1] > limit_length:
        print(f"Warning: Input length {inputs.shape[1]} is too long, truncating to {limit_length} tokens")
        inputs = inputs[:, :limit_length]
    
    with torch.inference_mode(), torch.cuda.amp.autocast(dtype=autocast_dtype):
        # Forward pass through the model
        outputs = model(inputs)
        lm_logits = outputs.logits
        
        # Shift logits and labels for next token prediction
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:]
        
        # Compute loss
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), 
                      shift_labels.reshape(-1))
        # Negative log likelihood
        nll = loss.float() * shift_labels.numel() 

        # Calculate perplexity
        # ppl = torch.exp(loss.float())
    
    return nll

def calculate_perplexity_with_label(model, input_ids, label_pos=None, device='cuda'):
    model.eval()
    data_type = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    
    with torch.no_grad():
        # Use autocast to ensure data type consistency
        with torch.cuda.amp.autocast(dtype=data_type):
            outputs = model(input_ids)
            logits = outputs.logits if hasattr(outputs, 'logits') else outputs
            
            # Loss calculation logic remains unchanged
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            
            loss_fct = nn.CrossEntropyLoss(reduction='none')
            losses = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), 
                            shift_labels.reshape(-1))
            losses = losses.view(shift_labels.size())
            
            if label_pos is not None:
                start_pos = max(0, int(label_pos) - 1) if isinstance(label_pos, torch.Tensor) else max(0, label_pos - 1)
                if start_pos < losses.size(1):
                    loss = losses[0, start_pos:].mean()
                else:
                    loss = losses[0].mean()
            else:
                loss = losses[0].mean()
            
            return loss

def load_mc_dataset(dataset_name, split="test"):
    """Simplified multi-choice dataset loading function"""
    
    print(f"Loading {dataset_name} dataset ({split} split)...")
    
    name = dataset_name.lower()
    
    # Simple mapping table
    configs = {
        "winogrande": ("winogrande", "winogrande_debiased", "validation" if split=="test" else "train"),
        "arc-e": ("ai2_arc", "ARC-Easy", split),
        "arc-c": ("ai2_arc", "ARC-Challenge", split), 
        "hellaswag": ("Rowan/hellaswag", None, "validation" if split=="test" else "train"),
        "rowan/hellaswag": ("Rowan/hellaswag", None, "validation" if split=="test" else "train"),
        "piqa": ("ybisk/piqa", None, "validation" if split=="test" else "train"),
        "ybisk/piqa": ("ybisk/piqa", None, "validation" if split=="test" else "train"),
        "commonsense_qa": ("commonsense_qa", None, split),
        "boolq": ("boolq", None, "validation" if split=="test" else "train"),
        "obqa": ("openbookqa", "main", split),
        "openbookqa": ("openbookqa", "main", split)
    }
    
    if name in configs:
        path, config, actual_split = configs[name]
        if config:
            dataset = load_dataset(path, config, split=actual_split, 
                                 trust_remote_code=(name=="winogrande"))
        else:
            dataset = load_dataset(path, split=actual_split)
    else:
        dataset = load_dataset(dataset_name, split=split)
    
    print(f"Loaded {dataset_name}: {len(dataset)} examples")
    return dataset

def format_mc_example(example, dataset_name):
    if "winogrande" in dataset_name.lower():
        sent    = example["sentence"]
        cut     = sent.index("_")
        context_prefix = sent[:cut]          # Before "_"
        target_suffix = " " + sent[cut+1:].strip()   # After "_" (with space)

        return {
            "context_prefix": context_prefix,        # First half of sentence
            "target_suffix":  target_suffix,         # Second half of sentence
            "options": [" " + example["option1"],
                        " " + example["option2"]],
            "label": int(example["answer"]) - 1,
        }
        
    elif "arc" in dataset_name.lower():
        # ARC has a question and multiple answer choices
        context = example["question"]
        options = example["choices"]["text"]
        label = example["choices"]["label"].index(example["answerKey"])
        
        # Format as question followed by each option
        contexts = []
        for option in options:
            contexts.append(f"{context}\nAnswer: {option}")
        
        return {
            "contexts": contexts,
            "options": options,
            "label": label,
            "original_text": context,
            "question": context
        }
    
    elif "hellaswag" in dataset_name.lower():
        # HellaSwag might use "ctx" instead of "context"
        try:
            context = example["context"]
        except KeyError:
            context = example.get("ctx", example.get("activity_label", ""))
    
        options = example["endings"]
        label = int(example["label"])

        # Format as context followed by each ending
        contexts = []
        for option in options:
            contexts.append(f"{context} {option}")
        
        return {
            "contexts": contexts,
            "options": options,
            "label": label,
            "original_text": context,
            "question": context
        }
    
    elif "piqa" in dataset_name.lower():
        # PIQA has a goal and two solutions
        context = example["goal"]
        options = [example["sol1"], example["sol2"]]
        label = int(example["label"])
        
        # Format as goal followed by each solution
        contexts = []
        for option in options:
            contexts.append(f"{context}\nAnswer: {option}")
        
        return {
            "contexts": contexts,
            "options": options,
            "label": label,
            "original_text": context,
            "question": context
        }
    elif "commonsense_qa" in dataset_name.lower():
        # {'id': '075e483d21c29a511267ef62bedc0461','question': 'The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?', 'question_concept': 'punishing','choices': {'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['ignore', 'enforce', 'authoritarian', 'yell at', 'avoid']}, 'answerKey': 'A'}
        context = example["question"]
        options = example["choices"]["text"]
        label = example["choices"]["label"].index(example["answerKey"])

        # Format as question followed by each option
        contexts = []
        for option in options:
            contexts.append(f"{context}\nAnswer: {option}")
        
        return {
            "contexts": contexts,
            "options": options,
            "label": label,
            "original_text": context,
            "question": context
        }
    elif "boolq" in dataset_name.lower():
        # BoolQ is a binary classification task
        passage = example["passage"]
        question = example["question"]
        answer = example["answer"]  # True or False
        
        # Format as passage + question with True/False options
        context = f"{passage}\n\nQuestion: {question}"
        options = ["False", "True"]
        label = 1 if answer else 0
        
        contexts = []
        for option in options:
            contexts.append(f"{context}\nAnswer: {option}")
            # contexts.append(f"{context}{option}")
        
        return {
            "contexts": contexts,
            "options": options,
            "label": label,
            "original_text": context,
            "question": context
        }

    elif "obqa" in dataset_name.lower() or "openbookqa" in dataset_name.lower():
        # OpenBookQA has question_stem and choices
        question = example["question_stem"]
        choices = example["choices"]
        options = choices["text"]
        label = choices["label"].index(example["answerKey"])
        
        # Format as question followed by each option
        contexts = []
        for option in options:
            contexts.append(f"{question}\nAnswer: {option}")
        
        return {
            "contexts": contexts,
            "options": options,
            "label": label,
            "original_text": question,
            "question": question
        }
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

def get_mc_context_for_features(formatted_example, mc_dataset_name):
    """Extract context text from MC dataset sample"""
    if mc_dataset_name.lower() == "winogrande":
        # WinoGrande uses the part before the blank
        return formatted_example["context_prefix"]
    else:
        return formatted_example.get("question",
                                   formatted_example.get("original_text",
                                                        formatted_example["contexts"][0]))

# use acc_norm in https://blog.eleuther.ai/multiple-choice-normalization/
def calculate_sequence_log_prob(input_ids, logits, start_idx=0, tokenizer=None, return_byte_length=False):
    # Shift logits and input_ids for next token prediction
    shift_logits = logits[:, :-1, :].contiguous()
    shift_input_ids = input_ids[:, 1:].contiguous()
    
    # Get log probabilities
    log_probs = F.log_softmax(shift_logits, dim=-1)
    
    # Extract the log probability of the next token
    token_log_probs = log_probs.gather(-1, shift_input_ids.unsqueeze(-1)).squeeze(-1)
    
    # Consider only tokens from start_idx onwards
    if start_idx > 0:
        token_log_probs = token_log_probs[:, start_idx-1:]
        relevant_tokens = shift_input_ids[:, start_idx-1:]
    else:
        relevant_tokens = shift_input_ids
    
    # Sum log probabilities over the sequence
    seq_log_prob = token_log_probs.sum()
    num_tokens = token_log_probs.shape[1]
    
    # Calculate byte length if requested
    if return_byte_length and tokenizer is not None:
        # Decode tokens to get the text
        text = tokenizer.decode(relevant_tokens[0], skip_special_tokens=True)
        byte_length = len(text.encode('utf-8'))
        return seq_log_prob.item(), num_tokens, byte_length
    
    return seq_log_prob.item(), num_tokens


def format_mc_prompt_with_ans(formatted_example, method="per_input", include_task_instruction=True):
    dataset_name = formatted_example.get("dataset_name", "")
    
    # 1. Task instruction (optional)
    task_instructions = {
        "winogrande": "You are an expert in English language and commonsense reasoning. Fill in the blank with the most appropriate option.",
        "arc": "You are an expert in science. Select the correct answer to the science question.",
        "hellaswag": "You are an expert in commonsense reasoning. Select the most plausible continuation.",
        "piqa": "You are an expert in practical problem-solving. Select the best solution.",
        "commonsense_qa": "You are an expert in commonsense reasoning. Select the most reasonable answer.",
        "boolq": "You are an expert in reading comprehension. Answer the following question with True or False based on the given passage.",
        "obqa": "You are an expert in science and reasoning. Select the correct answer to the question."
    }
    
    # 2. Build context
    if "winogrande" in dataset_name.lower():
        ctx_pref = formatted_example["context_prefix"]
        tgt_suf  = formatted_example["target_suffix"]
        options = formatted_example["options"]  # [" opt1", " opt2"]
        label = formatted_example["label"]
        
        if include_task_instruction:
            task_inst = task_instructions.get("winogrande", "")
            context_parts = [
                task_inst,
                "",  # Empty line
                f"Sentence: {ctx_pref} ___ {tgt_suf}",
                f"Options: A){options[0]} B){options[1]}",
                "Answer:"
            ]
        else:
            context_parts = [
                f"Sentence: {ctx_pref} ___ {tgt_suf}",
                f"Options: A){options[0]} B){options[1]}",
                "Answer:"
            ]
        
        context_only = "\n".join(context_parts)
        answer_text = f" {chr(65 + label)}"  # " A" or " B"
        
    else:
        # Other datasets (ARC, HellaSwag, PIQA, CommonsenseQA, etc.)
        context = formatted_example["question"]
        options = formatted_example["options"]
        label = formatted_example["label"]
        
        # Determine task instruction
        task_key = None
        for key in task_instructions:
            if key in dataset_name.lower():
                task_key = key
                break
        
        if include_task_instruction and task_key:
            task_inst = task_instructions[task_key]
            context_parts = [
                task_inst,
                "",  # Empty line
                f"Question: {context}",
                "Options:"
            ]
        else:
            context_parts = [
                f"Question: {context}",
                "Options:"
            ]
        
        # Add options
        for i, option in enumerate(options):
            context_parts.append(f"{chr(65 + i)}) {option}")
        
        context_parts.append("Answer:")
        context_only = "\n".join(context_parts)
        answer_text = f" {chr(65 + label)}"  # " A", " B", " C", etc.
    
    full_context = context_only + answer_text
    
    return full_context, answer_text, context_only


def get_prompt_and_answer_position(formatted_example, tokenizer, method="per_input", include_task_instruction=True):
    
    if method == "per_task":
        # Per-task: only use task instruction
        _, _, context_only = format_mc_prompt_with_ans(
            formatted_example, 
            method="per_task", 
            include_task_instruction=include_task_instruction
        )
        # For per-task, only use task instruction part to generate mask
        task_instruction_only = context_only.split("\n\n")[0]  # Only take task instruction
        return task_instruction_only, None, None
        
    else:
        # Per-input: use complete prompt
        full_context, answer_text, context_only = format_mc_prompt_with_ans(
            formatted_example, 
            method="per_input", 
            include_task_instruction=include_task_instruction
        )
        
        # Calculate answer position
        context_tokens = tokenizer(context_only, add_special_tokens=False)['input_ids']
        label_pos = len(context_tokens)
        
        # Generate complete input_ids
        full_input_ids = tokenizer(full_context, 
                                 return_tensors="pt", 
                                 truncation=True, 
                                 max_length=2048).input_ids
        
        return full_context, full_input_ids, label_pos


def evaluate_mc_example(model, tokenizer, formatted_example, device="cuda", max_length=2048):
    model.eval()
    dataset_name = formatted_example.get("dataset_name", "")
    
    # Calculate log probabilities for each option
    log_probs = []
    normalized_log_probs = []

    if "winogrande" in dataset_name.lower():
        ctx_pref = formatted_example["context_prefix"]
        tgt_suf  = formatted_example["target_suffix"]
        options  = formatted_example["options"]          # [" opt1", " opt2"]
        label    = formatted_example["label"]

        log_ps, tok_cnts = [], []

        for opt in options:
            full_ctx   = ctx_pref + opt                  # First half + candidate word
            ids_full   = tokenizer(full_ctx + tgt_suf,
                                   add_special_tokens=False,
                                   return_tensors="pt").input_ids.to(device)
            ctx_len    = len(tokenizer(full_ctx,
                                       add_special_tokens=False).input_ids)

            with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
                logits = model(ids_full).logits

            lp, cnt  = calculate_sequence_log_prob(ids_full, logits,
                                                   start_idx=ctx_len)
            log_ps.append(lp)       # log p(target | context)
            tok_cnts.append(cnt)    # target token count

        # Convert list to numpy for element-wise operations
        log_ps   = np.array(log_ps,   dtype=np.float32)
        tok_cnts = np.array(tok_cnts, dtype=np.float32)
        norm_lp  = log_ps / tok_cnts          # acc_norm uses token-length normalization

        prediction            = int(log_ps.argmax())
        normalized_prediction = int(norm_lp.argmax())

        is_correct            = (prediction            == label)
        is_correct_normalized = (normalized_prediction == label)

    else:
        contexts = formatted_example["contexts"]
        options = formatted_example["options"]
        label = formatted_example["label"]
        
        with torch.no_grad():
            for i, context in enumerate(contexts):
                # Tokenize input
                # inputs = tokenizer(context, return_tensors="pt", truncation=True, max_length=max_length)
                inputs = tokenizer(context, return_tensors="pt")
                input_ids = inputs.input_ids.to(device)
                attention_mask = inputs.attention_mask.to(device)
                
                # Forward pass through the model
                with torch.cuda.amp.autocast(dtype=torch.float16, enabled=True):
                    outputs = model(input_ids, attention_mask=attention_mask)
                    logits = outputs.logits
                
                    # For other datasets, calculate answer-only log probability
                    # First, find where the answer part starts
                    question = formatted_example["question"]
                    
                    # Tokenize question to find where answer starts
                    question_tokens = tokenizer(question, add_special_tokens=True)
                    question_length = len(question_tokens.input_ids)
                    
                    # Calculate full sequence log prob for acc
                    full_log_prob, _ = calculate_sequence_log_prob(input_ids, logits)
                    log_probs.append(full_log_prob)
                    
                    # Calculate answer-only log prob with byte length for acc_norm
                    answer_log_prob, answer_tokens, answer_byte_length = calculate_sequence_log_prob(
                        input_ids, logits, start_idx=question_length, 
                        tokenizer=tokenizer, return_byte_length=True
                    )
                    
                    # Normalize by byte length instead of token count
                    if answer_byte_length > 0:
                        normalized_log_prob = answer_log_prob / answer_byte_length
                    else:
                        normalized_log_prob = answer_log_prob
                    
                    normalized_log_probs.append(normalized_log_prob)
    
        # Make predictions
        prediction = np.argmax(log_probs)
        normalized_prediction = np.argmax(normalized_log_probs)
            
        # Check if predictions are correct
        is_correct = (prediction == label)
        is_correct_normalized = (normalized_prediction == label)

    
    return {
        "prediction": prediction,
        "normalized_prediction": normalized_prediction,
        "is_correct": is_correct,
        "is_correct_normalized": is_correct_normalized,
        "log_probs": log_probs,
        "normalized_log_probs": normalized_log_probs,
        "label": label
    }

def evaluate_mc_dataset(model, tokenizer, dataset_name, device="cuda", num_examples=None, split="test"):
    # Load dataset
    dataset = load_mc_dataset(dataset_name, split)
    
    # Limit number of examples if specified
    if num_examples is not None:
        dataset = dataset.select(range(min(num_examples, len(dataset))))
    
    # Track correct predictions
    correct = 0
    correct_normalized = 0
    total = 0
    
    # Evaluate each example
    for example in dataset:
        formatted_example = format_mc_example(example, dataset_name)
        formatted_example["dataset_name"] = dataset_name  # Add dataset name for evaluation
        
        results = evaluate_mc_example(model, tokenizer, formatted_example, device)
        
        if results["is_correct"]:
            correct += 1
        if results["is_correct_normalized"]:
            correct_normalized += 1
        total += 1
        
        # Print progress
        if total % 10 == 0:
            acc = correct / total
            acc_norm = correct_normalized / total
            print(f"Progress: {total}/{len(dataset)} - acc: {acc:.4f}, acc_norm: {acc_norm:.4f}")
    
    # Calculate final metrics
    acc = correct / total
    acc_norm = correct_normalized / total
    
    results = {
        "dataset": dataset_name,
        "num_examples": total,
        "acc": acc,
        "acc_norm": acc_norm,
        "correct": correct,
        "correct_normalized": correct_normalized
    }
    
    # For WinoGrande, acc and acc_norm now differ due to byte-length normalization
    if "winogrande" in dataset_name.lower():
        results["note"] = "WinoGrande uses byte-length normalization for acc_norm"
    
    return results

def evaluate_mc_examples_batch(
    model, 
    tokenizer, 
    formatted_examples: List[Dict[str, Any]], 
    device: str = "cuda", 
    max_length: int = 2048,
    batch_size: int = 8
) -> List[Dict[str, Any]]:
    """
    Batch version of multi-choice question evaluation function, strictly following the original evaluate_mc_example logic
    
    Args:
        model: Language model
        tokenizer: Tokenizer
        formatted_examples: List of formatted samples
        device: Device
        max_length: Maximum sequence length
        batch_size: Batch size
    
    Returns:
        List[Dict]: Evaluation results for each sample
    """
    model.eval()
    all_results = []
    
    # Group by batch size
    for batch_start in range(0, len(formatted_examples), batch_size):
        batch_end = min(batch_start + batch_size, len(formatted_examples))
        batch_examples = formatted_examples[batch_start:batch_end]
        
        # Process current batch
        batch_results = _evaluate_mc_batch_core(
            model, tokenizer, batch_examples, device, max_length
        )
        
        all_results.extend(batch_results)
    
    return all_results


def _evaluate_mc_batch_core(
    model, 
    tokenizer, 
    batch_examples: List[Dict[str, Any]], 
    device: str,
    max_length: int
) -> List[Dict[str, Any]]:
    """
    Core batch evaluation function, strictly following the original evaluate_mc_example logic
    """
    # Process WinoGrande and other datasets separately
    winogrande_examples = []
    other_examples = []
    
    for i, example in enumerate(batch_examples):
        dataset_name = example.get("dataset_name", "")
        if "winogrande" in dataset_name.lower():
            winogrande_examples.append((i, example))
        else:
            other_examples.append((i, example))
    
    # Initialize result array
    results = [None] * len(batch_examples)
    
    # Process WinoGrande samples
    if winogrande_examples:
        wino_results = _evaluate_winogrande_batch(
            model, tokenizer, winogrande_examples, device
        )
        for (orig_idx, _), result in zip(winogrande_examples, wino_results):
            results[orig_idx] = result
    
    # Process other dataset samples
    if other_examples:
        other_results = _evaluate_other_datasets_batch(
            model, tokenizer, other_examples, device, max_length
        )
        for (orig_idx, _), result in zip(other_examples, other_results):
            results[orig_idx] = result
    
    return results


def _evaluate_winogrande_batch(
    model, 
    tokenizer, 
    indexed_examples: List[tuple], 
    device: str
) -> List[Dict[str, Any]]:
    """
    Batch evaluate WinoGrande samples, strictly following the original logic
    """
    results = []
    
    # Collect all texts to be processed
    all_full_contexts = []
    all_context_lengths = []
    example_mapping = []
    
    for batch_idx, (orig_idx, example) in enumerate(indexed_examples):
        ctx_pref = example["context_prefix"]
        tgt_suf = example["target_suffix"]
        options = example["options"]
        label = example["label"]
        
        # Build full context for each option
        for opt_idx, opt in enumerate(options):
            full_ctx = ctx_pref + opt  # First half + candidate word
            full_context = full_ctx + tgt_suf  # Complete sentence
            
            # Calculate context length (for subsequent target sequence probability calculation)
            ctx_len = len(tokenizer(full_ctx, add_special_tokens=False).input_ids)
            
            all_full_contexts.append(full_context)
            all_context_lengths.append(ctx_len)
            example_mapping.append({
                'batch_idx': batch_idx,
                'orig_idx': orig_idx,
                'option_idx': opt_idx,
                'label': label
            })
    
    # Batch tokenize all texts
    all_input_ids = []
    max_len = 0
    for context in all_full_contexts:
        tokens = tokenizer(context, add_special_tokens=False, return_tensors="pt")
        input_ids = tokens.input_ids.squeeze(0)  # Remove batch dimension
        all_input_ids.append(input_ids)
        max_len = max(max_len, len(input_ids))
    
    # Manually pad to the same length (this allows controlling padding effects)
    padded_input_ids = []
    attention_masks = []
    for input_ids in all_input_ids:
        # Create attention mask
        attention_mask = torch.ones(len(input_ids), dtype=torch.long)
        
        # Right padding
        pad_length = max_len - len(input_ids)
        if pad_length > 0:
            input_ids = torch.cat([input_ids, torch.full((pad_length,), tokenizer.pad_token_id)])
            attention_mask = torch.cat([attention_mask, torch.zeros(pad_length, dtype=torch.long)])
        
        padded_input_ids.append(input_ids)
        attention_masks.append(attention_mask)
    
    # Convert to batch tensor
    batch_input_ids = torch.stack(padded_input_ids).to(device)
    batch_attention_masks = torch.stack(attention_masks).to(device)
    
    # Batch forward pass - using attention mask
    with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
        outputs = model(batch_input_ids, attention_mask=batch_attention_masks)
        all_logits = outputs.logits
    
    # Calculate log probability for each option
    option_log_probs = []
    option_token_counts = []
    
    for i, (ctx_len, mapping) in enumerate(zip(all_context_lengths, example_mapping)):
        logits = all_logits[i:i+1]  # [1, seq_len, vocab_size]
        input_ids = batch_input_ids[i:i+1]  # [1, seq_len]
        attention_mask = batch_attention_masks[i:i+1]  # [1, seq_len]
        
        # Key modification: only compute log probability for non-padding parts
        actual_length = attention_mask.sum().item()
        actual_input_ids = input_ids[:, :actual_length]
        actual_logits = logits[:, :actual_length]
        
        # Calculate log probability of the target sequence
        log_prob, token_count = calculate_sequence_log_prob(
            actual_input_ids, actual_logits, start_idx=ctx_len
        )
        
        option_log_probs.append(log_prob)
        option_token_counts.append(token_count)
    
    # Reorganize results by sample
    current_batch_idx = -1
    current_log_probs = []
    current_token_counts = []
    current_label = None
    
    for i, mapping in enumerate(example_mapping):
        if mapping['batch_idx'] != current_batch_idx:
            # Process results for the previous sample
            if current_batch_idx >= 0:
                result = _compute_winogrande_result(
                    current_log_probs, current_token_counts, current_label
                )
                results.append(result)
            
            # Start new sample
            current_batch_idx = mapping['batch_idx']
            current_log_probs = []
            current_token_counts = []
            current_label = mapping['label']
        
        current_log_probs.append(option_log_probs[i])
        current_token_counts.append(option_token_counts[i])
    
    # Process the last sample
    if current_batch_idx >= 0:
        result = _compute_winogrande_result(
            current_log_probs, current_token_counts, current_label
        )
        results.append(result)
    
    return results


def _evaluate_other_datasets_batch(
    model, 
    tokenizer, 
    indexed_examples: List[tuple], 
    device: str,
    max_length: int
) -> List[Dict[str, Any]]:
    """
    Batch evaluate other dataset samples, strictly following the original logic
    """
    results = []
    
    # Collect all contexts to be processed
    all_contexts = []
    all_questions = []
    example_mapping = []
    
    for batch_idx, (orig_idx, example) in enumerate(indexed_examples):
        contexts = example["contexts"]
        options = example["options"]
        label = example["label"]
        question = example["question"]
        
        # Prepare data for each context
        for ctx_idx, context in enumerate(contexts):
            all_contexts.append(context)
            all_questions.append(question)
            example_mapping.append({
                'batch_idx': batch_idx,
                'orig_idx': orig_idx,
                'context_idx': ctx_idx,
                'label': label,
                'question': question
            })
    
    # Batch tokenize all contexts
    all_context_inputs = []
    max_context_len = 0
    for context in all_contexts:
        tokens = tokenizer(context, return_tensors="pt", truncation=True, max_length=max_length)
        input_ids = tokens.input_ids.squeeze(0)
        attention_mask = tokens.attention_mask.squeeze(0)
        all_context_inputs.append((input_ids, attention_mask))
        max_context_len = max(max_context_len, len(input_ids))
    
    # Manually pad
    padded_input_ids = []
    padded_attention_masks = []
    for input_ids, attention_mask in all_context_inputs:
        pad_length = max_context_len - len(input_ids)
        if pad_length > 0:
            input_ids = torch.cat([input_ids, torch.full((pad_length,), tokenizer.pad_token_id)])
            attention_mask = torch.cat([attention_mask, torch.zeros(pad_length, dtype=torch.long)])
        
        padded_input_ids.append(input_ids)
        padded_attention_masks.append(attention_mask)
    
    batch_input_ids = torch.stack(padded_input_ids).to(device)
    batch_attention_masks = torch.stack(padded_attention_masks).to(device)
    
    # Calculate log probability for each context
    question_lengths = []
    for question in all_questions:
        question_tokens = tokenizer(question, add_special_tokens=True)
        question_lengths.append(len(question_tokens.input_ids))
    
    # Batch forward pass
    with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
        outputs = model(batch_input_ids, attention_mask=batch_attention_masks)
        all_logits = outputs.logits
    
    # Calculate log probability using correct question length
    context_log_probs = []
    context_normalized_log_probs = []
    
    for i, mapping in enumerate(example_mapping):
        logits = all_logits[i:i+1]  # [1, seq_len, vocab_size]
        input_ids = batch_input_ids[i:i+1]  # [1, seq_len]
        attention_mask = batch_attention_masks[i:i+1]  # [1, seq_len]
        
        # Key modification: only compute non-padding parts
        actual_length = attention_mask.sum().item()
        actual_input_ids = input_ids[:, :actual_length]
        actual_logits = logits[:, :actual_length]
        
        # Calculate log probability of the full sequence (for acc)
        full_log_prob, _ = calculate_sequence_log_prob(actual_input_ids, actual_logits)
        context_log_probs.append(full_log_prob)
        
        # Key modification: use correct question length
        question_length = question_lengths[i]  # Use pre-calculated question length
        
        answer_log_prob, answer_tokens, answer_byte_length = calculate_sequence_log_prob(
            actual_input_ids, actual_logits, start_idx=question_length,
            tokenizer=tokenizer, return_byte_length=True
        )
        
        # Normalize by byte length
        if answer_byte_length > 0:
            normalized_log_prob = answer_log_prob / answer_byte_length
        else:
            normalized_log_prob = answer_log_prob
        
        context_normalized_log_probs.append(normalized_log_prob)
    
    # Reorganize results by sample
    current_batch_idx = -1
    current_log_probs = []
    current_normalized_log_probs = []
    current_label = None
    
    for i, mapping in enumerate(example_mapping):
        if mapping['batch_idx'] != current_batch_idx:
            # Process results for the previous sample
            if current_batch_idx >= 0:
                result = _compute_other_dataset_result(
                    current_log_probs, current_normalized_log_probs, current_label
                )
                results.append(result)
            
            # Start new sample
            current_batch_idx = mapping['batch_idx']
            current_log_probs = []
            current_normalized_log_probs = []
            current_label = mapping['label']
        
        current_log_probs.append(context_log_probs[i])
        current_normalized_log_probs.append(context_normalized_log_probs[i])
    
    # Process the last sample
    if current_batch_idx >= 0:
        result = _compute_other_dataset_result(
            current_log_probs, current_normalized_log_probs, current_label
        )
        results.append(result)
    
    return results


def _compute_winogrande_result(log_probs: List[float], token_counts: List[int], label: int) -> Dict[str, Any]:
    """
    Calculate WinoGrande sample results, strictly following the original logic
    """
    # Convert to numpy array
    log_ps = np.array(log_probs, dtype=np.float32)
    tok_cnts = np.array(token_counts, dtype=np.float32)
    
    # acc_norm uses token-length normalization
    norm_lp = log_ps / tok_cnts
    
    # Predictions
    prediction = int(log_ps.argmax())
    normalized_prediction = int(norm_lp.argmax())
    
    # Correctness check
    is_correct = (prediction == label)
    is_correct_normalized = (normalized_prediction == label)
    
    return {
        "prediction": prediction,
        "normalized_prediction": normalized_prediction,
        "is_correct": is_correct,
        "is_correct_normalized": is_correct_normalized,
        "log_probs": log_ps.tolist(),
        "normalized_log_probs": norm_lp.tolist(),
        "label": label
    }


def _compute_other_dataset_result(log_probs: List[float], normalized_log_probs: List[float], label: int) -> Dict[str, Any]:
    """
    Calculate other dataset sample results, strictly following the original logic
    """
    # Convert to numpy array
    log_probs = np.array(log_probs, dtype=np.float32)
    normalized_log_probs = np.array(normalized_log_probs, dtype=np.float32)
    
    # Predictions
    prediction = int(log_probs.argmax())
    normalized_prediction = int(normalized_log_probs.argmax())
    
    # Correctness check
    is_correct = (prediction == label)
    is_correct_normalized = (normalized_prediction == label)
    
    return {
        "prediction": prediction,
        "normalized_prediction": normalized_prediction,
        "is_correct": is_correct,
        "is_correct_normalized": is_correct_normalized,
        "log_probs": log_probs.tolist(),
        "normalized_log_probs": normalized_log_probs.tolist(),
        "label": label
    }


def evaluate_mc_dataset_batch(
    model, 
    tokenizer, 
    dataset_name: str, 
    device: str = "cuda", 
    num_examples: Optional[int] = None,
    split: str = "test",
    batch_size: int = 8
) -> Dict[str, Any]:
    """
    Batch version of dataset evaluation function, strictly following the original evaluate_mc_dataset logic
    
    Args:
        model: Language model
        tokenizer: Tokenizer
        dataset_name: Dataset name
        device: Device
        num_examples: Limit on number of examples to evaluate
        split: Dataset split
        batch_size: Batch size
    
    Returns:
        Dict: Evaluation results
    """
    # Load dataset
    dataset = load_mc_dataset(dataset_name, split)
    
    # Limit number of examples
    if num_examples is not None:
        dataset = dataset.select(range(min(num_examples, len(dataset))))
    
    # Format examples
    formatted_examples = []
    for example in dataset:
        formatted_example = format_mc_example(example, dataset_name)
        formatted_example["dataset_name"] = dataset_name  # Add dataset_name for evaluation
        formatted_examples.append(formatted_example)
    
    print(f"Starting batch evaluation for {dataset_name} ({len(formatted_examples)} samples, batch_size={batch_size})")
    
    # Batch evaluation
    all_results = evaluate_mc_examples_batch(
        model, tokenizer, formatted_examples, device, 
        batch_size=batch_size
    )
    
    # Calculate overall metrics
    correct = 0
    correct_normalized = 0
    total = 0
    
    for results in all_results:
        if results["is_correct"]:
            correct += 1
        if results["is_correct_normalized"]:
            correct_normalized += 1
        total += 1
        
        # Print progress
        if total % 10 == 0:
            acc = correct / total
            acc_norm = correct_normalized / total
            print(f"Progress: {total}/{len(all_results)} - acc: {acc:.4f}, acc_norm: {acc_norm:.4f}")
    
    # Calculate final metrics
    final_acc = correct / total
    final_acc_norm = correct_normalized / total
    
    results_summary = {
        "dataset": dataset_name,
        "num_examples": total,
        "acc": final_acc,
        "acc_norm": final_acc_norm,
        "correct": correct,
        "correct_normalized": correct_normalized
    }
    
    # Add note for WinoGrande
    if "winogrande" in dataset_name.lower():
        results_summary["note"] = "WinoGrande uses token-length normalization for acc_norm"
    
    print(f"\nBatch evaluation completed - {dataset_name}")
    print(f"ACC: {final_acc:.4f} ({correct}/{total})")
    print(f"ACC_NORM: {final_acc_norm:.4f} ({correct_normalized}/{total})")
    
    return results_summary