import hashlib
import json
import os
import time
from pathlib import Path
from huggingface_hub import HfApi

from src.generation_utils import is_adapter_model


def maybe_load_evaluation_report(report_output_path: str | None) -> dict:
    """
    Load the evaluation report from the given path.
    If the file does not exist, return an empty dictionary.

    Args:
        report_output_path: Path to load the evaluation report.

    Returns:
        Evaluation report.
    """
    if report_output_path is None:
        return {}
    if os.path.exists(report_output_path):
        retry_count = 3
        while retry_count > 0:
            try:
                with open(report_output_path, "r") as f:
                    evaluation_report = json.load(f)
                break
            except json.decoder.JSONDecodeError:
                retry_count -= 1
                time.sleep(1)
    else:
        evaluation_report = {}
    return evaluation_report


def save_evaluation_report(report_output_path: str, evaluation_report: dict) -> None:
    """
    Save the evaluation report to the given path.

    Args:
        report_output_path: Path to save the evaluation report.

        evaluation_report: Evaluation report.
    """
    # mkdir if necessary
    report_output_path = os.path.abspath(report_output_path)
    os.makedirs(os.path.dirname(report_output_path), exist_ok=True)
    if os.path.exists(report_output_path):
        existing_report = maybe_load_evaluation_report(report_output_path)
        for k, v in existing_report.items():
            if k in evaluation_report:
                continue
            else:
                evaluation_report[k] = v

    with open(report_output_path, "w") as f:
        json.dump(evaluation_report, f, indent=4)


def maybe_merge_lora_weights_and_save(model_name_or_path: str) -> str | None:
    if is_adapter_model(model_name_or_path):
        from peft import PeftConfig, PeftModel
        from transformers import AutoModelForCausalLM, AutoTokenizer
        peft_config = PeftConfig.from_pretrained(model_name_or_path)
        _tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        base_model = AutoModelForCausalLM.from_pretrained(
            peft_config.base_model_name_or_path,
        )
        model = PeftModel.from_pretrained(
            base_model,
            model_name_or_path,
        )
        model = model.merge_and_unload()
        model.generation_config.do_sample = True
        # save to temporary directory.
        hash_name = int(hashlib.sha1(model_name_or_path.encode("utf-8")).hexdigest(), 16) % (10 ** 8)
        temp_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), ".tmp", str(hash_name))
        model.save_pretrained(temp_dir)
        _tokenizer.save_pretrained(temp_dir)
        del _tokenizer

        if os.path.exists(model_name_or_path):
            tokenizer_paths = Path(model_name_or_path).glob("token*")
            for path in tokenizer_paths:
                path = str(path)
                os.system(f"cp -r {path} {temp_dir}")
    else:
        temp_dir = None

    return temp_dir

def upload_results_to_hf(
        results_dict,
        hf_dataset_name,
        hf_dataset_save_dir,
        task_name=None,
        primary_score=None,
        prepend_timestamp=False,
    ):
    # A bunch of stuff for making the results file setup follow
    # the oe-eval standards
    if task_name is not None:
        results_dict["task_name"] = task_name
    if primary_score is not None:
        results_dict["metrics"] = {}
        results_dict["metrics"]["primary_score"] = primary_score
    if prepend_timestamp:
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        hf_dataset_save_path = f"{hf_dataset_save_dir}/{timestamp}-{task_name}.json"
    else:
        hf_dataset_save_path = f"{hf_dataset_save_dir}/{task_name}.json"
    # actual save and upload
    with open("results.json", "w") as f:
        json.dump(results_dict, f)
    api = HfApi(token=os.getenv("HF_TOKEN", None))
    api.upload_file(
        path_or_fileobj="results.json",
        path_in_repo=hf_dataset_save_path,
        repo_id=hf_dataset_name,
        repo_type="dataset",
    )
    os.remove("results.json")
