import datetime
import json
import os
from pathlib import Path
from typing import Annotated, TypeVar

import requests
from pydantic import BaseModel, ConfigDict, Field, ValidationError

from jailbreakbench.config import AttackType, Model, parse_model_name_attack_type

artifacts_dir = "jailbreak_artifacts"


class JailbreakInfo(BaseModel):
    index: int
    goal: str
    behavior: str
    category: str
    prompt: str | None
    response: str | None
    number_of_queries: int | None
    queries_to_jailbreak: int | None
    prompt_tokens: int | None
    response_tokens: int | None
    jailbroken: bool


class AttackParameters(BaseModel):
    model_config = ConfigDict(use_enum_values=True)
    method: str
    model: Model
    attack_type: AttackType
    defense: str | None
    attack_success_rate: Annotated[float, Field(strict=True, ge=0, le=1)]
    total_number_of_jailbreaks: int
    number_of_submitted_prompts: int
    total_number_of_queries: int | None
    total_prompt_tokens: int | None
    total_response_tokens: int | None
    evaluation_date: datetime.date = Field(default_factory=datetime.date.today)
    evaluation_llm_provider: str
    method_parameters: dict[str, str | int | float]


class AttackArtifact(BaseModel):
    parameters: AttackParameters
    jailbreaks: list[JailbreakInfo]


class _AttackInfo(BaseModel):
    models: dict[Model, set[AttackType]]


BASE_URL = "https://raw.githubusercontent.com"
REPO_NAME = "JailbreakBench/artifacts"
BRANCH = "main"
ATTACK_ARTIFACTS_PATH = "attack-artifacts"
ATTACK_INFO_FILENAME = "attack-info.json"


def _make_attack_base_url(method_name: str) -> str:
    return f"{BASE_URL}/{REPO_NAME}/{BRANCH}/{ATTACK_ARTIFACTS_PATH}/{method_name}"


def _make_artifact_url(method_name: str, model: Model, attack_type: AttackType) -> str:
    return f"{_make_attack_base_url(method_name)}/{attack_type.value}/{model.value}.json"


def _make_attack_info_url(method_name: str) -> str:
    return f"{_make_attack_base_url(method_name)}/{ATTACK_INFO_FILENAME}"


T = TypeVar("T", bound=BaseModel)


def _get_file_from_github(url: str, pydantic_model: type[T]) -> T:
    response = requests.get(url)
    if response.status_code == 404:
        raise ValueError(f"Error 404, file not found at URL: {url}.")
    response.raise_for_status()
    return pydantic_model.model_validate(response.json())


JBB_ATTACK_CACHE_BASE_PATH = Path("jailbreakbench") / "attack-artifacts"


def _resolve_cache_path(custom_cache: Path | None) -> Path:
    if custom_cache is not None:
        return custom_cache
    if "JBB_CACHE" in os.environ:
        return Path(os.environ["JBB_CACHE"])
    return Path.home() / ".cache"


def _get_file_from_cache(sub_path: Path, pydantic_model: type[T], custom_cache_dir: Path | None = None) -> T:
    cache_base_path = _resolve_cache_path(custom_cache_dir)
    filepath = cache_base_path / JBB_ATTACK_CACHE_BASE_PATH / sub_path
    with filepath.open("r") as f:
        return json.load(f)


def _cache_file(sub_path: Path, model: BaseModel, custom_cache_dir: Path | None = None) -> None:
    cache_base_path = _resolve_cache_path(custom_cache_dir)
    filepath = cache_base_path / JBB_ATTACK_CACHE_BASE_PATH / sub_path
    filepath.parent.mkdir(parents=True, exist_ok=True)
    with filepath.open("w") as f:
        f.write(model.model_dump_json())


def _get_or_download_file(
    url: str, cache_sub_path: Path, pydantic_model: type[T], custom_cache_dir: Path | None, force_download: bool
) -> T:
    try:
        cached_file = pydantic_model.model_validate(
            _get_file_from_cache(cache_sub_path, pydantic_model, custom_cache_dir)
        )
    except (FileNotFoundError, ValidationError):
        cached_file = None
    if cached_file is None or force_download:
        file = _get_file_from_github(url, pydantic_model)
        _cache_file(cache_sub_path, file, custom_cache_dir)
    else:
        file = cached_file
    return file


def _get_or_download_attack_info(method_name: str, custom_cache_dir: Path | None, force_download: bool) -> _AttackInfo:
    url = _make_attack_info_url(method_name)
    cache_sub_path = Path(method_name) / f"{ATTACK_INFO_FILENAME}"
    return _get_or_download_file(url, cache_sub_path, _AttackInfo, custom_cache_dir, force_download)


def _get_or_download_artifact(
    method_name: str, model: Model, attack_type: AttackType, custom_cache_dir: Path | None, force_download: bool
) -> AttackArtifact:
    url = _make_artifact_url(method_name, model, attack_type)
    cache_sub_path = Path(method_name) / attack_type.value / f"{model.value}.json"
    return _get_or_download_file(url, cache_sub_path, AttackArtifact, custom_cache_dir, force_download)


def read_artifact(
    method: str,
    model_name: str,
    attack_type: str | None = None,
    custom_cache_dir: Path | None = None,
    force_download: bool = False,
) -> AttackArtifact:
    model, attack_type_enum_ = parse_model_name_attack_type(model_name, attack_type)
    attack_info = _get_or_download_attack_info(method, custom_cache_dir, force_download)
    if model not in attack_info.models:
        raise ValueError(f"Model {model.value} not found for method {method}.")
    if attack_type_enum_ is not None:
        if attack_type_enum_ not in attack_info.models[model]:
            raise ValueError(f"Attack type {attack_type_enum_.value} not found for model {model.value}.")
        attack_type_enum = attack_type_enum_
    else:
        if len(attack_info.models[model]) > 1:
            raise ValueError(
                f"Multiple attack types found for model {model.value}. Please specify one of the following: {attack_info.models[model]}."
            )
        attack_type_enum = list(attack_info.models[model]).pop()
    artifact = _get_or_download_artifact(method, model, attack_type_enum, custom_cache_dir, force_download)
    return artifact
