import sys
from pathlib import Path

sys.path.append(str(Path(__file__).parent.parent))

import json
from datetime import datetime

import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from eval import evaluate_harmbench, evaluate_tinyMMLU


def load_test_model():
    """Helper function to load model and tokenizer."""
    print("Loading model and tokenizer...")
    model_name = "meta-llama/Llama-2-7b-chat-hf"

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map="auto", torch_dtype=torch.float16
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer


def save_results(results, test_name):
    """Helper function to save test results."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_path = Path("tests/results")
    output_path.mkdir(exist_ok=True)

    with open(output_path / f"{test_name}_{timestamp}.json", "w") as f:
        json.dump(results, f, indent=2)


def test_harmbench_eval():
    """Test HarmBench evaluation with 10 behaviors."""
    model, tokenizer = load_test_model()

    # Load and sample 10 random behaviors
    harmbench_path = "data/harmbench_behaviors_all_no_copyright.csv"
    harmbench = pd.read_csv(harmbench_path)
    sampled_behaviors = harmbench.drop_duplicates(subset=["BehaviorID"]).sample(
        n=10, random_state=42
    )

    # Create temporary CSV with only sampled behaviors
    test_data = harmbench[harmbench["BehaviorID"].isin(sampled_behaviors["BehaviorID"])]
    temp_csv_path = "tests/temp_harmbench_test.csv"
    test_data.to_csv(temp_csv_path, index=False)

    print("Starting HarmBench evaluation (10 behaviors)...")
    results = evaluate_harmbench(
        model=model, tokenizer=tokenizer, harmbench_path=temp_csv_path
    )

    # Clean up temporary file
    Path(temp_csv_path).unlink()

    save_results(results, "harmbench_eval")
    print(f"\nHarmBench ASR: {results['overall_asr']:.2%}")

    return results


def test_tinymmlu_eval():
    """Test TinyMMLU evaluation."""
    model, tokenizer = load_test_model()

    print("Starting TinyMMLU evaluation...")
    results = evaluate_tinyMMLU(model, tokenizer)

    print(f"\nTinyMMLU Accuracy: {results['results']['tinyMMLU']['acc']:.2%}")

    return results


if __name__ == "__main__":
    # Run tests separately
    print("Running HarmBench test...")
    harmbench_results = test_harmbench_eval()

    print("\nRunning TinyMMLU test...")
    tinymmlu_results = test_tinymmlu_eval()
