# -*- coding: utf-8 -*-

import re
import hashlib

import json
import torch

from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError

from typing import Tuple, Dict

def check_bf16_support(use_bf16_arg):
    """Check for BF16 support on the system."""
    res = False
    if torch.cuda.is_available() and use_bf16_arg:
        res = torch.cuda.is_bf16_supported()
    return res

def checksum(text: str) -> str:
    return hashlib.sha512(text.encode()).hexdigest()

def extract_templates(tokenizer) -> Tuple[str, str, str]:
    system_prompt, user_prompt, assistant_prompt = map(checksum, [f"{e} prompt" for e in ["system", "user", "assistant"]])
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
        {"role": "assistant", "content": assistant_prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False)
    
    def extract_template(prompt: str) -> str:
        expression = rf'\n(.*?)\s*{prompt}'
        match = re.search(expression, text)
        assert match is not None
        res = match.group(1)
        return f"{res}"
    
    res = map(extract_template, [system_prompt, user_prompt, assistant_prompt])
    return tuple(res)

def create_model_path(args):
    # Sanitize the model ID to replace '/' with '_' to avoid directory path issues
    sanitized_model_id = re.sub(r'[/\\]', '_', args.model.split("/")[-1])
    
    path_parts = [
        sanitized_model_id,
        f"bs={args.batch_size}",
        f"lr={args.learning_rate:.0e}",
        f"collator={args.collator}",
        f"gas={args.gradient_accumulation_steps}",
        f"ms={args.max_steps}",
        f"lrank={args.lora_rank}",
        f"ws={args.warmup_steps}"
    ]
    
    # Combine all parts with underscores
    path = "_".join(path_parts)

    # Sanitise, due to the following:
    # huggingface_hub.utils._validators.HFValidationError: Repo id must [..]
    # path = path.replace('_', '-')
    path = path.replace('=', '_')
    
    return path


def get_best_results(repo_id: str) -> Dict[str, float]:
    best_results = None
    try:
        best_results_path = hf_hub_download(repo_id=repo_id, filename="results.json", repo_type="model")
        
        if best_results_path is not None:
            with open(best_results_path, 'r') as f:
                best_results = json.load(f)
    except (EntryNotFoundError, RepositoryNotFoundError):
        pass
    return best_results
