import json
import logging
import os
import warnings
from pathlib import Path

from dotenv import load_dotenv

from .openmodel import LoRAModelManager
from .utils.configs import (  # NOQA
    ApibenchDataConfig,
    CorporaDataConfig,
    EvalConfig,
    JoinDataConfig,
    MLLMDataConfig,
    # Olympus1DataConfig,
    # Olympus2DataConfig,
)
from .utils.eval_utility import compute_metrics
from .utils.parser import EvalParser
from .utils.prepareDataset import (  # NOQA
    get_prompt,
    get_retriever,
    gorilla_prompt,
    gorilla_prompt_with_retrieval,
    load_dataset_json,
)
from .utils.wandb import WandbLogger

PACKAGE_ROOT = Path(__file__).resolve().parent
PROJECT_ROOT = PACKAGE_ROOT.parent
load_dotenv(PROJECT_ROOT / ".env")

cache_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.hf_cache"))
os.makedirs(cache_dir, exist_ok=True)
os.environ["HF_HOME"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["TOKENIZERS_CACHE"] = cache_dir
os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir


# Suppress all unnecessary logging
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)
logging.getLogger("peft").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")


def llm_responses(
    model: LoRAModelManager, question_jsons: list, eval_config: EvalConfig
):
    """
    Generate LLM responses for a list of question JSONs.
    Args:
        model (LoRAModelManager): The LoRA model manager instance.
        question_jsons (list): List of question JSON objects.
        eval_config (EvalConfig): Evaluation configuration.
    Returns:
        list: List of answer JSON objects.

    """
    prompts = []

    system_prompt = gorilla_prompt

    retriever = None
    if eval_config.corpus and eval_config.retriever:
        system_prompt = gorilla_prompt_with_retrieval
        retriever = get_retriever(eval_config)

    for q_json in question_jsons:
        instruction = q_json.get("instruction", "").strip().replace("\r\n", "\n")
        prompt = get_prompt(
            instruction, system_prompt=system_prompt, retriever=retriever
        )  # retriever can be None

        # Ensure prompt is a string
        if not isinstance(prompt, str):
            raise TypeError(
                f"get_prompt returned {type(prompt)}, expected str. Value: {prompt}"
            )

        prompts.append(prompt)

    print("Example prompt:")
    print(prompts[0] if prompts else "No prompts generated.")

    responses = model.generate_batch_safe(
        prompts,
        do_sample=eval_config.do_sample,
        temperature=eval_config.temperature,
        max_new_tokens=eval_config.max_new_tokens,
        top_p=eval_config.top_p,
        top_k=eval_config.top_k,
        batch_size=eval_config.eval_batch_size,
    )

    # rimuove il prompt dai token generati
    cleaned_responses = [
        r[len(prompt) :].strip() if r.startswith(prompt) else r
        for r, prompt in zip(responses, prompts)
    ]

    # remove eos
    cleaned_responses = [o.split("</s>")[0].strip() for o in cleaned_responses]

    # costruisce la lista di output JSONL
    ans_jsons = []
    for prompt, resp in zip(question_jsons, cleaned_responses):
        ans_jsons.append(
            {
                "questions": prompt["instruction"],
                "response": resp,
                "ground_true": prompt["model_name"],
                "domain_ground_true": prompt["domain"],
            }
        )

    return ans_jsons


def get_dataset_config(experience_name: str):
    """Get dataset config for a given experience name."""
    if experience_name == "apibench":
        return ApibenchDataConfig()
    elif experience_name == "mllm":
        return MLLMDataConfig()
    # elif experience_name == "olympus-1":
    #     return Olympus1DataConfig()
    # elif experience_name == "olympus-2":
    #     return Olympus2DataConfig()
    elif experience_name == "join":
        return JoinDataConfig()
    else:
        raise ValueError(f"Unknown experience name: {experience_name}")


def main():
    eval_config = EvalParser().parse_args()
    print(eval_config)

    # Initialize WandB logger
    wandb_key = os.getenv("WANDB_API_KEY")
    if wandb_key:
        wandb_logger = WandbLogger(wandb_key, eval_config, mode="eval")
    else:
        wandb_logger = None
        print(
            "Warning: WANDB_API_KEY not found in environment variables. Skipping WandB logging."
        )

    lora_paths = [f"./cco/experiments/{adapter}" for adapter in eval_config.lora_adapters]
    model = LoRAModelManager(eval_config, lora_paths=lora_paths)

    # Support both single experience and list of experiences
    experience_names = (
        eval_config.experiences_sequence
        if eval_config.experiences_sequence
        else eval_config.experience_name
    )
    if isinstance(experience_names, str):
        experience_names = [experience_names]

    # Dictionary to store all metrics and answers
    all_metrics = {}
    all_answers = {}

    # Evaluate on each experience
    for exp_name in experience_names:
        print(f"\n{'=' * 80}")
        print(f"Evaluating on experience: {exp_name}")
        print(f"{'=' * 80}\n")

        dataset = get_dataset_config(exp_name)
        dataset_json = load_dataset_json(dataset.test_set)
        corpus = (
            load_dataset_json(CorporaDataConfig().get_corpus_path(eval_config.corpus))
            if eval_config.corpus
            else None
        )
        if not corpus:
            print(
                "No corpus provided for retrieval or evaluation. Using current experience."
            )
            corpus = dataset_json

        answers = llm_responses(model, dataset_json, eval_config)
        metrics = compute_metrics(answers, corpus=corpus)

        print(f"{exp_name} metrics:", metrics)

        # Store with experience name suffix
        for key, value in metrics.items():
            all_metrics[f"{key}_{exp_name}"] = value
        all_answers[exp_name] = answers

        # Log metrics to WandB
        if wandb_logger:
            wandb_logger.log({f"{exp_name}/{k}": v for k, v in metrics.items()})

        # Optionally evaluate on train set to assess overfitting
        if eval_config.eval_on_train:
            dataset_json_train = load_dataset_json(dataset.train_set)
            answers_train = llm_responses(model, dataset_json_train, eval_config)
            train_metrics = compute_metrics(answers_train, dataset=dataset_json_train)

            print(f"{exp_name} train metrics:", train_metrics)

            # Store train metrics with suffix
            for key, value in train_metrics.items():
                all_metrics[f"{key}_{exp_name}_train"] = value
            all_answers[f"{exp_name}_train"] = answers_train

            if wandb_logger:
                wandb_logger.log(
                    {f"{exp_name}/train/{k}": v for k, v in train_metrics.items()}
                )

    # Determine save path (use first experience name if multiple)
    first_exp_name = experience_names[0]
    if eval_config.output_path:
        save_path = f"results/{eval_config.output_path}"
    else:
        if len(eval_config.lora_adapters) > 1:
            save_path = (
                f"results/{first_exp_name}/{eval_config.lora_merging_strategy}/"
                + "_".join(
                    [adapter.replace("/", "-") for adapter in eval_config.lora_adapters]
                )
            )
            if eval_config.weights:
                save_path += "/weights-" + "_".join(
                    [str(w).replace(".", "-") for w in eval_config.weights]
                )
                save_path += f"_density-{eval_config.density}".replace(".", "-")
        else:
            save_path = f"results/{first_exp_name}/{eval_config.lora_adapters[0]}"

    os.makedirs(save_path, exist_ok=True)

    # Save all answers to separate files per experience
    for exp_key, answers in all_answers.items():
        answers_file = f"{save_path}/answers_{exp_key}.jsonl"
        with open(answers_file, "w") as f:
            for line in answers:
                f.write(json.dumps(line) + "\n")
        print(f"Saved answers for {exp_key} to {answers_file}")

    # Save all metrics to a single JSON file with experience suffixes
    all_metrics["Adapter Path"] = f"{eval_config.lora_adapters}"
    if (
        eval_config.lora_merging_strategy in ["ties", "dare_linear"]
        and len(eval_config.lora_adapters) > 1
    ):
        all_metrics["Merge Weights"] = eval_config.weights
        all_metrics["Merge Density"] = eval_config.density

    with open(f"{save_path}/metrics.json", "w") as f:
        json.dump(all_metrics, f, indent=2)

    print(f"\nSaved all metrics to {save_path}/metrics.json")
    print(f"Evaluated experiences: {', '.join(experience_names)}")

    # Finish WandB logging
    if wandb_logger:
        wandb_logger.finish()


if __name__ == "__main__":
    main()
