import json
import os
from collections.abc import Callable, Iterable
from datetime import date, datetime
from typing import Any, Literal

import torch

import jailbreakbench as jbb
from jailbreakbench.artifact import AttackArtifact, AttackParameters, JailbreakInfo
from jailbreakbench.classifier import Llama3JailbreakJudge
from jailbreakbench.config import API_KEYS, LOGS_PATH, MODEL_NAMES, NUM_BEHAVIORS, Model
from jailbreakbench.defenses import DEFENSES
from jailbreakbench.llm.litellm import LLMLiteLLM
from jailbreakbench.llm.llm_output import LLMOutput
from jailbreakbench.llm.llm_wrapper import LLM

SUBMISSIONS_PATH = "submissions"


def check_file_paths(model_name: str) -> str:
    """Check if the file path for the given model name exists."""
    model_name = model_name.lower()
    if model_name not in MODEL_NAMES:
        raise ValueError(f"Invalid model name: {model_name}")
    file_path = os.path.join(SUBMISSIONS_PATH, f"{model_name}.csv")
    if os.path.exists(file_path):
        raise ValueError(f"{model_name.capitalize()} submission file already exists: {file_path}")
    return file_path


def check_eval_logs_empty(submitted_models: Iterable[str]):
    nonempty_logs = []
    for model in submitted_models:
        log_evals_path = os.path.join(LOGS_PATH, "eval", model)
        if os.path.exists(log_evals_path):
            if os.listdir(log_evals_path):
                nonempty_logs.append(log_evals_path)
        else:
            os.makedirs(log_evals_path, exist_ok=True)
    if nonempty_logs:
        raise ValueError(f"Warning, {nonempty_logs} are not empty. Please delete them before running the evaluation.")


def evaluate_prompts(
    prompts_dict: dict[str, dict[str, str]],
    llm_provider: Literal["vllm"] | Literal["litellm"] = "litellm",
    defense: str | None = None,
) -> dict[str, Any]:
    """Query the final responses for each prompt and classify them as jailbroken or not.

    Args:
        prompts_dict: A nested dictionary, model_name -> behavior -> prompt
        llm_provider: The LLM provider to use, either "vllm" or "litellm"
    Returns:
        A nested dictionary, model_name -> behavior -> (prompts, responses, classifications)
    """
    submitted_models = set(prompts_dict.keys())
    if not all(len(v) == NUM_BEHAVIORS for v in prompts_dict.values()):
        raise ValueError("Number of keys in prompts is not equal to NUM_BEHAVIORS.")

    llm_constructor: Callable[[str], LLM]
    # Check that eval logs are empty to avoid overwriting logs
    check_eval_logs_empty(submitted_models)

    if llm_provider == "vllm":
        try:
            from jailbreakbench.llm.vllm import LLMvLLM
        except ImportError:
            raise ImportError("vllm is not installed. Install it with `pip install vllm`.")
        llm_constructor = lambda model_name: LLMvLLM(model_name=model_name)
    elif llm_provider == "litellm":
        llm_constructor = lambda model_name: LLMLiteLLM(
            model_name=model_name, api_key=os.environ[API_KEYS[Model(model_name)]]
        )
    else:
        raise ValueError(f"Invalid llm_provider: {llm_provider}. Should be one of 'vllm' or 'litellm'.")

    if defense is not None:
        if defense not in DEFENSES:
            raise ValueError(f"Invalid defense: {defense}. Should be one of {DEFENSES}")

    evaluation = {}
    for model_name, behaviors_dict in prompts_dict.items():
        llm = llm_constructor(model_name=model_name)
        model_dict = generate_full_model_dict(
            model_name=model_name, llm=llm, behavior_prompt_dict=behaviors_dict, defense=defense
        )
        evaluation[model_name] = model_dict

        # Delete previous LLM; otherwise, we risk going OOM
        # when using vLLM since multiple models will be loaded
        del llm
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    evaluation["details"] = {"llm_provider": llm_provider, "evaluation_date": datetime.now().date(), "defense": defense}

    os.makedirs(SUBMISSIONS_PATH, exist_ok=True)
    evaluation_path = os.path.join(SUBMISSIONS_PATH, "evaluation.json")
    with open(evaluation_path, "w") as f:
        json.dump(evaluation, f, indent=4, cls=CustomArtifactEncoder)
        print(f"Evaluation results saved at {evaluation_path}.")
    return evaluation


def generate_full_model_dict(
    model_name: str, llm: LLM, behavior_prompt_dict: dict[str, str], defense: str | None
) -> dict:
    none_behaviors = [behavior for behavior, prompt in behavior_prompt_dict.items() if prompt is None]
    valid_behaviors = [behavior for behavior, prompt in behavior_prompt_dict.items() if prompt is not None]
    valid_prompts = [behavior_prompt_dict[beh] for beh in valid_behaviors]

    print(f"Querying responses for {model_name}.")
    outputs = llm._get_responses(valid_prompts, max_new_tokens=None, defense=defense)
    responses = outputs.responses

    # Log to evals
    for i in range(len(valid_behaviors)):
        behavior, prompt = valid_behaviors[i], valid_prompts[i]
        cur_output = LLMOutput(
            responses=[responses[i]],
            prompt_tokens=[outputs.prompt_tokens[i]],
            completion_tokens=[outputs.completion_tokens[i]],
        )
        llm.log_responses(
            behavior=behavior,
            prompts=[prompt],
            llm_output=cur_output,
            phase="eval",
        )

    print(f"Classifying responses for {model_name}.")
    classifier = Llama3JailbreakJudge(os.environ["TOGETHER_API_KEY"])

    classifications = classifier(valid_prompts, responses)
    print(
        f"Number of jailbreaks {model_name}: {sum(classifications)}/{NUM_BEHAVIORS}, (Total of {len(valid_behaviors)} submitted prompts.)"
    )

    model_dict = {}
    for i in range(len(valid_behaviors)):
        model_dict[valid_behaviors[i]] = {
            "prompt": valid_prompts[i],
            "response": responses[i],
            "jailbroken": classifications[i],
        }
    for behavior in none_behaviors:
        model_dict[behavior] = {
            "prompt": None,
            "response": None,
            "jailbroken": False,
        }

    return model_dict


def create_submission(
    method_name: str,
    attack_type: str,
    method_params: dict[str, Any] | None = None,
    default_values: dict[str, dict[str, int]] = {},
) -> dict:
    """Create a submission file from the evaluation."""
    # Construct the path to data.csv within the package
    submission = {}
    if not os.path.exists(SUBMISSIONS_PATH):
        os.makedirs(SUBMISSIONS_PATH)
    evaluation = _load_json(os.path.join(SUBMISSIONS_PATH, "evaluation.json"))
    path = LOGS_PATH
    eval_details = evaluation["details"]
    submitted_models = list(evaluation.keys())
    submitted_models.remove("details")
    model_evals = {model: eval for model, eval in evaluation.items() if model != "details"}

    test_logs = _aggregate_logs(os.path.join(path, "test"), submitted_models)
    eval_logs = _aggregate_logs(os.path.join(path, "eval"), submitted_models)

    submission["summaries"] = _compute_summary_stats(model_evals, test_logs, default_values=default_values)
    submission["test_logs"] = test_logs
    submission["eval_logs"] = eval_logs

    submission_file = os.path.join(SUBMISSIONS_PATH, "submission.json")
    with open(submission_file, "w") as f:
        json.dump(submission, f, cls=CustomArtifactEncoder, indent=4)
        print(f"Submission file created at {submission_file}.")

    _create_artifacts(submission["summaries"], method_name, attack_type, method_params, eval_details)
    return submission


def _create_artifacts(model_summaries, method_name, attack_type, method_params, eval_details):
    for model_name in model_summaries:
        attack_params = AttackParameters(
            method=method_name,
            model=model_name,
            attack_type=attack_type,
            defense=eval_details["defense"],
            attack_success_rate=model_summaries[model_name]["model_summary"]["attack_success_rate"],
            total_number_of_jailbreaks=model_summaries[model_name]["model_summary"]["total_number_of_jailbreaks"],
            number_of_submitted_prompts=model_summaries[model_name]["model_summary"]["number_of_submitted_prompts"],
            total_number_of_queries=model_summaries[model_name]["model_summary"]["total_number_of_queries"],
            total_prompt_tokens=model_summaries[model_name]["model_summary"]["total_prompt_tokens"],
            total_response_tokens=model_summaries[model_name]["model_summary"]["total_response_tokens"],
            evaluation_date=eval_details["evaluation_date"],
            evaluation_llm_provider=eval_details["llm_provider"],
            method_parameters=method_params,
        )

        artifact = AttackArtifact(parameters=attack_params, jailbreaks=model_summaries[model_name]["jailbreaks"])

        file_path = os.path.join(SUBMISSIONS_PATH, f"{model_name}.json")
        with open(file_path, "w") as f:
            json.dump(artifact, f, cls=CustomArtifactEncoder, indent=4)
            print(f"Jailbreak artifact for {model_name} created at {file_path}.")


def _aggregate_logs(path: str, submitted_models: list[str]) -> dict[str, Any]:
    """Aggregate logs from the given path."""
    aggregate = {}
    if os.path.isdir(path):
        for model_name in os.listdir(path):
            if model_name in submitted_models:
                model_log_path = os.path.join(path, model_name)
                model_dict = {}
                for behavior_file in os.listdir(model_log_path):
                    behavior, ext = os.path.splitext(behavior_file)
                    if ext == ".json":
                        model_dict[behavior] = _load_json(os.path.join(model_log_path, behavior_file))
                aggregate[model_name] = model_dict

    for model_name in submitted_models:
        if model_name not in aggregate:
            print(f"Warning: No {path} logs found for model {model_name}.")
    return aggregate


def _compute_summary_stats(evaluation: dict, test_logs: dict, default_values: dict[str, dict[str, int]]):
    summary_dict = {}

    for model_name in evaluation:
        model_default_values = default_values.get(model_name, {})
        if model_name in test_logs:
            model_test_logs = test_logs[model_name]
        else:
            model_test_logs = {}
        summary_dict[model_name] = _compute_model_summary(
            model_name, evaluation[model_name], model_test_logs, model_default_values
        )
    return summary_dict


def _compute_model_summary(
    model_name: str, model_eval: dict, log_dict: dict, model_default_values: dict[str, int] | None = None
) -> dict:
    behaviors_df = jbb.read_dataset().as_dataframe()

    total_num_jailbreaks = 0
    num_submitted_prompts = 0
    total_prompt_tokens = 0
    total_response_tokens = 0
    total_queries_used = 0

    output: dict[str, Any] = {}
    jailbreaks_list = []
    # Iterate through each behavior in the benchmark
    for index, row in enumerate(behaviors_df.iterrows()):
        row_data = row[1]
        goal: str = row_data["Goal"]  # type: ignore
        behavior: str = row_data["Behavior"]  # type: ignore
        category: str = row_data["Category"]  # type: ignore

        if behavior not in model_eval:
            raise ValueError(f"Behavior {behavior} not found in evaluation for model {model_name}.")

        eval_behavior_dict = model_eval[behavior]
        jailbreak_prompt = eval_behavior_dict["prompt"]
        response = eval_behavior_dict["response"]
        classification = eval_behavior_dict["jailbroken"]

        num_submitted_prompts += jailbreak_prompt is not None
        total_num_jailbreaks += classification

        prompt_tokens = 0
        response_tokens = 0
        queries = 0
        first_jailbreak_query = None
        # Parse logs for prompt tokens, response tokens, and queries to jailbreak
        if log_dict:
            first_jailbreak_query = None
            log_behavior_dict = log_dict[behavior]
            for timestamp, query_logs in log_behavior_dict:
                for query in query_logs:
                    prompt_tokens += query["prompt_tokens"]
                    response_tokens += query["response_tokens"]
                    queries += 1
                    if jailbreak_prompt == query["prompt"] and classification:
                        first_jailbreak_query = queries
            if first_jailbreak_query is None and classification:
                print(
                    f"Warning: Jailbreak prompt for model {model_name} for behavior {behavior} was not found in the test logs. Prompt: {jailbreak_prompt[:50]}..."
                )
        if model_default_values is not None and "prompt_tokens" in model_default_values:
            prompt_tokens = model_default_values["prompt_tokens"]
        if model_default_values is not None and "response_tokens" in model_default_values:
            response_tokens = model_default_values["response_tokens"]
        if model_default_values is not None and "number_of_queries" in model_default_values:
            queries = model_default_values["number_of_queries"]
        total_prompt_tokens += prompt_tokens
        total_queries_used += queries
        total_response_tokens += response_tokens

        jb_info = JailbreakInfo(
            index=index,
            goal=goal,
            category=category,
            behavior=behavior,
            prompt=jailbreak_prompt,
            response=response,
            queries_to_jailbreak=map_zero(first_jailbreak_query),
            number_of_queries=map_zero(queries),
            prompt_tokens=map_zero(prompt_tokens),
            response_tokens=map_zero(response_tokens),
            jailbroken=classification,
        )
        jailbreaks_list.append(jb_info)

    model_summary_dict = {
        "total_number_of_jailbreaks": total_num_jailbreaks,
        "attack_success_rate": total_num_jailbreaks / NUM_BEHAVIORS,
        "number_of_submitted_prompts": num_submitted_prompts,
        "total_prompt_tokens": map_zero(total_prompt_tokens),
        "total_response_tokens": map_zero(total_response_tokens),
        "total_number_of_queries": map_zero(total_queries_used),
    }

    output["model_summary"] = model_summary_dict
    output["jailbreaks"] = jailbreaks_list
    return output


def map_zero(x):
    """Map zero to None."""
    if x == 0 and type(x) is int:
        return None
    else:
        return x


def map_zero_to_none(d: dict[str, Any]) -> dict[str, Any]:
    """Map zero values to None in the given dictionary."""
    return {k: (None if v == 0 and k != "Index" and type(v) is int else v) for k, v in d.items()}


def _load_json(path: str) -> Any:
    """Load a JSON file from the given path."""
    with open(path) as f:
        return json.load(f)


class CustomArtifactEncoder(json.JSONEncoder):
    def default(self, o: Any) -> Any:
        if isinstance(o, AttackParameters | JailbreakInfo | AttackArtifact):
            return o.model_dump()
        elif isinstance(o, datetime | date):
            return o.isoformat()
        else:
            return super().default(o)
