from typing import Optional, Tuple, List, Dict
import logging
import os
import json
import torch
import torch.nn.functional as F
import numpy as np

def load_id2label(path: str) -> Dict[str, int]:
    """
    Load id2label from JSON and create label2id. Validate against num_classes.

    Args:
        path (str): Path to id2label.json file.

    Returns:
        Dict[int, str]: id2label
    """
    with open(path, "r") as f:
        id2label = json.load(f)

    if not isinstance(id2label, dict):
        raise ValueError(f"Expected dict in {path}, got {type(id2label)}")
    return id2label


def save_id2label(id2label: Dict[int, str], dir: str, overwrite: bool = True) -> None:
    """
    Save id2label mapping to JSON file.

    If `overwrite` is False and the file already exists and is identical, skip saving.
    If `overwrite` is True and a different file exists, back it up.
    """
    path = os.path.join(dir, 'id2label.json')
    if os.path.exists(path):
        with open(path, "r") as f:
            existing = json.load(f)
        if existing == id2label:
            logging.info(f"Existing id2label at {path} is identical. Skipping save.")
            return
        elif not overwrite:
            logging.info(f"Not overwriting existing id2label at {path}")
            return
        else:
            # Backup
            backup_path = path.replace(".json", ".bak.json")
            logging.warning(f"Backing up existing id2label to {backup_path}")
            os.rename(path, backup_path)

    with open(path, "w") as f:
        json.dump(id2label, f, indent=2, ensure_ascii=False)
    logging.info(f"Saved id2label to {path}")
    

def tag2multihot(tag_strings, label2id):
    # input: ['sand;rub', 'butterfly']
    # output: torch.tensor([[1,1,0], [0,0,1]])
    num_classes = len(label2id)
    multihot = torch.zeros((len(tag_strings), num_classes), dtype=torch.float32)

    for i, tag_str in enumerate(tag_strings):
        tags = tag_str.split(";")
        for tag in tags:
            multihot[i, int(label2id[tag])] = 1.0
    return multihot

def compute_acc(logits, targets, prob_type='auto'):
    if prob_type == "auto":
        # Multi-hot if any row has more than one non-zero entry
        max_labels_per_sample = targets.sum(dim=1).max().item()
        prob_type = "sigmoid" if max_labels_per_sample > 1 else "softmax"
    
    if prob_type == 'softmax':
        probs = F.softmax(logits, dim=1)
    elif prob_type == 'sigmoid':
        probs = F.sigmoid(logits)
        
    batch_size, num_classes = targets.shape

    # Top-1 Accuracy
    top1_preds = torch.argmax(probs, dim=1)
    top1_correct = targets[torch.arange(batch_size), top1_preds].float()
    top1_acc = top1_correct.mean().item() * 100

    # Top-5 Accuracy
    top5_preds = torch.topk(probs, k=5, dim=1).indices
    top5_correct = torch.gather(targets, 1, top5_preds)
    top5_acc = (top5_correct.sum(dim=1) > 0).float().mean().item() * 100
    
    return top1_acc, top5_acc

def compute_weighted_acc(preds, labels, id2label):
    import mir_eval
    assert preds.shape == labels.shape
    correct = preds == labels
    accuracy = correct.astype(np.float32).mean()

    scores = [
        mir_eval.key.weighted_score(
            id2label[str(ref_key)], id2label[str(est_key)]
        )
        for ref_key, est_key in zip(labels, preds)
    ]

    return accuracy, np.mean(scores)