import os
import json
import torch
import argparse
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
import re
import sys
import random
import numpy as np


torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
np.random.seed(42)
random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


sys.path.append("/fs/nexus-scratch/hjae/ShadowKV")
from models.llama import LlamaForCausalLM 

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
    parser.add_argument("--token_budget", type=int, default=1024)
    parser.add_argument("--compression_enabled", action="store_true")
    parser.add_argument("--compression_threshold", type=int, default=128)
    parser.add_argument("--compression_ratio", type=float, default=0.5)
    parser.add_argument("--window_size", type=int, default=512)  
    parser.add_argument("--max_samples", type=int, default=100)
    parser.add_argument("--output_dir", type=str, default="results")
    return parser.parse_args()

def load_model_and_tokenizer(args):
    print(f"Loading model from {args.model_path}...")
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path,
        use_fast=False
    )
    config = AutoConfig.from_pretrained(args.model_path)
    

    print("Using LlamaForCausalLM with ShadowKV support")
    model = LlamaForCausalLM.from_pretrained(
        args.model_path,
        config=config,
        device_map="auto",
        torch_dtype=torch.float16
    )
    
    if hasattr(model, "shadowkv_init"):
        print("Initializing ShadowKV...")
        model.shadowkv_init(
            window_size=args.window_size,
            max_tokens=args.token_budget,
            compress_ratio=args.compression_ratio if args.compression_enabled else 1.0,
            compress_threshold=args.compression_threshold,
        )
        print(f"ShadowKV initialized with compression_enabled={args.compression_enabled}")
    else:
        print("Warning: ShadowKV is not available for this model")
        if args.compression_enabled:
            print("Compression will be disabled as ShadowKV is not supported")
    
    return model, tokenizer

def format_prompt(question, facts):
    return f"""Given some facts and a related question, answer the question with true or false.

Facts:
{facts}

Question:
{question}

Solve the problem step by step. Answer with true or false and wrap your final answer in "\\boxed{{}}".

Answer:"""

def extract_full_boxed_content(text):
    """Extract all content within \\boxed{} tags."""
    pattern = r'\\boxed\{([^}]*)\}'
    matches = re.findall(pattern, text)
    return matches

def extract_answer(response):
    """
    Parse the answer text to get the answer.
    """
    response = response.strip()

    if response.lower() in ['true', 'false']:
        return response.lower()
    
    if 'boxed{' in response:
        try:
            model_answers = extract_full_boxed_content(response)
            if model_answers:
                try:
                    text_content = re.findall(r'\\text{(.*?)}', model_answers[-1])
                    if text_content:
                        return text_content[-1].strip().lower()
                except Exception:
                    print("Error in extracting text content from boxed answer.")
                return model_answers[-1].strip().lower()
        except Exception:
            print("Error in extracting boxed content.")
            return ""

    for flag in ['final answer is', 'correct answer is', 'answer should be', 'answer is', 'answer:']:
        if flag in response.lower():
            try:
                model_answer = response.lower().split(flag)[-1].strip()
                return model_answer.split('\n')[0].split('.')[0].strip()
            except Exception:
                print("Error in extracting answer from response.")
                return ""
    
    return ""

def accuracy(predictions, answers):
    """
    Calculate accuracy of predictions.
    """
    correct = 0
    total = len(predictions)

    for prediction, answer in zip(predictions, answers):
        if prediction.lower() == answer.lower():
            correct += 1

    return correct / total if total > 0 else 0.0

def load_partial_results(output_file):
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            data = json.load(f)
            return data.get("results", []), data.get("correct", 0), data.get("total", 0)
    return [], 0, 0


def evaluate_strategy_qa(model, tokenizer, dataset, args, output_file):
    results, correct, total = load_partial_results(output_file)
    start_idx = total  
    n_total = len(dataset)
    batch_size = 100
    save_every = 50

    while start_idx < n_total:
        end_idx = min(start_idx + batch_size, n_total)
        subset = dataset.select(range(start_idx, end_idx))
        
        for sample in tqdm(subset, initial=start_idx, total=n_total):
            question = sample["question"]
            facts = sample["facts"]
            answer = sample["answer"]

            prompt = format_prompt(question, facts)
            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

            with torch.no_grad():
                outputs = model.generate(
                    inputs.input_ids,
                    max_new_tokens=256,  
                    temperature=0.1,   
                    top_p=0.9,
                    do_sample=True,
                    pad_token_id=tokenizer.eos_token_id,
                    eos_token_id=tokenizer.eos_token_id
                )

            response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            predicted_answer = extract_answer(response)
            correct_answer = str(answer).lower()

            is_correct = predicted_answer == correct_answer
            if is_correct:
                correct += 1
            total += 1

            results.append({
                "question": question,
                "facts": facts,
                "correct_answer": correct_answer,
                "predicted_answer": predicted_answer,
                "is_correct": is_correct,
                "response": response
            })

            if (total % save_every == 0) or (total == n_total):
                predictions = [r["predicted_answer"] for r in results]
                answers = [r["correct_answer"] for r in results]
                current_accuracy = accuracy(predictions, answers)
                
                with open(output_file, "w") as f:
                    json.dump({
                        "args": vars(args),
                        "accuracy": current_accuracy,
                        "results": results,
                        "correct": correct,
                        "total": total
                    }, f, indent=2)
                print(f"[Checkpoint] Saved {total} results to {output_file} (accuracy: {current_accuracy:.2%})")
        
        start_idx = end_idx
    
    predictions = [r["predicted_answer"] for r in results]
    answers = [r["correct_answer"] for r in results]
    final_accuracy = accuracy(predictions, answers)
    
    return results, final_accuracy

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    model, tokenizer = load_model_and_tokenizer(args)

    print("Loading StrategyQA dataset...")
    dataset = load_dataset("wics/strategy-qa")
    test_set = dataset["test"]
    

    if args.max_samples is not None:
        test_set = test_set.shuffle(seed=42).select(range(min(args.max_samples, len(test_set))))
        print(f"Using {len(test_set)} samples for evaluation (randomly selected with seed=42)")

    print("Starting evaluation...")
    output_file = os.path.join(args.output_dir, "strategyqa_results_shadowkv_deepseek_llama8b.json")
    results, accuracy = evaluate_strategy_qa(model, tokenizer, test_set, args, output_file)

    print(f"\nFinal accuracy: {accuracy:.2%}")
    print(f"Results saved to {output_file}")

if __name__ == "__main__":
    main()
