"""
This module contains the Logger class which is used to log training information.
"""
import copy
import io
import os
from datetime import datetime
from typing import Optional

import numpy as np
import torch.nn as nn
from matplotlib import pyplot as plt
from PIL import Image
from torch.utils.tensorboard.writer import SummaryWriter
from torchvision.transforms import ToTensor
from tqdm import tqdm

from src.utils.path_io import get_path_up_to
from src.utils.vis_training import plot_confusion_matrix, plot_continous_confussion_matrix

first_name_logging = True
ROOT_PATH = get_path_up_to(__file__, "src")


class Logger():
    """
    A class used to log training information.

    This class handles logging of training information to the console and to a TensorBoard log file.
    It logs the start and end times of the training, the training and validation losses for each
    epoch, and the model architecture. Additionally, it handles closing the logger after training is
    finished.
    """

    def __init__(self, current_time_string, fold: int=0,  name: str = "") -> None:
        """
        Initializes the Logger.

        This init method sets up the TensorBoard writer with the name of the target directory set
        to the current date and time. It also initializes the training start time to None.
        """
        self.start_time_str = current_time_string
        self.name = f"runs/{current_time_string}_{name}_fold{fold}"
        self._summary_writer = SummaryWriter(os.path.join(ROOT_PATH,self.name))
        self._training_start: Optional[datetime] = None
        self._sr_closed = False

        self.train_loss = {}
        self.val_loss = {}
        self.test_loss = {}
        self.scores = {}
        self.performance_values = {}
        self.measure_values = {}
        self.text = {}
        self.optimal_epoch = 0


    def write_text(self, tag: str, text: str) -> None:
        """
        Writes a custom text to a custom tag in the TensorBoard log file.

        Args:
            tag (str): The tag for the text.
            text (str): The text to write.
        """
        self._summary_writer.add_text(tag, str(text))
        self.text[tag] = text

    def write_dict(self, data: dict, name:str='config') -> None:
        """
        Logs the configuration of the training.

        This method writes for each first level item in the config a text of its contents with the dict key as subtag.

        Args:
            data (dict): The configuration of the training.
        """
        data = copy.deepcopy(data)
        for key, value in data.items():
            self.write_text(f'{name}_{key}', str(value))
            self.text[f'{name}_{key}'] = value

    def write_model(self, model: nn.Module) -> None:
        """
        Writes the model architecture formatted as text to the TensorBoard log file.

        Args:
            model (nn.Module): The model.
        """
        self._summary_writer.add_text("model/model_class", str(model))
        self.model = str(model)

    def write_training_start(self) -> None:
        """
        Logs the start time of the training.

        This method writes a message to the console, sets the training start time to the current
        time, and writes the start time to the TensorBoard log file.
        """
        tqdm.write("[LOGGER]: Training was started.")

        self._training_start = datetime.now()
        self._summary_writer.add_text(
            "training_duration/start_time", str(self._training_start))

    def write_training_end(self, reason: str) -> None:
        """
        Logs the end time of the training and the reason the training finished.

        This method calculates the training duration, writes a message to the console,
        and writes the end time, training duration, and finish reason to the TensorBoard log file.
        It also closes the TensorBoard writer.

        Args:
            reason (str): The reason the training finished.
        """
        training_end = datetime.now()

        if self._training_start is None:
            # Set training duration to None if training start is None (if
            # log_training_start was not called)
            training_duration = None
        else:
            training_duration = training_end - self._training_start
        tqdm.write(
            f"[LOGGER]: Training finished with a runtime of {training_duration}. Finish reason: {reason}")

        self._summary_writer.add_text(
            "training_duration/end_time", str(training_end))
        self._summary_writer.add_text("training_duration/duration",
                                      str(training_duration))
        self._summary_writer.add_text("training_duration/reason", reason)

        tqdm.write("[LOGGER]: Closing logger.")
        self._summary_writer.close()

    def log_loss(self, value: float, epoch: int, set: str = 'train') -> None:
        """
        Logs the loss for an epoch.

        This method writes a message to the console and writes the loss to the
        TensorBoard log file.

        Args:
            value (float): The loss.
            epoch (int): The epoch number.
            set (str): The set for which the loss is logged. Defaults to 'train'.
        """
        # Write to summary writer
        self._summary_writer.add_scalar(f"loss/{set}", value, epoch)

        # Save value
        if set == '1_train':
            self.train_loss[epoch] = value
        elif set == '2_validation':
            self.val_loss[epoch] = value
        elif set == '3_test':
            self.test_loss[epoch] = value

    def log_test_score(self, value: float, epoch: int, score: str, class_label: str = '0_total') -> None:
        """
        Logs the accuracy for an epoch.

        This method writes a message to the console and writes the accuracy to the
        TensorBoard log file.

        Args:
            value (float): The accuracy.
            epoch (int): The epoch number.
            class_label (str): The set for which the accuracy is logged. Defaults to 'val'.
        """
        # Write to summary writer
        self._summary_writer.add_scalar(f"{score}/{class_label}", value, epoch)

        # Save value
        if score not in self.scores.keys():
            self.scores[score] = {}
        if class_label not in self.scores[score].keys():
            self.scores[score][class_label] = {}
        self.scores[score][class_label][epoch] = value

    def log_performance(self, value: float, epoch: int, name: str) -> None:
        """
        Logs the accuracy for an epoch.

        This method writes a message to the console and writes the value to the
        TensorBoard log file.

        Args:
            value (float): The value to log.
            epoch (int): The epoch number.
            name (str): Name of the value (e.g. epoch_time).
        """
        # Write to summary writer
        self._summary_writer.add_scalar(f"performance/{name}", value, epoch)

        # Save value
        if name not in self.performance_values.keys():
            self.performance_values[name] = {}
        self.performance_values[name][epoch] = value

    def log_measure(self, value: float, epoch: int, name: str) -> None:
        """
        Logs the accuracy for an epoch.

        This method writes a message to the console and writes the value to the
        TensorBoard log file.

        Args:
            value (float): The value to log.
            epoch (int): The epoch number.
            name (str): Name of the value (e.g. epoch_time).
        """
        # Write to summary writer
        self._summary_writer.add_scalar(f"measures/{name}", value, epoch)

        # Save value
        if name not in self.measure_values.keys():
            self.measure_values[name] = {}
        self.measure_values[name][epoch] = value

    def log_lr(self, lr: float, epoch: int) -> None:
        """
        Logs the learning rate for an epoch.

        This method writes a message to the console and writes the validation loss to the
        TensorBoard log file.

        Args:
            value (float): The validation loss.
            epoch (int): The epoch number.
        """
        self._summary_writer.add_scalar("lr", lr, epoch)

    def log_model_path(self, model_path: str) -> None:
        """
        Logs the path of the saved model.
        Args:
            model_path (str): Path of the saved model.

        Returns: None
        """
        global first_name_logging

        if first_name_logging:
            self._summary_writer.add_text("model/mode_path", model_path)
            first_name_logging = False

    def set_optimal_epoch(self, epoch: int) -> None:
        """
        Set optimal epoch as the last epoch in the training.

        Args:
            epoch (int): The last epoch of the training.
        """

        self.optimal_epoch = epoch
        self._summary_writer.add_text(tag="optimal_epoch", text_string=str(epoch))

    def close_sr(self) -> None:

        if not self._sr_closed:

            # Write scores from optimal epoch
            for score, class_scores in self.scores.items():
                for class_label, epoch_scores in class_scores.items():
                    if self.optimal_epoch in epoch_scores:
                        self._summary_writer.add_text(
                            f"99_score_{score}_{class_label}", str(epoch_scores[self.optimal_epoch]), self.optimal_epoch)
            # Write losses from optimal epoch
            self._summary_writer.add_text(f"loss_train", str(self.train_loss.get(self.optimal_epoch, None)), self.optimal_epoch)
            self._summary_writer.add_text(f"loss_val_loss", str(self.val_loss.get(self.optimal_epoch, None)), self.optimal_epoch)

            self._summary_writer.close()

            self._sr_closed = True



    def close(self) -> None:
        """
        Closes the TensorBoard writer.

        This method is called to close the TensorBoard writer when the training is finished.
        """
        self.close_sr()

        # Save losses and scores from optimal epoch
        self.best_results = {
            "train_loss": self.train_loss.get(self.optimal_epoch, None),
            "val_loss": self.val_loss.get(self.optimal_epoch, None),
            "test_loss": self.test_loss.get(self.optimal_epoch, None),
            "scores": {
                score: {
                    class_label: epoch_scores.get(self.optimal_epoch, None)
                    for class_label, epoch_scores in class_scores.items()
                }
                for score, class_scores in self.scores.items()
            }
        }

        # Clean up
        self.scores = {}
        self.train_loss = {}
        self.val_loss = {}
        self.test_loss = {}


    def save_confusion_matrix(self,
                              y_true: np.array,
                              y_pred: np.array,
                              labels: list[str],
                              epoch: int,
                              continuous: bool = False,
                              set='Test'):
        """
        Saves a confusion matrix to the TensorBoard log file.
        Args:
            targets (np.array): Targets of the model.
            predictions (np.array): Predictions of the model.
            labels (list[str]): List of class labels.
            epoch (int): Epoch, in which the predictions were made.
            name (str): Name of the chart. Defaults to "validation_set".

        Returns: None
        """
        title = f'Confusion Matrix: {set} Set'
        cmap = 'Blues'

        if not continuous:
            fig = plot_confusion_matrix(y_true, y_pred, labels, title, cmap)
        else:
            fig = plot_continous_confussion_matrix(y_true, y_pred, labels, title, cmap)

        os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

        # Save image in Logger
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        image = Image.open(buf)
        image = ToTensor()(image)
        self._summary_writer.add_image(f"image/cm_{set}", image, epoch)

        print(f"[Logger]: Chart for {set} set saved.")
        plt.close('all')



