import sys
import os
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(project_root)

import re
import json
import torch
from tqdm import tqdm
from datasets import load_dataset
#from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
from SLICER.SLICER_opti import *            
from SLICER.SLICER_config import SCConfig 

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

# Utils for Data Preparation
def extract_answer_str(ans_str):
    """
    Parse the numeric answer string from GSM8K format.

    Args:
        ans_str (str): The raw answer string from the GSM8K dataset.

    Returns:
        str: The parsed numeric answer.
    """
    return ans_str.strip().split("####")[-1].strip()

def save_gsm8k_jsonl(filepath="gsm8k_test.jsonl"):
    """
    Save the GSM8K test split to a JSONL file with question/answer pairs.

    Args:
        filepath (str): Output file path.
    """
    gsm8k = load_dataset("gsm8k", "main")["test"]
    with open(filepath, "w", encoding="utf-8") as f:
        for item in gsm8k:
            f.write(json.dumps({
                "question": item["question"],
                "answer": extract_answer_str(item["answer"])
            }, ensure_ascii=False) + "\n")
    print(f"Saved test set to: {filepath}")

# Model Loader
def load_llama3(model_name="meta-llama/Meta-Llama-3-8B-Instruct"):
    """
    Load the LLaMA3 model and tokenizer from Hugging Face hub.

    Args:
        model_name (str): Model name or path.

    Returns:
        tokenizer, model: The loaded tokenizer and model in evaluation mode.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16
    )
    model.eval()
    return tokenizer, model

# Answer Parsing and Matching Utils
def extract_answer_from_text(text):
    """
    Extract the predicted numeric answer from model output text.

    Args:
        text (str): Model output.

    Returns:
        str: Parsed answer.
    """
    match = re.search(r"\\boxed{([\d\.\-]+)}", text)
    if match:
        return match.group(1)
    match = re.search(r"####\s*(-?[\d\.]+)", text)
    return match.group(1) if match else ""

def math_equal(a, b):
    """
    Compare two answers for numerical equivalence (float or fallback to string).

    Args:
        a (str or float): Predicted answer.
        b (str or float): Ground truth answer.

    Returns:
        bool: Whether the answers are considered equal.
    """
    try:
        return float(a) == float(b)
    except:
        return a.strip() == b.strip()

# Few-shot Prompt for GSM8K
few_shot_prompt = """Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are initially 3 cars. 2 more arrive, so 3 + 2 = 5. #### 5
...
Q: A tree has 8 apples, 4 fall off. How many apples remain?
A: 8 - 4 = 4 apples remaining. #### 4"""

# SLICER-Aware Text Generation
def custom_generate_text(
    model,
    tokenizer,
    sc_config,
    input_ids: torch.Tensor,
    w_bar: int = 200,
    max_new_tokens: int = 256,
    temperature: float = 0.2,
    top_p: float = 0.95,
    device: torch.device = None,
):
    """
    Generate output from the model with SLICER token limitation and post-SLICER continuation.

    Args:
        model: The loaded LLM model.
        tokenizer: The tokenizer for the model.
        sc_config: SLICER configuration object.
        input_ids (torch.Tensor): Input prompt tokens.
        w_bar (int): Number of tokens to generate before applying SLICER.
        max_new_tokens (int): Maximum total new tokens.
        temperature (float): Generation temperature.
        top_p (float): Nucleus sampling parameter.
        device (torch.device): CUDA device.

    Returns:
        (generated_text, slicer_output, slicer_applied)
    """
    input_ids = input_ids.to(device)
    generated = input_ids.clone()
    input_len = input_ids.shape[1]
    past_key_values = None
    applied = False
    out_slicer = None
    new_tokens = 0

    def decode_slice():
        gen_ids = generated[0, input_len:]
        return tokenizer.decode(gen_ids, skip_special_tokens=True)

    # Phase 1: Standard autoregressive generation until w_bar tokens
    while generated.size(1) < w_bar and new_tokens < max_new_tokens:
        cur = generated if past_key_values is None else generated[:, -1:]
        with torch.no_grad():
            out = model(
                input_ids=cur,
                past_key_values=past_key_values,
                use_cache=True,
                return_dict=True,
                output_hidden_states=False,
            )
        past_key_values = out.past_key_values
        logits = out.logits[:, -1, :]
        if temperature == 0.0:
            nxt = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)
        generated = torch.cat([generated, nxt], dim=-1)
        new_tokens += 1
        if nxt.item() == tokenizer.eos_token_id:
            return decode_slice(), None, False

    # Phase 2: Apply SLICER once at token boundary
    if new_tokens < max_new_tokens:
        with torch.no_grad():
            out = model(
                input_ids=generated,
                past_key_values=None,
                use_cache=False,
                return_dict=True,
                output_hidden_states=True,
                sc_config=sc_config
            )
        applied = True
        out_slicer = out
        logits = out.logits[:, -1, :]
        if temperature == 0.0:
            nxt = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)
        generated = torch.cat([generated, nxt], dim=-1)
        new_tokens += 1
        if nxt.item() == tokenizer.eos_token_id:
            return decode_slice(), out_slicer, applied

    # Phase 3: Continue generation as usual after SLICER
    past_key_values = None
    while new_tokens < max_new_tokens:
        cur = generated if past_key_values is None else generated[:, -1:]
        with torch.no_grad():
            out = model(
                input_ids=cur,
                past_key_values=past_key_values,
                use_cache=True,
                return_dict=True,
                output_hidden_states=False,
            )
        past_key_values = out.past_key_values
        logits = out.logits[:, -1, :]
        if temperature == 0.0:
            nxt = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)
        generated = torch.cat([generated, nxt], dim=-1)
        new_tokens += 1
        if nxt.item() == tokenizer.eos_token_id:
            break

    return decode_slice(), out_slicer, applied

# Model Evaluation on GSM8K Data
def evaluate(jsonl_path, tokenizer, model, few_shot_prompt):
    """
    Evaluate the model on a GSM8K-style QA dataset.

    Args:
        jsonl_path (str): Path to input JSONL file with question/answer pairs.
        tokenizer: The model tokenizer.
        model: The model for generation.
        few_shot_prompt (str): Few-shot prompt string.

    Returns:
        (results, accuracy): List of result dicts and total accuracy.
    """
    correct = 0
    total = 0
    results = []

    # SLICER configuration for this evaluation
    sc_config = SCConfig(
        split_layer=20,
        s=0.9,
        Q=[8, 8, 8],
        Q_n=[8, 8, 8],
        lambd=0.0,
        delta=0.0,
        use_ABSQ=True
    )

    with open(jsonl_path, "r", encoding="utf-8") as f:
        data = [json.loads(line) for line in f]

    for item in tqdm(data, desc="Evaluating"):
        q = item["question"]
        gt = item["answer"]
        prompt = few_shot_prompt.strip() + "\n\nQ: " + q.strip() + "\nA:"
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        pred_text, _, _ = custom_generate_text(
            model, tokenizer, sc_config, inputs["input_ids"],
            w_bar=200, max_new_tokens=256, temperature=0.7, top_p=0.95,
            device=model.device
        )
        pred = extract_answer_from_text(pred_text)
        is_correct = math_equal(pred, gt)
        results.append({
            "question": q,
            "ground_truth": gt,
            "prediction": pred,
            "correct": is_correct,
            "response": pred_text
        })
        correct += is_correct
        total += 1

    acc = correct / total
    print(f"\nAccuracy on {total} examples: {acc*100:.2f}%")
    with open("llama3_gsm8k_results.jsonl", "w", encoding="utf-8") as f:
        for r in results:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")
    with open("llama3_gsm8k_summary.json", "w", encoding="utf-8") as f:
        json.dump({"total": total, "correct": correct, "accuracy": acc}, f, indent=2, ensure_ascii=False)

    return results, acc

# Script Entry Point
if __name__ == "__main__":
    # Save GSM8K test set to JSONL
    save_gsm8k_jsonl("llama3_gsm8k.jsonl")

    # Load LLaMA3 model and tokenizer
    tokenizer, model = load_llama3()

    # Evaluate model on GSM8K test split
    results, acc = evaluate("llama3_gsm8k.jsonl", tokenizer, model, few_shot_prompt)

    # Print up to 3 sample errors for quick inspection
    print("\nSample incorrect predictions:")
    wrong_samples = [r for r in results if not r["correct"]][:3]
    for r in wrong_samples:
        print(f"\nQ: {r['question']}\nPred: {r['prediction']} | GT: {r['ground_truth']}")
        if r['response']:
            print(f"Last line of model response: {r['response'].splitlines()[-1]}")
