import sys
import os
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(project_root)


import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import torch
import random
import numpy as np
import json
import re
import ast
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

from SLICER.SLICER_opti import *            
from SLICER.SLICER_config import SCConfig    

# 1. Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# 2. Load the DeepSeekMath-7B-Instruct model and tokenizer
model_name = "deepseek-ai/deepseek-math-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.float16,   # For efficient GPU memory usage
    device_map={"": 0},          # Use GPU 0 only
).eval()                        # Set model to evaluation mode

# 3. Load the GSM8K math dataset (test split)
gsm8k = load_dataset("gsm8k", "main")["test"]

# 4. Extract the final answer from model output (expecting \boxed{})
def extract_answer_from_text(text):
    """
    Extracts the boxed answer from model output.
    If not found, falls back to the last numeric value in the text.
    """
    match = re.search(r"\\boxed\s*{?\s*(-?\d+(?:\.\d+)?)\s*}?", text)
    if match:
        return match.group(1)
    nums = re.findall(r"-?\d+(?:\.\d+)?", text)
    return nums[-1] if nums else None

# 5. Compare predicted and ground-truth answers numerically
def math_equal(pred, gt):
    """
    Compares two answers as floats (with tolerance), for GSM8K math grading.
    Returns True if the answers are numerically equal within a small tolerance.
    """
    try:
        pred_val = float(ast.literal_eval(re.sub(r"[^\d.\-+eE]", "", str(pred))))
        gt_val = float(ast.literal_eval(re.sub(r"[^\d.\-+eE]", "", str(gt))))
        return abs(pred_val - gt_val) < 1e-4
    except:
        return False

# 6. SLICER-enabled text generation for a single problem
def custom_generate_text(
    model,
    tokenizer,
    sc_config,
    input_ids: torch.Tensor,
    w_bar: int = 200,
    max_new_tokens: int = 256,
    temperature: float = 0.2,
    top_p: float = 0.95,
    device: torch.device = None,
):
    """
    Generates model output with SLICER injected at position w_bar.
    - (1) Generates tokens normally up to w_bar
    - (2) Applies SLICER for a single step at w_bar
    - (3) Continues normal generation until max_new_tokens is reached or EOS
    Returns: (decoded output text, out_slicer, applied_flag)
    """
    input_ids = input_ids.to(device)
    generated = input_ids.clone()
    input_len = input_ids.shape[1]
    past_key_values = None
    applied = False        # Flag: SLICER applied
    out_slicer = None      # Output from SLICER step
    new_tokens = 0

    def decode_slice():
        # Only decode tokens generated after the prompt
        gen_ids = generated[0, input_len:]
        return tokenizer.decode(gen_ids, skip_special_tokens=True)

    # (1) Generate tokens up to w_bar with standard decoding
    while generated.size(1) < w_bar and new_tokens < max_new_tokens:
        cur = generated if past_key_values is None else generated[:, -1:]
        with torch.no_grad():
            out = model(
                input_ids=cur,
                past_key_values=past_key_values,
                use_cache=True,
                return_dict=True,
                output_hidden_states=False,
            )
        past_key_values = out.past_key_values
        logits = out.logits[:, -1, :]

        # Greedy decoding if temperature == 0, otherwise sample
        if temperature == 0.0:
            nxt = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)

        generated = torch.cat([generated, nxt], dim=-1)
        new_tokens += 1
        if nxt.item() == tokenizer.eos_token_id:
            return decode_slice(), None, False

    # (2) Apply SLICER intervention once at w_bar
    if new_tokens < max_new_tokens:
        with torch.no_grad():
            out = model(
                input_ids=generated,
                past_key_values=None,
                use_cache=False,
                return_dict=True,
                output_hidden_states=True,
                sc_config=sc_config,
            )
        applied = True
        out_slicer = out
        logits = out.logits[:, -1, :]

        if temperature == 0.0:
            nxt = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)

        generated = torch.cat([generated, nxt], dim=-1)
        new_tokens += 1
        if nxt.item() == tokenizer.eos_token_id:
            return decode_slice(), out_slicer, applied

    # (3) Continue with normal decoding until end
    past_key_values = None
    while new_tokens < max_new_tokens:
        cur = generated if past_key_values is None else generated[:, -1:]
        with torch.no_grad():
            out = model(
                input_ids=cur,
                past_key_values=past_key_values,
                use_cache=True,
                return_dict=True,
                output_hidden_states=False,
            )
        past_key_values = out.past_key_values
        logits = out.logits[:, -1, :]

        if temperature == 0.0:
            nxt = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)

        generated = torch.cat([generated, nxt], dim=-1)
        new_tokens += 1
        if nxt.item() == tokenizer.eos_token_id:
            break

    return decode_slice(), out_slicer, applied

# 7. Main evaluation loop on the GSM8K test set
def evaluate_gsm8k_slicer(data, save_prefix="deepseek_math_50_slicer"):
    """
    Runs SLICER-enabled generation for each GSM8K test example.
    Evaluates answer correctness and saves both detailed and summary results.
    Args:
        data: list of GSM8K problems (dict with "question", "answer")
        save_prefix: filename prefix for saving results
    Returns:
        results: per-problem evaluation records
        acc: overall accuracy (float)
    """
    results = []
    correct = 0
    sc_config = SCConfig(split_layer=20, s=0.9, lambd=0.0, delta=0.0, Q=[8,8,8], Q_n=[8,8,8])
    max_new_tokens = 512
    w_bar = 200

    for item in tqdm(data):
        q = item["question"]
        a = item["answer"].split("####")[-1].strip()

        # Construct the model prompt (DeepSeekMath formatting)
        prompt = f"English question: {q}\nPlease reason step by step, and put your final answer within \\boxed{{}}."
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

        pred_text, _, _ = custom_generate_text(
            model=model,
            tokenizer=tokenizer,
            sc_config=sc_config,
            input_ids=input_ids,
            w_bar=w_bar,
            max_new_tokens=max_new_tokens,
            temperature=0.0, 
            top_p=1.0,
            device=model.device
        )

        pred = extract_answer_from_text(pred_text)
        is_correct = math_equal(pred, a)

        results.append({
            "question": q,
            "ground_truth": a,
            "prediction": pred,
            "correct": is_correct,
            "response": pred_text
        })

        correct += is_correct

    acc = correct / len(data)
    print(f"\n🎯 Accuracy on {len(data)} examples: {acc*100:.2f}%")

    # Save detailed results for each problem
    with open(f"{save_prefix}_results.jsonl", "w", encoding="utf-8") as f:
        for r in results:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    # Save summary statistics (accuracy etc.)
    with open(f"{save_prefix}_summary.json", "w", encoding="utf-8") as f:
        json.dump({"total": len(data), "correct": correct, "accuracy": acc}, f, indent=2)

    return results, acc

# 8. Script entry point
if __name__ == "__main__":
    results, acc = evaluate_gsm8k_slicer(gsm8k)
