"""tensorboard logger"""
__author__ = 'XYZ'


import os
from datetime import datetime

import torch
from torch.utils.tensorboard import SummaryWriter


class TensorBoardLogger:
  """TensorBoard Logger for tracking training, validation metrics, and model performance summaries."""

  def __init__(self, log_dir=None):
    if log_dir is None:
      timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
      log_dir = f"logs/run_{timestamp}"

    ## Store provided log_dir as an attribute
    self.log_dir = log_dir
    os.makedirs(log_dir, exist_ok=True)
    
    self.writer = SummaryWriter(log_dir=log_dir)
    self.step = 0
    self.epoch = 0

    # Best metrics tracking
    self.best_test_accuracy = 0.0
    self.best_test_epoch = 0
    self.best_val_accuracy = 0.0
    self.best_val_epoch = 0
    self.best_train_accuracy = 0.0
    self.best_train_epoch = 0    
    self.best_topk_accuracy = {
      'Train': {},
      'Val': {},
      'Test': {}
    }
    self.best_topk_epoch = {
      'Train': {},
      'Val': {},
      'Test': {}
    }
    self.best_loss = {
      'Train': float('inf'),
      'Val': float('inf'),
      'Test': float('inf')
    }
    self.best_loss_epoch = {
      'Train': 0,
      'Val': 0,
      'Test': 0
    }


  def train_metrics(self, loss, accuracy, epoch=None):
    """
    Logs training loss and accuracy for each epoch.
    Also keeps track of the best training accuracy and the corresponding epoch.
    """
    epoch = epoch if epoch is not None else self.epoch
    self.writer.add_scalar("Train/Loss", loss, epoch)
    self.writer.add_scalar("Train/Accuracy", accuracy, epoch)

    # Track best training accuracy
    if accuracy > self.best_train_accuracy:
      self.best_train_accuracy = accuracy
      self.best_train_epoch = epoch

    # Log best training accuracy and epoch
    self.writer.add_scalar("Best/Train_Accuracy", self.best_train_accuracy, epoch)
    self.writer.add_scalar("Best/Train_Epoch", self.best_train_epoch, epoch)

    ## Best Loss
    if loss < self.best_loss['Train']:
      self.best_loss['Train'] = loss
      self.best_loss_epoch['Train'] = epoch

    self.writer.add_scalar("Best/Train_Loss", self.best_loss['Train'], epoch)
    self.writer.add_scalar("Best/Train_Loss_Epoch", self.best_loss_epoch['Train'], epoch)


  def val_metrics(self, loss, accuracy, epoch=None):
    """
    Logs validation loss and accuracy for each epoch.
    Also keeps track of the best validation accuracy and the corresponding epoch.
    """
    epoch = epoch if epoch is not None else self.epoch
    self.writer.add_scalar("Val/Loss", loss, epoch)
    self.writer.add_scalar("Val/Accuracy", accuracy, epoch)

    # Track best validation accuracy
    if accuracy > self.best_val_accuracy:
      self.best_val_accuracy = accuracy
      self.best_val_epoch = epoch

    # Log best validation accuracy and epoch
    self.writer.add_scalar("Best/Val_Accuracy", self.best_val_accuracy, epoch)
    self.writer.add_scalar("Best/Val_Epoch", self.best_val_epoch, epoch)

    ## Best Loss
    if loss < self.best_loss['Val']:
      self.best_loss['Val'] = loss
      self.best_loss_epoch['Val'] = epoch

    self.writer.add_scalar("Best/Val_Loss", self.best_loss['Val'], epoch)
    self.writer.add_scalar("Best/Val_Loss_Epoch", self.best_loss_epoch['Val'], epoch)


  def test_metrics(self, loss, accuracy, epoch=None):
    """
    Logs validation loss and accuracy for each epoch.
    Also keeps track of the best validation accuracy and the corresponding epoch.
    """
    epoch = epoch if epoch is not None else self.epoch
    self.writer.add_scalar("Test/Loss", loss, epoch)
    self.writer.add_scalar("Test/Accuracy", accuracy, epoch)

    # Track best validation accuracy
    if accuracy > self.best_test_accuracy:
      self.best_test_accuracy = accuracy
      self.best_test_epoch = epoch

    # Log best validation accuracy and epoch
    self.writer.add_scalar("Best/Test_Accuracy", self.best_test_accuracy, epoch)
    self.writer.add_scalar("Best/Test_Epoch", self.best_test_epoch, epoch)

    ## Best Loss
    if loss < self.best_loss['Test']:
      self.best_loss['Test'] = loss
      self.best_loss_epoch['Test'] = epoch

    self.writer.add_scalar("Best/Test_Loss", self.best_loss['Test'], epoch)
    self.writer.add_scalar("Best/Test_Loss_Epoch", self.best_loss_epoch['Test'], epoch)


  def topk_metrics(self, phase, topk_accuracy: dict, epoch=None):
    """
    Logs Top-K accuracy for a given phase ('Train', 'Val', 'Test') and keeps track of best per K.
    """
    epoch = epoch if epoch is not None else self.epoch
    for k, acc in topk_accuracy.items():
      ## Log current top-k
      self.writer.add_scalar(f"{phase}/Top{k}_Accuracy", acc, epoch)

      ## Compare with previous best
      if k not in self.best_topk_accuracy[phase] or acc > self.best_topk_accuracy[phase][k]:
        self.best_topk_accuracy[phase][k] = acc
        self.best_topk_epoch[phase][k] = epoch

      ## Log best top-k accuracy and epoch
      self.writer.add_scalar(f"Best/{phase}_Top{k}_Accuracy", self.best_topk_accuracy[phase][k], epoch)
      self.writer.add_scalar(f"Best/{phase}_Top{k}_Epoch", self.best_topk_epoch[phase][k], epoch)


  def gradients(self, model, step=None):
    """Logs gradients of model parameters."""
    step = step if step is not None else self.step
    for name, param in model.named_parameters():
      if param.grad is not None:
        self.writer.add_histogram(f"{name}_grad", param.grad, step)


  def embedding(self, features, labels, metadata=None, tag="Embedding", step=None):
    """Logs embeddings for visualization."""
    step = step if step is not None else self.step
    self.writer.add_embedding(
      features, metadata=metadata, label_img=labels, global_step=step, tag=tag
    )


  def increment_step(self):
    """Increments the step counter (useful for batch logging)."""
    self.step += 1


  def increment_epoch(self):
    """Increments the epoch counter (useful for epoch logging)."""
    self.epoch += 1


  def close(self):
    """Closes the TensorBoard writer."""
    self.writer.close()
