import numpy as np

def confusion_matrix(y_true, y_pred):
    num_classes = len(np.unique(y_true))
    cm = np.zeros((num_classes, num_classes), dtype=int)
    for t, p in zip(y_true, y_pred):
        cm[t, p] += 1
    return cm

def classification_metrics(y_pred, y_true):
    y_pred = y_pred.max(1)[1].detach().cpu().numpy()
    y_true = y_true.detach().cpu().numpy()
    
    cm = confusion_matrix(y_true, y_pred)  # shape: (num_classes, num_classes)
    num_classes = cm.shape[0]

    precisions = []
    recalls = []
    f1_scores = []
    specificities = []

    total_correct = np.trace(cm)  
    total_samples = np.sum(cm)  
    macro_accuracy = total_correct / total_samples

    for i in range(num_classes):
        tp = cm[i, i]  # True Positive
        fp = np.sum(cm[:, i]) - tp  # False Positive
        fn = np.sum(cm[i, :]) - tp  # False Negative
        tn = total_samples - (tp + fp + fn)  # True Negative

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

        precisions.append(precision)
        specificities.append(specificity)
        recalls.append(recall)
        f1_scores.append(f1)

    macro_precision = np.mean(precisions)
    macro_specificity = np.mean(specificities)
    macro_recall = np.mean(recalls)
    macro_f1 = np.mean(f1_scores)

    return macro_accuracy, macro_precision, macro_specificity, macro_recall, macro_f1


def macro_specificity(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)
    num_classes = cm.shape[0]

    specificities = []
    
    for i in range(num_classes):
        tn = np.sum(np.delete(np.delete(cm, i, axis=0), i, axis=1))
        fp = np.sum(np.delete(cm[:, i], i))
        
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        specificities.append(specificity)

    return np.mean(specificities)

def macro_accuracy(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)  # shape: (num_classes, num_classes)

    total_correct = np.trace(cm) 
    total_samples = np.sum(cm) 
    macro_accuracy = total_correct / total_samples
    
    return macro_accuracy
