"""metrics script to generate a comprehensive classification report
with metrics, visualizations, and system information.
"""
__author__ = 'XYZ'


import json
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn import metrics
from sklearn.metrics import (
  confusion_matrix,
  precision_recall_curve,
  precision_recall_fscore_support,
  accuracy_score,
)

try:
  import torch
  import torch.nn.functional as F
  import torch.multiprocessing as mp
  mp.set_sharing_strategy('file_system')
except ImportError:
  print('torchvision is not installed')


from ..core._log_ import logger
log = logger(__file__)

from ..core import encoders


def calculate_fp_fn(cm, predictions, labels):
  """
  Calculate TP, FP, FN, and TN correctly.
  Returns:
    - tp, fp, fn, tn: Class-wise counts
    - fp_indices, fn_indices, tp_indices, tn_indices: Index mappings
  """
  num_classes = len(cm)
  
  tp = np.diag(cm)  # True Positives
  fp = cm.sum(axis=0) - tp  # False Positives
  fn = cm.sum(axis=1) - tp  # False Negatives
  tn = np.zeros(num_classes, dtype=int)  # Initialize TN

  # Create mappings for FP, FN, TP, and TN indices
  fp_indices = {cls: [] for cls in range(num_classes)}
  fn_indices = {cls: [] for cls in range(num_classes)}
  tp_indices = {cls: [] for cls in range(num_classes)}
  tn_indices = {cls: [] for cls in range(num_classes)}

  for idx, (pred, true_label) in enumerate(zip(predictions, labels)):
    if pred == true_label:
      tp_indices[true_label].append(idx)  # True Positive
    else:
      fp_indices[pred].append(idx)  # False Positive
      fn_indices[true_label].append(idx)  # False Negative

  # Compute TN correctly
  for class_id in range(num_classes):
    for idx in range(len(predictions)):
      pred, true_label = predictions[idx], labels[idx]

      # TN occurs when the true class is NOT the current class AND it was NOT predicted as that class
      if true_label != class_id and pred != class_id:
        tn[class_id] += 1
        tn_indices[class_id].append(idx)

  return tp, fp, fn, tn, fp_indices, fn_indices, tp_indices, tn_indices


def save_confusion_matrix_and_metrics(logits, labels, class_names, output_dir):
  """
  Save a confusion matrix with scores (normalized), counts, and include FP and FN as part of the matrix,
  with Predicted on the y-axis and True Label on the x-axis (only for visualization). Metrics calculations
  are performed on the original confusion matrix.
  """
  ## Compute confusion matrix (original format)
  predictions = torch.argmax(logits, dim=1).cpu().numpy()
  labels = labels.cpu().numpy()
  cm = confusion_matrix(labels, predictions, labels=range(len(class_names)))

  ## Calculate FP, FN, TP, and TN from original matrix
  tp, fp, fn, tn, fp_indices, fn_indices, tp_indices, tn_indices = calculate_fp_fn(cm)
  tn = cm.sum() - (tp + fp + fn)

  ## For visualization: Extend confusion matrix and flip axes
  extended_cm = np.zeros((cm.shape[0] + 1, cm.shape[1] + 1))
  extended_cm[:-1, :-1] = cm  # Original confusion matrix
  extended_cm[:-1, -1] = fn  # FN column
  extended_cm[-1, :-1] = fp  # FP row
  extended_cm_flipped = extended_cm.T  # Transpose for visualization

  ## Normalize the extended confusion matrix for visualization
  normalized_cm = extended_cm_flipped.astype('float') / (extended_cm_flipped.sum(axis=1, keepdims=True) + 1e-6)

  ## Update class names to include FP and FN
  extended_class_names = class_names + ["FP"]
  extended_pred_names = class_names + ["FN"]

  ## Annotate the heatmap with scores and counts
  annot = np.empty_like(extended_cm_flipped, dtype=object)
  for i in range(extended_cm_flipped.shape[0]):
    for j in range(extended_cm_flipped.shape[1]):
      count = extended_cm_flipped[i, j]
      score = normalized_cm[i, j]
      annot[i, j] = f"{score:.2f}\n({int(count)})"

  ## Plot the confusion matrix
  plt.figure(figsize=(14, 12))
  sns.heatmap(normalized_cm, annot=annot, fmt='', cmap="Blues", xticklabels=extended_class_names, yticklabels=extended_pred_names, cbar_kws={'label': 'Score'})
  plt.xlabel("True Label")
  plt.ylabel("Predicted Label")
  plt.xticks(rotation=45, ha='right')
  plt.tight_layout()

  ## Save the confusion matrix as an image
  cm_path = os.path.join(output_dir, "confusion_matrix_with_fp_fn_swapped_axes.png")
  plt.savefig(cm_path, bbox_inches="tight")
  plt.close()
  log.info(f"Confusion matrix saved to {cm_path}")

  ## Save FP/FN details as a text file
  fp_fn_report_path = os.path.join(output_dir, "fp_fn_report.txt")
  with open(fp_fn_report_path, 'w') as f:
    for idx, class_name in enumerate(class_names):
      f.write(f"Class: {class_name}, FP: {fp[idx]}, FN: {fn[idx]}\n")
  log.info(f"FP and FN report saved to {fp_fn_report_path}")

  ## Calculate total samples and predictions
  total_samples = labels.shape[0]
  total_predictions = predictions.shape[0]

  ## Aggregated metrics
  accuracy = round(accuracy_score(labels, predictions), 4)
  precision, recall, f1score, support = precision_recall_fscore_support(
    labels,
    predictions,
    labels=list(range(len(class_names))),  # force full coverage
    average="weighted",
    zero_division=0
  )

  ## Validate support and handle None cases
  support_sum = support.sum() if support is not None else len(labels)
  aggregated_metrics = {
    "accuracy": round(accuracy, 4),
    "precision": round(float(precision), 4),
    "recall": round(float(recall), 4),
    "f1score": round(float(f1score), 4),
    "support": int(support_sum),
  }

  ## Per-class metrics
  per_class_precision, per_class_recall, per_class_f1score, per_class_support = precision_recall_fscore_support(
    labels,
    predictions,
    labels=list(range(len(class_names))),  # force full coverage
    average=None,
    zero_division=0
  )

  # per_class_metrics = {}
  # for i, class_name in enumerate(class_names):
  #   per_class_metrics[class_name] = {
  #     "precision": round(per_class_precision[i], 4),
  #     "recall": round(per_class_recall[i], 4),
  #     "f1score": round(per_class_f1score[i], 4),
  #     "support": round(per_class_support[i], 4) if per_class_support is not None else 0,
  #     "accuracy": round(tp[i] / (tp[i] + fp[i] + fn[i] + 1e-6), 4),
  #     "true_positive": round(tp[i], 4),
  #     "false_positive": round(fp[i], 4),
  #     "false_negative": round(fn[i], 4),
  #     "true_negative": round(tn[i], 4),
  #   }


  per_class_metrics = {
    "precision": [],
    "recall": [],
    "f1score": [],
    "support": [],
    "accuracy": [],
    "true_positive": [],
    "false_positive": [],
    "false_negative": [],
    "true_negative": []
  }

  for i, _ in enumerate(class_names):
    per_class_metrics["precision"].append(round(per_class_precision[i], 4))
    per_class_metrics["recall"].append(round(per_class_recall[i], 4))
    per_class_metrics["f1score"].append(round(per_class_f1score[i], 4))
    per_class_metrics["support"].append(round(per_class_support[i], 4) if per_class_support is not None else 0)
    per_class_metrics["accuracy"].append(round(tp[i] / (tp[i] + fp[i] + fn[i] + 1e-6), 4))
    per_class_metrics["true_positive"].append(round(tp[i], 4))
    per_class_metrics["false_positive"].append(round(fp[i], 4))
    per_class_metrics["false_negative"].append(round(fn[i], 4))
    per_class_metrics["true_negative"].append(round(tn[i], 4))

  ## Overall counts
  counts = {
    "true_positive": int(tp.sum()),
    "false_positive": int(fp.sum()),
    "false_negative": int(fn.sum()),
    "true_negative": int(tn.sum()),
  }

  ## Combine all metrics into a dictionary
  metrics_data = {
    "num_classes": len(class_names),
    "labels": list(range(len(class_names))),
    "class_names": class_names,
    "aggregated": aggregated_metrics,
    "perclass": per_class_metrics,
    "counts": counts,
    "total_samples": total_samples,
    "total_predictions": total_predictions,
  }

  ## Save metrics to a JSON file
  metrics_path = os.path.join(output_dir, "classification_metrics.json")
  serialized_metrics_data = encoders.numpy_to_json(metrics_data)
  with open(metrics_path, "w") as f:
    json.dump(serialized_metrics_data, f, indent=4)
  log.info(f"Metrics saved to {metrics_path}")

  return metrics_data


def save_confusion_matrix_with_fp_fn_included_swapped_axes(logits, labels, class_names, output_dir):
  """
  Save a confusion matrix with scores (normalized), counts, and include FP and FN as part of the matrix,
  with Predicted on the y-axis and True Label on the x-axis.
  """
  # Compute confusion matrix
  predictions = torch.argmax(logits, dim=1).cpu().numpy()
  labels = labels.cpu().numpy()
  cm = confusion_matrix(labels, predictions, labels=range(len(class_names)))

  # Calculate FP and FN
  tp, fp, fn, tn, fp_indices, fn_indices, tp_indices, tn_indices = calculate_fp_fn(cm)

  # Extend confusion matrix to include FP and FN
  extended_cm = np.zeros((cm.shape[0] + 1, cm.shape[1] + 1))
  extended_cm[:-1, :-1] = cm  # Original confusion matrix
  extended_cm[:-1, -1] = fn  # FN column
  extended_cm[-1, :-1] = fp  # FP row

  # Transpose the matrix to swap axes
  extended_cm = extended_cm.T

  # Normalize the extended confusion matrix
  normalized_cm = extended_cm.astype('float') / (extended_cm.sum(axis=1, keepdims=True) + 1e-6)

  # Update class names to include FP and FN
  extended_class_names = class_names + ["FP"]  # Add "FP" to the bottom
  extended_pred_names = class_names + ["FN"]  # Add "FN" to the right

  # Annotate the heatmap with scores and counts
  annot = np.empty_like(extended_cm, dtype=object)
  for i in range(extended_cm.shape[0]):
    for j in range(extended_cm.shape[1]):
      count = extended_cm[i, j]
      score = normalized_cm[i, j]
      annot[i, j] = f"{score:.2f}\n({int(count)})"

  # Plot the confusion matrix
  plt.figure(figsize=(14, 12))
  sns.heatmap(normalized_cm, annot=annot, fmt='', cmap="Blues", xticklabels=extended_class_names, yticklabels=extended_pred_names, cbar_kws={'label': 'Score'})
  
  # Label axes and adjust layout
  plt.xlabel("True Label")
  plt.ylabel("Predicted Label")
  # plt.title("Confusion Matrix with Scores, Counts, FP, FN (Swapped Axes)")
  plt.xticks(rotation=45, ha='right')
  plt.tight_layout()

  # Save the confusion matrix
  cm_path = os.path.join(output_dir, "confusion_matrix_with_fp_fn_swapped_axes.png")
  plt.savefig(cm_path, bbox_inches="tight")
  plt.close()

  # Log FP and FN for further analysis
  fp_fn_report = os.path.join(output_dir, "fp_fn_report_swapped_axes.txt")
  with open(fp_fn_report, 'w') as f:
    for idx, class_name in enumerate(class_names):
      f.write(f"Class: {class_name}, FP: {fp[idx]}, FN: {fn[idx]}\n")
  log.info(f"Confusion matrix with FP and FN included (swapped axes) saved to {cm_path}")
  log.info(f"FP and FN report saved to {fp_fn_report}")


def save_confusion_matrix_with_fp_fn_imagelist(logits, labels, class_names, output_dir, dataset):
  """
  Save a confusion matrix with scores (normalized), counts, and include FP, FN, TP, TN as separate CSV files.
  """
  # Compute confusion matrix
  predictions = torch.argmax(logits, dim=1).cpu().numpy()
  labels = labels.cpu().numpy()
  cm = confusion_matrix(labels, predictions, labels=range(len(class_names)))

  # Calculate TP, FP, FN, TN with indices
  tp, fp, fn, tn, fp_indices, fn_indices, tp_indices, tn_indices = calculate_fp_fn(cm, predictions, labels)

  # Prepare lists to store FP, FN, TP, TN data
  fp_data = []
  fn_data = []
  tp_data = []
  tn_data = []

  # Ensure label mapping is correct
  image_paths = dataset.imgs

  ## ===================
  ## Process False Positives (FP) - Only if `pred != gt`
  ## ===================
  for class_id, indices in fp_indices.items():
    for idx in indices:
      image_path = image_paths[idx]

      # Fetch correct ground truth label from dataset
      gt_labelid = dataset.labels[idx]
      gt_class_name = dataset.index_to_label.get(gt_labelid, "Unknown")

      # **Ensure FP case is valid (i.e., pred ≠ gt)**
      if class_id != gt_labelid:
        fp_data.append({
          "pred_labelid": class_id,
          "pred_class_name": class_names[class_id],
          "gt_labelid": gt_labelid,
          "gt_class_name": gt_class_name,
          "filepath": image_path
        })

  ## ===================
  ## Process False Negatives (FN) - Model fails to predict correctly
  ## ===================
  for class_id, indices in fn_indices.items():
    for idx in indices:
      image_path = image_paths[idx]

      gt_labelid = dataset.labels[idx]
      gt_class_name = dataset.index_to_label.get(gt_labelid, "Unknown")

      fn_data.append({
        "pred_labelid": "",  # FN has no valid prediction
        "pred_class_name": "",
        "gt_labelid": gt_labelid,
        "gt_class_name": gt_class_name,
        "filepath": image_path
      })

  ## ===================
  ## Process True Positives (TP) - Only when `pred == gt`
  ## ===================
  for class_id, indices in tp_indices.items():
    for idx in indices:
      image_path = image_paths[idx]

      gt_labelid = dataset.labels[idx]
      gt_class_name = dataset.index_to_label.get(gt_labelid, "Unknown")

      tp_data.append({
        "pred_labelid": class_id,
        "pred_class_name": class_names[class_id],
        "gt_labelid": gt_labelid,
        "gt_class_name": gt_class_name,
        "filepath": image_path
      })

  ## ===================
  ## Process True Negatives (TN) - Images correctly NOT predicted as a class
  ## ===================
  for class_id, indices in tn_indices.items():
    for idx in indices:
      image_path = image_paths[idx]

      gt_labelid = dataset.labels[idx]
      gt_class_name = dataset.index_to_label.get(gt_labelid, "Unknown")

      tn_data.append({
        "pred_labelid": "",
        "pred_class_name": "",
        "gt_labelid": gt_labelid,
        "gt_class_name": gt_class_name,
        "filepath": image_path
      })

  ## Debugging: Log counts before saving CSVs
  log.info(f"FP count: {len(fp_data)}, FN count: {len(fn_data)}, TP count: {len(tp_data)}, TN count: {len(tn_data)}")

  ## ===================
  ## Save to CSV files
  ## ===================
  if fp_data:
    fp_df = pd.DataFrame(fp_data)
    fp_df.to_csv(os.path.join(output_dir, "fp.csv"), index=False)
  else:
    log.warning("No False Positives detected. 'fp.csv' was not generated.")

  if fn_data:
    fn_df = pd.DataFrame(fn_data)
    fn_df.to_csv(os.path.join(output_dir, "fn.csv"), index=False)
  else:
    log.warning("No False Negatives detected. 'fn.csv' was not generated.")

  if tp_data:
    tp_df = pd.DataFrame(tp_data)
    tp_df.to_csv(os.path.join(output_dir, "tp.csv"), index=False)
  else:
    log.warning("No True Positives detected. 'tp.csv' was not generated.")

  if tn_data:
    tn_df = pd.DataFrame(tn_data)
    tn_df.to_csv(os.path.join(output_dir, "tn.csv"), index=False)
  else:
    log.warning("No True Negatives detected. 'tn.csv' was not generated.")

  # ============================
  # Generate Confusion Matrix Image
  # ============================

  # Extend confusion matrix with FP and FN
  extended_cm = np.zeros((cm.shape[0] + 1, cm.shape[1] + 1))
  extended_cm[:-1, :-1] = cm  # Original confusion matrix
  extended_cm[:-1, -1] = fn   # FN column
  extended_cm[-1, :-1] = fp   # FP row

  # Normalize the confusion matrix
  normalized_cm = extended_cm.astype('float') / (extended_cm.sum(axis=1, keepdims=True) + 1e-6)

  # Extend class names
  extended_class_names = class_names + ["FN"]
  extended_pred_names = class_names + ["FP"]

  # Annotate matrix with scores and counts
  annot = np.empty_like(extended_cm, dtype=object)
  for i in range(extended_cm.shape[0]):
    for j in range(extended_cm.shape[1]):
      count = extended_cm[i, j]
      score = normalized_cm[i, j]
      annot[i, j] = f"{score:.2f}\n({int(count)})"

  ## Plot the confusion matrix
  plt.figure(figsize=(14, 12))
  sns.heatmap(normalized_cm, annot=annot, fmt='', cmap="Blues",
              xticklabels=extended_pred_names, yticklabels=extended_class_names, cbar_kws={'label': 'Score'})

  ## Label axes and adjust layout
  plt.xlabel("Predicted Label")
  plt.ylabel("True Label")
  # plt.title("Confusion Matrix with Scores, Counts, FP, FN, TN")
  plt.xticks(rotation=45, ha='right')
  plt.tight_layout()

  ## Save the confusion matrix visualization
  cm_path = os.path.join(output_dir, "confusion_matrix_with_fp_fn_gradient.png")
  plt.savefig(cm_path, bbox_inches="tight")
  plt.close()

  ## Calculate total samples and predictions
  total_samples = labels.shape[0]
  total_predictions = predictions.shape[0]

  ## Aggregated metrics
  accuracy = round(accuracy_score(labels, predictions), 4)
  precision, recall, f1score, support = precision_recall_fscore_support(
    labels,
    predictions,
    labels=list(range(len(class_names))),  # force full coverage
    average="weighted",
    zero_division=0
  )

  ## Validate support and handle None cases
  support_sum = support.sum() if support is not None else len(labels)
  aggregated_metrics = {
    "accuracy": round(accuracy, 4),
    "precision": round(float(precision), 4),
    "recall": round(float(recall), 4),
    "f1score": round(float(f1score), 4),
    "support": int(support_sum),
  }

  ## Per-class metrics
  per_class_precision, per_class_recall, per_class_f1score, per_class_support = precision_recall_fscore_support(
    labels,
    predictions,
    labels=list(range(len(class_names))),  # force full coverage
    average=None,
    zero_division=0
  )

  # per_class_metrics = {}
  # for i, class_name in enumerate(class_names):
  #   per_class_metrics[class_name] = {
  #     "precision": round(per_class_precision[i], 4),
  #     "recall": round(per_class_recall[i], 4),
  #     "f1score": round(per_class_f1score[i], 4),
  #     "support": round(per_class_support[i], 4) if per_class_support is not None else 0,
  #     "accuracy": round(tp[i] / (tp[i] + fp[i] + fn[i] + 1e-6), 4),
  #     "true_positive": round(tp[i], 4),
  #     "false_positive": round(fp[i], 4),
  #     "false_negative": round(fn[i], 4),
  #     "true_negative": round(tn[i], 4),
  #   }

  per_class_metrics = {
    "precision": [],
    "recall": [],
    "f1score": [],
    "support": [],
    "accuracy": [],
    "true_positive": [],
    "false_positive": [],
    "false_negative": [],
    "true_negative": []
  }
  
  for i, _ in enumerate(class_names):
    per_class_metrics["precision"].append(round(per_class_precision[i], 4))
    per_class_metrics["recall"].append(round(per_class_recall[i], 4))
    per_class_metrics["f1score"].append(round(per_class_f1score[i], 4))
    per_class_metrics["support"].append(round(per_class_support[i], 4) if per_class_support is not None else 0)
    per_class_metrics["accuracy"].append(round(tp[i] / (tp[i] + fp[i] + fn[i] + 1e-6), 4))
    per_class_metrics["true_positive"].append(round(tp[i], 4))
    per_class_metrics["false_positive"].append(round(fp[i], 4))
    per_class_metrics["false_negative"].append(round(fn[i], 4))
    per_class_metrics["true_negative"].append(round(tn[i], 4))

  ## Overall counts
  counts = {
    "true_positive": int(tp.sum()),
    "false_positive": int(fp.sum()),
    "false_negative": int(fn.sum()),
    "true_negative": int(tn.sum()),
  }

  ## Combine all metrics into a dictionary
  metrics_data = {
    "num_classes": len(class_names),
    "labels": list(range(len(class_names))),
    "class_names": class_names,
    "aggregated": aggregated_metrics,
    "perclass": per_class_metrics,
    "counts": counts,
    "total_samples": total_samples,
    "total_predictions": total_predictions,
  }

  return metrics_data


def save_confusion_matrix_with_fp_fn_included(logits, labels, class_names, output_dir):
  """
  Save a confusion matrix with scores (normalized), counts, and include FP and FN as part of the matrix.
  """
  ## Compute confusion matrix
  predictions = torch.argmax(logits, dim=1).cpu().numpy()
  labels = labels.cpu().numpy()
  cm = confusion_matrix(labels, predictions, labels=range(len(class_names)))

  ## Calculate FP and FN
  tp, fp, fn, tn, fp_indices, fn_indices, tp_indices, tn_indices = calculate_fp_fn(cm, predictions, labels)
  tn = cm.sum() - (tp + fp + fn)

  ## Extend confusion matrix to include FP and FN
  extended_cm = np.zeros((cm.shape[0] + 1, cm.shape[1] + 1))
  extended_cm[:-1, :-1] = cm  ## Original confusion matrix
  extended_cm[:-1, -1] = fn  ## FN column
  extended_cm[-1, :-1] = fp  ## FP row

  ## Normalize the extended confusion matrix
  normalized_cm = extended_cm.astype('float') / (extended_cm.sum(axis=1, keepdims=True) + 1e-6)

  ## Update class names to include FP and FN
  extended_class_names = class_names + ["FN"]  ## Add "FN" to the bottom
  extended_pred_names = class_names + ["FP"]  ## Add "FP" to the right

  ## Annotate the heatmap with scores and counts
  annot = np.empty_like(extended_cm, dtype=object)
  for i in range(extended_cm.shape[0]):
    for j in range(extended_cm.shape[1]):
      count = extended_cm[i, j]
      score = normalized_cm[i, j]
      annot[i, j] = f"{score:.2f}\n({int(count)})"

  ## Plot the confusion matrix
  plt.figure(figsize=(14, 12))
  sns.heatmap(normalized_cm, annot=annot, fmt='', cmap="Blues", xticklabels=extended_pred_names, yticklabels=extended_class_names, cbar_kws={'label': 'Score'})
  
  ## Label axes and adjust layout
  plt.xlabel("Predicted Label")
  plt.ylabel("True Label")
  plt.title("Confusion Matrix with Scores, Counts, FP, FN")
  plt.xticks(rotation=45, ha='right')
  plt.tight_layout()

  ## Save the confusion matrix
  cm_path = os.path.join(output_dir, "confusion_matrix_with_fp_fn_gradient.png")
  plt.savefig(cm_path, bbox_inches="tight")
  plt.close()

  ## Log FP and FN for further analysis
  fp_fn_report = os.path.join(output_dir, "fp_fn_report.txt")
  with open(fp_fn_report, 'w') as f:
    for idx, class_name in enumerate(class_names):
      f.write(f"Class: {class_name}, FP: {fp[idx]}, FN: {fn[idx]}\n")
  log.info(f"Confusion matrix with FP and FN included saved to {cm_path}")
  log.info(f"FP and FN report saved to {fp_fn_report}")


  ## Calculate total samples and predictions
  total_samples = labels.shape[0]
  total_predictions = predictions.shape[0]

  ## Aggregated metrics
  accuracy = round(accuracy_score(labels, predictions), 4)
  precision, recall, f1score, support = precision_recall_fscore_support(
    labels,
    predictions,
    labels=list(range(len(class_names))),  # force full coverage
    average="weighted",
    zero_division=0
  )

  ## Validate support and handle None cases
  support_sum = support.sum() if support is not None else len(labels)
  aggregated_metrics = {
    "accuracy": round(accuracy, 4),
    "precision": round(float(precision), 4),
    "recall": round(float(recall), 4),
    "f1score": round(float(f1score), 4),
    "support": int(support_sum),
  }

  ## Per-class metrics
  per_class_precision, per_class_recall, per_class_f1score, per_class_support = precision_recall_fscore_support(
    labels,
    predictions,
    labels=list(range(len(class_names))),  # force full coverage
    average=None,
    zero_division=0
  )

  # per_class_metrics = {}
  # for i, class_name in enumerate(class_names):
  #   per_class_metrics[class_name] = {
  #     "precision": round(per_class_precision[i], 4),
  #     "recall": round(per_class_recall[i], 4),
  #     "f1score": round(per_class_f1score[i], 4),
  #     "support": round(per_class_support[i], 4) if per_class_support is not None else 0,
  #     "accuracy": round(tp[i] / (tp[i] + fp[i] + fn[i] + 1e-6), 4),
  #     "true_positive": round(tp[i], 4),
  #     "false_positive": round(fp[i], 4),
  #     "false_negative": round(fn[i], 4),
  #     "true_negative": round(tn[i], 4),
  #   }

  per_class_metrics = {
    "precision": [],
    "recall": [],
    "f1score": [],
    "support": [],
    "accuracy": [],
    "true_positive": [],
    "false_positive": [],
    "false_negative": [],
    "true_negative": []
  }
  
  for i, _ in enumerate(class_names):
    per_class_metrics["precision"].append(round(per_class_precision[i], 4))
    per_class_metrics["recall"].append(round(per_class_recall[i], 4))
    per_class_metrics["f1score"].append(round(per_class_f1score[i], 4))
    per_class_metrics["support"].append(round(per_class_support[i], 4) if per_class_support is not None else 0)
    per_class_metrics["accuracy"].append(round(tp[i] / (tp[i] + fp[i] + fn[i] + 1e-6), 4))
    per_class_metrics["true_positive"].append(round(tp[i], 4))
    per_class_metrics["false_positive"].append(round(fp[i], 4))
    per_class_metrics["false_negative"].append(round(fn[i], 4))
    per_class_metrics["true_negative"].append(round(tn[i], 4))

  ## Overall counts
  counts = {
    "true_positive": int(tp.sum()),
    "false_positive": int(fp.sum()),
    "false_negative": int(fn.sum()),
    "true_negative": int(tn.sum()),
  }

  ## Combine all metrics into a dictionary
  metrics_data = {
    "num_classes": len(class_names),
    "labels": list(range(len(class_names))),
    "class_names": class_names,
    "aggregated": aggregated_metrics,
    "perclass": per_class_metrics,
    "counts": counts,
    "total_samples": total_samples,
    "total_predictions": total_predictions,
  }

  ## Save metrics to a JSON file
  metrics_path = os.path.join(output_dir, "classification_metrics.json")
  serialized_metrics_data = encoders.numpy_to_json(metrics_data)
  with open(metrics_path, "w") as f:
    f.write(serialized_metrics_data)
  log.info(f"Metrics saved to {metrics_path}")

  return metrics_data


def save_confusion_matrix_with_fp_fn_styled(logits, labels, class_names, output_dir):
  """
  Save a confusion matrix with scores (normalized), counts (smaller font and different color),
  and include FP and FN annotations in the matrix.
  """
  # Compute confusion matrix
  predictions = torch.argmax(logits, dim=1).cpu().numpy()
  labels = labels.cpu().numpy()
  cm = confusion_matrix(labels, predictions, labels=range(len(class_names)))

  # Calculate TP, FP, FN
  tp, fp, fn, tn, fp_indices, fn_indices, tp_indices, tn_indices = calculate_fp_fn(cm, predictions, labels)

  # Normalize the confusion matrix for scores
  cm_normalized = cm.astype('float') / (cm.sum(axis=1, keepdims=True) + 1e-6)

  # Plot the confusion matrix
  plt.figure(figsize=(12, 10))
  ax = sns.heatmap(cm_normalized, annot=False, fmt='', cmap="Blues", xticklabels=class_names, yticklabels=class_names, cbar_kws={'label': 'Score'})

  # Annotate the heatmap with counts and scores
  for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
      score = cm_normalized[i, j]
      count = cm[i, j]
      # Score annotation (larger font)
      ax.text(j + 0.5, i + 0.5, f"{score:.2f}", 
              ha="center", va="center", fontsize=10, color="black", weight="bold")
      # Count annotation (smaller font, different color)
      ax.text(j + 0.5, i + 0.5, f"\n({count})", 
              ha="center", va="center", fontsize=8, color="gray")

  # Add FP and FN annotations in the margins
  for i, class_name in enumerate(class_names):
    # FP annotations in the right margin
    ax.text(cm.shape[1] + 0.5, i + 0.5, f"FP: {fp[i]}", va='center', ha='left', fontsize=10, color='red')
    # FN annotations in the bottom margin
    ax.text(i + 0.5, cm.shape[0] + 0.5, f"FN: {fn[i]}", va='center', ha='center', fontsize=10, color='red')

  # Label axes and adjust layout
  plt.xlabel("Predicted Label")
  plt.ylabel("True Label")
  plt.title("Confusion Matrix with Scores, Counts, FP, FN")
  plt.xticks(rotation=45, ha='right')
  plt.tight_layout()

  # Save the confusion matrix
  cm_path = os.path.join(output_dir, "confusion_matrix_with_fp_fn_styled.png")
  plt.savefig(cm_path, bbox_inches="tight")
  plt.close()
  
  # Log FP and FN for further analysis
  fp_fn_report = os.path.join(output_dir, "fp_fn_report_0.txt")
  with open(fp_fn_report, 'w') as f:
    for idx, class_name in enumerate(class_names):
      f.write(f"Class: {class_name}, FP: {fp[idx]}, FN: {fn[idx]}\n")
  log.info(f"Confusion matrix with styled counts saved to {cm_path}")
  log.info(f"FP and FN report saved to {fp_fn_report}")


def save_confusion_matrix_with_fp_fn(logits, labels, class_names, output_dir):
  """
  Save a confusion matrix with scores (normalized), counts, and include FP and FN annotations in the matrix.
  """
  # Compute confusion matrix
  predictions = torch.argmax(logits, dim=1).cpu().numpy()
  labels = labels.cpu().numpy()
  cm = confusion_matrix(labels, predictions, labels=range(len(class_names)))

  # Calculate TP, FP, FN
  tp, fp, fn, tn, fp_indices, fn_indices, tp_indices, tn_indices = calculate_fp_fn(cm, predictions, labels)

  # Normalize the confusion matrix for scores
  cm_normalized = cm.astype('float') / (cm.sum(axis=1, keepdims=True) + 1e-6)

  # Annotate the confusion matrix with counts, scores, FP, FN
  annot = np.empty_like(cm, dtype=object)
  for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
      count = cm[i, j]
      score = cm_normalized[i, j]
      if i == j:  # True Positive cells
        annot[i, j] = f"TP: {count}\n{score:.2f}"
      else:
        annot[i, j] = f"{count}\n{score:.2f}"
  
  # Add FP and FN annotations in the margins
  fig, ax = plt.subplots(figsize=(12, 10))
  sns.heatmap(cm_normalized, annot=annot, fmt='', cmap="Blues", xticklabels=class_names, yticklabels=class_names, cbar_kws={'label': 'Score'})

  # Add FP (column-wise sum minus diagonal) and FN (row-wise sum minus diagonal)
  for i, class_name in enumerate(class_names):
    ax.text(cm.shape[1] + 0.5, i + 0.5, f"FP: {fp[i]}", va='center', ha='left', fontsize=10, color='red')
    ax.text(i + 0.5, cm.shape[0] + 0.5, f"FN: {fn[i]}", va='center', ha='center', fontsize=10, color='red')

  # Label axes and adjust layout
  plt.xlabel("Predicted Label")
  plt.ylabel("True Label")
  plt.title("Confusion Matrix with Scores, FP, FN")
  plt.xticks(rotation=45, ha='right')
  plt.tight_layout()

  # Save the confusion matrix
  cm_path = os.path.join(output_dir, "confusion_matrix_with_fp_fn.png")
  plt.savefig(cm_path, bbox_inches="tight")
  plt.close()
  
  # Log FP and FN for further analysis
  fp_fn_report = os.path.join(output_dir, "fp_fn_report_1.txt")
  with open(fp_fn_report, 'w') as f:
    for idx, class_name in enumerate(class_names):
      f.write(f"Class: {class_name}, FP: {fp[idx]}, FN: {fn[idx]}\n")
  log.info(f"Confusion matrix with FP and FN saved to {cm_path}")
  log.info(f"FP and FN report saved to {fp_fn_report}")


def save_confusion_matrix_with_scores(logits, labels, class_names, output_dir):
  """
  Save a confusion matrix with scores (normalized) and counts, including FP and FN.
  """
  ## Compute confusion matrix
  predictions = torch.argmax(logits, dim=1).cpu().numpy()
  labels = labels.cpu().numpy()
  cm = confusion_matrix(labels, predictions, labels=range(len(class_names)))
  
  ## Calculate TP, FP, FN
  tp, fp, fn, tn, fp_indices, fn_indices, tp_indices, tn_indices = calculate_fp_fn(cm, predictions, labels)

  # Normalize the confusion matrix for scores
  cm_normalized = cm.astype('float') / (cm.sum(axis=1, keepdims=True) + 1e-6)

  ## Annotate the confusion matrix with counts and scores
  annot = np.empty_like(cm, dtype=object)
  for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
      count = cm[i, j]
      score = cm_normalized[i, j]
      # annot[i, j] = f"{score:.2f}\n({count})"  # Format: Score (Count)
      annot[i, j] = f"{score:.2f}"

  ## Plot the confusion matrix
  plt.figure(figsize=(12, 10))
  sns.heatmap(cm_normalized, annot=annot, fmt='', cmap="Blues", xticklabels=class_names, yticklabels=class_names)
  plt.xlabel("Predicted Label")
  plt.ylabel("True Label")
  plt.title("Confusion Matrix with Scores")
  plt.xticks(rotation=45, ha='right')

  ## Adjust layout to prevent label clipping
  plt.tight_layout()

  ## Save the confusion matrix
  cm_path = os.path.join(output_dir, "confusion_matrix_scores.png")
  plt.savefig(cm_path, bbox_inches="tight")
  plt.close()
  
  ## Log FP and FN for further analysis
  fp_fn_report = os.path.join(output_dir, "fp_fn_report_2.txt")
  with open(fp_fn_report, 'w') as f:
    for idx, class_name in enumerate(class_names):
      f.write(f"Class: {class_name}, FP: {fp[idx]}, FN: {fn[idx]}\n")
  log.info(f"Confusion matrix with scores saved to {cm_path}")
  log.info(f"FP and FN report saved to {fp_fn_report}")


def calculate_metrics(logits, labels, index_to_label):
  """Calculate classification metrics and map class indices to actual labels."""
  predictions = torch.argmax(logits, dim=1)
  
  ## Map indices to actual class names
  label_names = [index_to_label[idx] for idx in range(len(index_to_label))]
  cm = metrics.confusion_matrix(labels, predictions, labels=list(range(len(index_to_label))))
  classification_report = metrics.classification_report(
    labels, predictions, target_names=label_names, output_dict=True
  )

  tp_fp_fn = {
    "TP": cm.diagonal(),
    "FP": cm.sum(axis=0) - cm.diagonal(),
    "FN": cm.sum(axis=1) - cm.diagonal(),
    "TN": cm.sum() - (cm.sum(axis=0) + cm.sum(axis=1) - cm.diagonal())
  }
  
  return cm, classification_report, tp_fp_fn


def save_confusion_matrix(cm, class_names, output_dir):
  """Save the confusion matrix with adjusted margins for better label visibility."""
  plt.figure(figsize=(12, 10))
  ax = sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", xticklabels=class_names, yticklabels=class_names)

  ## Adjust layout to reduce excessive whitespace and prevent label cutting
  plt.xticks(rotation=45, ha='right')
  plt.yticks(rotation=0)
  plt.xlabel("Predicted Label")
  plt.ylabel("True Label")
  plt.title("Confusion Matrix")

  ## Adjust the subplot parameters for better spacing around labels
  plt.subplots_adjust(left=0.15, right=0.85, top=0.92, bottom=0.15)

  cm_path = os.path.join(output_dir, "confusion_matrix.png")
  plt.savefig(cm_path, bbox_inches="tight")
  plt.close()
  log.info(f"Confusion matrix saved to {cm_path}")
  return str(cm_path)


def plot_curves(logits, labels, class_names, output_dir):
  """Plot the aggregated Precision-Recall curve using actual class names."""
  probabilities = F.softmax(logits, dim=1).numpy()
  
  plt.figure(figsize=(12, 8))
  for i, class_name in enumerate(class_names):
    precision, recall, _ = precision_recall_curve((labels == i).int(), probabilities[:, i])
    plt.plot(recall, precision, label=class_name)
  plt.xlabel('Recall')
  plt.ylabel('Precision')
  plt.title('Aggregated Precision-Recall Curve')
  
  ## Move legend to the top right outside the plot area to prevent overlap
  plt.legend(loc='upper right', bbox_to_anchor=(1.15, 1), ncol=1, fontsize='small')
  
  pr_curve_path = os.path.join(output_dir, "aggregated_pr_curve.png")
  plt.savefig(pr_curve_path, bbox_inches="tight")
  plt.close()
  log.info(f"Aggregated PR curve saved to {pr_curve_path}")


def plot_and_save_curves(logits, labels, class_names, output_dir):
  """Plot and save Precision-Recall and ROC curves for each class as PNG files."""
  probabilities = F.softmax(logits, dim=1).numpy()
  for i, class_name in enumerate(class_names):
    precision, recall, _ = metrics.precision_recall_curve((labels == i).int(), probabilities[:, i])
    fpr, tpr, _ = metrics.roc_curve((labels == i).int(), probabilities[:, i])
    
    plt.figure()
    plt.plot(recall, precision, label=f'{class_name} PR Curve')
    plt.plot(fpr, tpr, label=f'{class_name} ROC Curve')
    plt.xlabel("Recall / FPR")
    plt.ylabel("Precision / TPR")
    plt.legend()
    plt.title(f'{class_name} PR and ROC Curves')
    curve_path = os.path.join(output_dir, f"{class_name}_pr_roc_curves.png")
    plt.savefig(curve_path)
    log.info(f"{class_name} PR and ROC curves saved to {curve_path}")
    plt.close()
