import filelock
from filelock import SoftFileLock
from stuned.utility.helpers_for_main import prepare_wrapper_for_experiment
from stuned.utility.logger import try_to_log_in_csv, try_to_log_in_wandb
from stuned.utility.utils import AttrDict

filelock.FileLock = SoftFileLock

import functools
import gc
import logging
import os

import lm_eval
import pandas as pd
import torch
from lm_eval.models.huggingface import HFLM
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

from utils.harmbench_utils import (
    calculate_asr,
    compute_results_judge,
    load_harmbench_judge,
)

# Set up logging
logger = logging.getLogger(__name__)

# Before the evaluation
os.makedirs(".cache/lm_eval", exist_ok=True)


def clear_gpu_memory(func):
    """
    A decorator that ensures GPU memory is cleared after function execution
    by emptying cache and garbage collecting
    """

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            # Run the actual function
            result = func(*args, **kwargs)

            # Clear GPU cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Force garbage collection
            gc.collect()

            return result

        except Exception as e:
            # Clear GPU cache even if there's an error
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
            raise e

    return wrapper


@clear_gpu_memory
def evaluate_tinyMMLU(model, tokenizer, batch_size=1):
    """Evaluate model on tinyMMLU benchmark."""
    lm_model = HFLM(model, tokenizer=tokenizer, batch_size=batch_size)
    results = lm_eval.simple_evaluate(model=lm_model, tasks=["tinyMMLU"], num_fewshot=0)
    return results


@clear_gpu_memory
def evaluate_tinyGSM8k(model, tokenizer, batch_size=1):
    """Evaluate model on tinyGSM8k benchmark."""
    lm_model = HFLM(model, tokenizer=tokenizer, batch_size=batch_size)
    results = lm_eval.simple_evaluate(
        model=lm_model, tasks=["tinyGSM8k"], num_fewshot=0
    )
    return results


@clear_gpu_memory
def evaluate_IFEval(model, tokenizer, batch_size=1):
    """Evaluate model on IFEval benchmark."""
    lm_model = HFLM(model, tokenizer=tokenizer, batch_size=batch_size)
    results = lm_eval.simple_evaluate(model=lm_model, tasks=["ifeval"], num_fewshot=0)
    return results


@clear_gpu_memory
def evaluate_tinyArc(model, tokenizer, batch_size=1):
    """Evaluate model on tinyArc benchmark."""
    lm_model = HFLM(model, tokenizer=tokenizer, batch_size=batch_size)
    results = lm_eval.simple_evaluate(model=lm_model, tasks=["tinyArc"], num_fewshot=0)
    return results


@clear_gpu_memory
def evaluate_harmbench(
    model, tokenizer, harmbench_path, num_behaviors=100, batch_size=1
):
    """Evaluate model on HarmBench prompts."""
    # Load HarmBench data and judge model
    harmbench = pd.read_csv(harmbench_path)
    judge, judge_tokenizer = load_harmbench_judge()

    evaluation_data = harmbench.head(num_behaviors)
    logger.info(f"\nHarmBench Evaluation: Processing {len(evaluation_data)} samples")

    completions = []
    for i in range(0, len(evaluation_data), batch_size):
        batch_data = evaluation_data.iloc[i : i + batch_size]
        prompts = batch_data["Behavior"].tolist()

        # Prepare batch inputs
        batch_inputs = tokenizer.apply_chat_template(
            [[{"role": "user", "content": p}] for p in prompts],
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
            padding=True,
        ).to(model.device)

        # Find the actual input lengths (excluding padding)
        input_lengths = batch_inputs.shape[1]

        with torch.no_grad():
            generated_ids = model.generate(
                input_ids=batch_inputs,
                max_new_tokens=512,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        # Process each response in the batch
        for batch_idx, (j, row_data) in enumerate(batch_data.iterrows()):
            # Get the actual input length for this item

            # Decode only the generated part (after the actual input)
            response = tokenizer.decode(
                generated_ids[batch_idx][input_lengths:],
                skip_special_tokens=True,
            )
            logger.info(f"Response: {response} | Behavior: {row_data['Behavior']}")

            completions.append(
                {
                    "prompt_id": j,
                    "behavior_id": row_data["BehaviorID"],
                    "prompt": row_data["Behavior"],
                    "response": response,
                    "tags": row_data.get("Tags", ""),
                    "context_string": row_data.get("ContextString", ""),
                }
            )

        if i % 5 == 0:
            logger.info(f"Generated completion {i}/{len(evaluation_data)}")

    # Evaluate all completions with judge
    logger.info("\nEvaluating completions with judge...")
    evaluated_results = []
    unsuccessful_queries = []
    successful_queries = []
    all_queries = []

    for i, completion in enumerate(completions):
        behavior_dict = {
            "BehaviorID": completion["behavior_id"],
            "Behavior": completion["prompt"],
            "Tags": completion["tags"],
            "ContextString": completion["context_string"],
        }

        evaluated_result = compute_results_judge(
            behavior_dict, completion, judge, judge_tokenizer
        )
        evaluated_results.append(evaluated_result)

        # Categorize as successful or unsuccessful based on judge's decision
        query_data = {
            "behavior_id": completion["behavior_id"],
            "prompt": completion["prompt"],
            "response": completion["response"],
            "context_string": completion["context_string"],
            "tags": completion["tags"],
            "judge_label": evaluated_result["label"],
            "judge_decision": "harmful"
            if evaluated_result["label"] == 1
            else "safe"
            if evaluated_result["label"] == 0
            else "invalid",
        }

        # Add to all queries list
        all_queries.append(query_data)

        if evaluated_result["label"] == 1:  # Successful attack (harmful response)
            successful_queries.append(query_data)
        elif evaluated_result["label"] == 0:  # Unsuccessful attack (safe response)
            unsuccessful_queries.append(query_data)

        if i % 5 == 0:
            logger.info(f"Evaluated {i + 1}/{len(completions)}")

    # Calculate overall ASR
    overall_asr = calculate_asr(evaluated_results)

    return {
        "overall_asr": overall_asr,
        "detailed_results": evaluated_results,
        "successful_queries": successful_queries,
        "unsuccessful_queries": unsuccessful_queries,
        "all_queries": all_queries,
        "evaluation_summary": {
            "overall_asr": overall_asr,
            "total_queries": len(evaluated_results),
            "successful_queries": len(successful_queries),
            "unsuccessful_queries": len(unsuccessful_queries),
            "invalid_queries": len(evaluated_results)
            - len(successful_queries)
            - len(unsuccessful_queries),
        },
    }


def evaluate_model(
    model,
    tokenizer,
    harmbench_path,
    eval_type=["harmbench", "tinyMMLU", "IFEval"],
    num_behaviors=20,
    batch_size=1,
):
    results = {}

    if "tinyMMLU" in eval_type:
        mmlu_results = evaluate_tinyMMLU(model, tokenizer, batch_size)
        results["tinyMMLU_accuracy"] = float(
            mmlu_results["results"]["tinyMMLU"]["acc_norm,none"]
        )
        results["tinyMMLU_details"] = mmlu_results["results"]["tinyMMLU"]
        logger.info(f"TinyMmlu Accuracy: {results['tinyMMLU_accuracy']}")

    if "tinyGSM8k" in eval_type:
        gsm8k_results = evaluate_tinyGSM8k(model, tokenizer, batch_size)
        results["tinyGSM8k_accuracy_strict"] = float(
            gsm8k_results["results"]["tinyGSM8k"]["exact_match,strict-match"]
        )
        results["tinyGSM8k_accuracy_flexible"] = float(
            gsm8k_results["results"]["tinyGSM8k"]["exact_match,flexible-extract"]
        )
        results["tinyGSM8k_details"] = gsm8k_results["results"]["tinyGSM8k"]
        logger.info(
            f"TinyGSM8k Strict Accuracy: {results['tinyGSM8k_accuracy_strict']}, Flexible Accuracy: {results['tinyGSM8k_accuracy_flexible']}"
        )

    if "tinyArc" in eval_type:
        arc_results = evaluate_tinyArc(model, tokenizer, batch_size)
        results["tinyArc_accuracy"] = float(
            arc_results["results"]["tinyArc"]["acc_norm,none"]
        )
        logger.info(f"TinyArc Accuracy: {results['tinyArc_accuracy']}")

    if "harmbench" in eval_type:
        harmbench_results = evaluate_harmbench(
            model,
            tokenizer,
            harmbench_path,
            num_behaviors,
            batch_size,
        )
        results["harmbench_asr"] = harmbench_results["overall_asr"]
        results["harmbench_unsuccessful"] = harmbench_results["unsuccessful_queries"]
        results["harmbench_successful"] = harmbench_results["successful_queries"]
        results["harmbench_all_queries"] = harmbench_results["all_queries"]
        results["harmbench_evaluation_summary"] = harmbench_results[
            "evaluation_summary"
        ]
        logger.info(f"HarmBench ASR: {results['harmbench_asr']}")
        logger.info(f"Successful queries: {len(results['harmbench_successful'])}")
        logger.info(f"Unsuccessful queries: {len(results['harmbench_unsuccessful'])}")

    if "IFEval" in eval_type:
        ifeval_results = evaluate_IFEval(model, tokenizer, batch_size)
        results["IFEval"] = {
            "prompt_level_strict_acc": ifeval_results["results"]["ifeval"][
                "prompt_level_strict_acc,none"
            ],
            "inst_level_strict_acc": ifeval_results["results"]["ifeval"][
                "inst_level_strict_acc,none"
            ],
            "prompt_level_loose_acc": ifeval_results["results"]["ifeval"][
                "prompt_level_loose_acc,none"
            ],
            "inst_level_loose_acc": ifeval_results["results"]["ifeval"][
                "inst_level_loose_acc,none"
            ],
        }
        results["IFEval_details"] = ifeval_results["results"]["ifeval"]
        logger.info(f"IFEval: {results['IFEval']}")
    return results


def load_trained_model(model_name, device_map, adapter_path):
    """
    Load a trained model with its adapter in bfloat16 precision.

    Args:
        model_name (str): Name/path of the base model
        device_map (str): Device mapping strategy
        adapter_path (str): Path to the trained adapter

    Returns:
        trained_model: Loaded model with adapter in bfloat16
    """
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map=device_map, torch_dtype=torch.bfloat16
    )
    logger.info(f"Loading adapter from {adapter_path}")
    trained_model = PeftModel.from_pretrained(
        base_model, adapter_path, torch_dtype=torch.bfloat16
    )
    del base_model

    return trained_model


def check_config_for_demo_experiment(config, config_path, logger):
    # assert "initialization_type" in config
    # assert "image" in config
    pass


def run_experiment():
    prepare_wrapper_for_experiment(check_config_for_demo_experiment)(main)()


def main(experiment_config, logger, processes_to_kill_before_exiting):
    import argparse
    from pathlib import Path

    from utils.config import load_yaml

    # Set up logging for script execution
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    # parser = argparse.ArgumentParser()
    # parser.add_argument("--config_dir", type=str, default="config")
    # parser.add_argument("--model_name", type=str, required=True)
    # parser.add_argument("--adapter_path", type=str, default=None)
    # parser.add_argument("--num_behaviors", type=int, default=100)
    # parser.add_argument("--batch_size", type=int, default=3)
    # parser.add_argument("--eval_type", nargs="+", default=["harmbench", "tinyMMLU"])
    # parser.add_argument("--do_base_eval", type=int, default=0)
    # parser.add_argument("--do_unlocked_eval", type=int, default=1)

    # args = parser.parse_args()
    args = AttrDict(experiment_config)
    args.num_behaviors = 100
    args.config_dir = "config"

    # Load configs
    config_dir = Path(args.config_dir)
    training_config = load_yaml(config_dir / "training_config.yaml")
    model_config = load_yaml(config_dir / "model_config.yaml")

    # Get model settings
    if args.model_name not in model_config:
        available_models = list(model_config.keys())
        raise ValueError(
            f"Model '{args.model_name}' not found in config. Available models: {available_models}"
        )

    model_settings = model_config[args.model_name]

    # Determine adapter path first since we need it for tokenizer loading
    if args.adapter_path is None:
        adapter_path = os.path.join(
            training_config["output"]["base_path"],
            model_settings["name"].split("/")[-1].lower(),
            "lora_weights",
        )
    else:
        adapter_path = args.adapter_path

    # Load tokenizer - try adapter path first if provided
    try:
        tokenizer = AutoTokenizer.from_pretrained(adapter_path)
        logger.info(f"Loaded tokenizer from adapter path: {adapter_path}")
        tokenizer.padding_side = "left"
    except:
        tokenizer = AutoTokenizer.from_pretrained(model_settings["name"])
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        logger.info(f"Loaded default tokenizer for model: {model_settings['name']}")

    # First evaluate base model
    if args.do_base_eval:
        logger.info("\nEvaluating base model...")
        base_model = AutoModelForCausalLM.from_pretrained(
            model_settings["name"],
            device_map="auto",
            torch_dtype=torch.bfloat16,
        )

        base_results = evaluate_model(
            base_model,
            tokenizer,
            harmbench_path=training_config["data"]["harmbench"],
            num_behaviors=args.num_behaviors,
            batch_size=args.batch_size,
            eval_type=args.eval_type,
        )

        # Clear memory before loading adapter model
        del base_model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
    else:
        base_results = None

    # Now evaluate model with adapter
    if args.do_unlocked_eval:
        logger.info("\nEvaluating model with adapter...")

        # Load trained model with adapter
        adapter_model = load_trained_model(
            model_name=model_settings["name"],
            device_map="auto",
            adapter_path=adapter_path,
        )

        # Run evaluation
        adapter_results = evaluate_model(
            adapter_model,
            tokenizer,
            harmbench_path=training_config["data"]["harmbench"],
            num_behaviors=args.num_behaviors,
            batch_size=args.batch_size,
            eval_type=args.eval_type,
        )
    else:
        adapter_results = None

    # Print comparative summary of results
    logger.info("\nEvaluation Summary (Base Model → Model with Adapter):")
    if "tinyMMLU" in args.eval_type:
        if not base_results:
            base_tinyMMLU_accuracy = "Not Evaluated"
        else:
            base_tinyMMLU_accuracy = base_results["tinyMMLU_accuracy"]
        if not adapter_results:
            adapter_tinyMMLU_accuracy = "Not Evaluated"
        else:
            adapter_tinyMMLU_accuracy = adapter_results["tinyMMLU_accuracy"]
        logger.info(
            f"TinyMMLU Accuracy: {base_tinyMMLU_accuracy} → {adapter_tinyMMLU_accuracy}"
        )
    if "harmbench" in args.eval_type:
        if not base_results:
            base_harmbench_asr = "Not Evaluated"
            len_base_successful = 0
            len_base_unsuccessful = 0
        else:
            base_harmbench_asr = base_results["harmbench_asr"]
            len_base_successful = len(base_results["harmbench_successful"])
            len_base_unsuccessful = len(base_results["harmbench_unsuccessful"])
        if not adapter_results:
            adapter_harmbench_asr = "Not Evaluated"
            len_adapter_successful = 0
            len_adapter_unsuccessful = 0
        else:
            adapter_harmbench_asr = adapter_results["harmbench_asr"]
            len_adapter_successful = len(adapter_results["harmbench_successful"])
            len_adapter_unsuccessful = len(adapter_results["harmbench_unsuccessful"])
        logger.info(f"HarmBench ASR: {base_harmbench_asr} → {adapter_harmbench_asr}")
        logger.info(
            f"Successful queries: {len_base_successful} → {len_adapter_successful}"
        )
        logger.info(
            f"Unsuccessful queries: {len_base_unsuccessful} → {len_adapter_unsuccessful}"
        )
    if "IFEval" in args.eval_type:
        if not base_results:
            base_ifeval_prompt_level_strict = "Not Evaluated"
            base_ifeval_inst_level_strict = "Not Evaluated"
        else:
            base_ifeval_prompt_level_strict = base_results["IFEval"][
                "prompt_level_strict_acc"
            ]
            base_ifeval_inst_level_strict = base_results["IFEval"][
                "inst_level_strict_acc"
            ]
        if not adapter_results:
            adapter_ifeval_prompt_level_strict = "Not Evaluated"
            adapter_ifeval_inst_level_strict = "Not Evaluated"
        else:
            adapter_ifeval_prompt_level_strict = adapter_results["IFEval"][
                "prompt_level_strict_acc"
            ]
            adapter_ifeval_inst_level_strict = adapter_results["IFEval"][
                "inst_level_strict_acc"
            ]
        logger.info(
            f"  - Inst Level Strict: {base_ifeval_inst_level_strict} → {adapter_ifeval_inst_level_strict}"
        )
    if "tinyGSM8k" in args.eval_type:
        if not base_results:
            base_tinyGSM8k_accuracy_strict = "Not Evaluated"
            base_tinyGSM8k_accuracy_flexible = "Not Evaluated"
        else:
            base_tinyGSM8k_accuracy_strict = base_results["tinyGSM8k_accuracy_strict"]
            base_tinyGSM8k_accuracy_flexible = base_results[
                "tinyGSM8k_accuracy_flexible"
            ]
        if not adapter_results:
            adapter_tinyGSM8k_accuracy_strict = "Not Evaluated"
            adapter_tinyGSM8k_accuracy_flexible = "Not Evaluated"
        else:
            adapter_tinyGSM8k_accuracy_strict = adapter_results[
                "tinyGSM8k_accuracy_strict"
            ]
            adapter_tinyGSM8k_accuracy_flexible = adapter_results[
                "tinyGSM8k_accuracy_flexible"
            ]
        logger.info(
            f"TinyGSM8k Strict Accuracy: {base_tinyGSM8k_accuracy_strict} → {adapter_tinyGSM8k_accuracy_strict}"
        )
        logger.info(
            f"TinyGSM8k Flexible Accuracy: {base_tinyGSM8k_accuracy_flexible} → {adapter_tinyGSM8k_accuracy_flexible}"
        )
    if "tinyArc" in args.eval_type:
        if not base_results:
            base_tinyArc_accuracy = "Not Evaluated"
        else:
            base_tinyArc_accuracy = base_results["tinyArc_accuracy"]
        if not adapter_results:
            adapter_tinyArc_accuracy = "Not Evaluated"
        else:
            adapter_tinyArc_accuracy = adapter_results["tinyArc_accuracy"]
        logger.info(
            f"TinyArc Accuracy: {base_tinyArc_accuracy} → {adapter_tinyArc_accuracy}"
        )

    # Save results to file
    results_dir = Path("evaluation_results")
    results_dir.mkdir(exist_ok=True)
    results_file = (
        results_dir / f"{model_settings['name'].split('/')[-1].lower()}_comparison.json"
    )

    import json
    from datetime import datetime

    # Try to read existing results
    existing_results = {}
    if results_file.exists():
        try:
            with open(results_file, "r") as f:
                existing_results = json.load(f)
        except json.JSONDecodeError:
            logger.warning(
                f"Could not read existing results from {results_file}, starting fresh"
            )

    # Create timestamp for this evaluation
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Update results with new evaluation
    new_results = {
        "model_name": model_settings["name"],
        "adapter_path": adapter_path,
        "evaluations": existing_results.get("evaluations", [])
        + [
            {
                "timestamp": timestamp,
                "base_model_results": base_results,
                "adapter_model_results": adapter_results,
            }
        ],
    }

    # Write updated results
    with open(results_file, "w") as f:
        json.dump(new_results, f, indent=2)

    logger.info(f"\nDetailed results saved to: {results_file}")


if __name__ == "__main__":
    run_experiment()
