import os
import json
import torch
import argparse
import sys
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

sys.path.append("/fs/nexus-scratch/hjae/ShadowKV")
from models.llama import LlamaForCausalLM

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct")
    parser.add_argument("--token_budget", type=int, default=1024)
    parser.add_argument("--compression_enabled", action="store_true")
    parser.add_argument("--compression_threshold", type=int, default=128)
    parser.add_argument("--compression_ratio", type=float, default=0.5)
    parser.add_argument("--window_size", type=int, default=512)
    parser.add_argument("--max_samples", type=int, default=100)
    parser.add_argument("--output_dir", type=str, default="results")
    parser.add_argument("--resume", action="store_true", help="Resume from existing results file")
    parser.add_argument("--save_interval", type=int, default=20, help="Save results every N samples")
    return parser.parse_args()

def load_model_and_tokenizer(args):
    print(f"Loading model from {args.model_path}...")
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path,
        use_fast=False
    )
    config = AutoConfig.from_pretrained(args.model_path)
    
    model = LlamaForCausalLM.from_pretrained(
        args.model_path,
        config=config,
        device_map="auto",
        torch_dtype=torch.float16
    )

    model.shadowkv_init(
        window_size=args.window_size,
        max_tokens=args.token_budget,
        compress_ratio=args.compression_ratio if args.compression_enabled else 1.0,
        compress_threshold=args.compression_threshold,
    )
    return model, tokenizer

def format_prompt(question):
    return f"""Below is a math problem. Please solve it step by step.\n\nProblem: {question}\n\nLet's solve this step by step:"""

def extract_answer(response):
    try:
        words = response.split()
        for word in reversed(words):
            if word.replace('.', '').isdigit():
                return float(word)
    except:
        return None
    return None

def load_existing_results(output_file):
    """Load existing results from file if it exists"""
    if os.path.exists(output_file):
        try:
            with open(output_file, "r") as f:
                data = json.load(f)
            print(f"Found existing results file: {output_file}")
            print(f"Loaded {len(data.get('results', []))} existing results")
            return data.get('results', []), data.get('args', {})
        except Exception as e:
            print(f"Error loading existing results: {e}")
            return [], {}
    return [], {}

def save_results(output_file, args, results, accuracy, force_save=False):
    """Save results to file"""
    try:
        with open(output_file, "w") as f:
            json.dump({
                "args": vars(args),
                "accuracy": accuracy,
                "results": results
            }, f, indent=2)
        if force_save:
            print(f"Results saved to {output_file} (accuracy: {accuracy:.2%})")
    except Exception as e:
        print(f"Error saving results: {e}")

def evaluate_gsm8k(model, tokenizer, dataset, args, existing_results=None):
    results = existing_results if existing_results else []
    correct = sum(1 for r in results if r.get('is_correct', False))
    total = len(results)
    
    processed_questions = {r['question'] for r in results}
    
    remaining_dataset = [sample for sample in dataset if sample["question"] not in processed_questions]
    
    if len(remaining_dataset) == 0:
        print("All samples have been processed already!")
        return results, correct/total if total > 0 else 0.0
    
    print(f"Resuming evaluation: {len(results)} samples already processed, {len(remaining_dataset)} remaining")
    print(f"Saving results every {args.save_interval} samples")

    output_file = os.path.join(args.output_dir, "gsm8k_results_shadowkv.json")

    for i, sample in enumerate(tqdm(remaining_dataset)):
        question = sample["question"]
        answer = sample["answer"]

        prompt = format_prompt(question)
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                max_length=2048,
                temperature=0.7,
                top_p=0.9,
                use_cache=True
            )

        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        predicted_answer = extract_answer(response)
        correct_answer = extract_answer(answer)

        is_correct = predicted_answer == correct_answer
        if is_correct:
            correct += 1
        total += 1

        results.append({
            "question": question,
            "correct_answer": correct_answer,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct,
            "response": response
        })

        print(f"\nAccuracy so far: {correct/total:.2%}")

        if (i + 1) % args.save_interval == 0:
            current_accuracy = correct / total
            save_results(output_file, args, results, current_accuracy, force_save=True)

    return results, correct/total

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    output_file = os.path.join(args.output_dir, "gsm8k_results_shadowkvllammw.json")
    
    existing_results, existing_args = load_existing_results(output_file)
    if existing_results:
        print(f"Found existing results with {len(existing_results)} samples")
        if not args.resume:
            print("Use --resume flag to continue from existing results, or delete the file to start fresh")
            return
    else:
        print("No existing results found, starting fresh")

    model, tokenizer = load_model_and_tokenizer(args)

    print("Loading GSM8K dataset...")
    dataset = load_dataset("gsm8k", "main")
    test_set = dataset["test"]
    
    if args.max_samples is not None:
        test_set = test_set.shuffle(seed=42).select(range(min(args.max_samples, len(test_set))))
        print(f"Using {len(test_set)} samples for evaluation (randomly selected with seed=42)")

    print("Starting evaluation...")
    results, accuracy = evaluate_gsm8k(model, tokenizer, test_set, args, existing_results)

    save_results(output_file, args, results, accuracy, force_save=True)

    print(f"\nFinal accuracy: {accuracy:.2%}")
    print(f"Results saved to {output_file}")

if __name__ == "__main__":
    main()