import logging
import torch
import os
import lightning as L
from src.utils.data_plotting import plot_all_test_estimation, plot_test_result_table
from src.utils.data_processing import (
    save_test_for_visualisation,
    create_test_visualisation_jupyter,
)


class LogTensorboardCallback(L.Callback):
    def __init__(self, logger):
        self.logger = logger.tensorboard

    def on_test_epoch_end(
        self, trainer: "L.Trainer", pl_module: "L.LightningModule"
    ) -> None:
        plot_test_result_table(
            pl_module.var_test_losses,
            pl_module.test_batch_est,
            pl_module.test_batch_true,
            pl_module.K_loss_test,
            variableDict=pl_module.estimated_variables_dict,
            variableList=pl_module.estimated_variables_list,
            logger=self.logger,
        )

        # choosing 10 random cycles for plotting on tensorboard and saving
        numCycle = 10
        torch.manual_seed(33)
        # Get the first value without knowing the key
        first_tensor = next(iter(pl_module.test_batch_est.values()))
        indices = torch.randperm(first_tensor.shape[0])[:numCycle]

        # plotting the selected test data estimations to tensorboard
        logging.log(level=logging.INFO, msg="Plotting the test plots on tensorboard...")
        plot_all_test_estimation(
            trainer.datamodule.data_test.Subs,
            pl_module.test_batch_est,
            pl_module.test_batch_true,
            pl_module.K_loss_test,
            pl_module.K_loss_test_val,
            variableDict=pl_module.estimated_variables_dict,
            variableList=pl_module.estimated_variables_list,
            indices=indices,
            dir=trainer.default_root_dir,
            logger=self.logger,
        )

        # saving the test data for jupyter visualisation
        logging.log(level=logging.INFO, msg="Saving test data for visualisation...")
        # saving one directory above the logger save_dir
        save_dir = os.path.abspath(os.path.join(self.logger.save_dir, ".."))

        save_test_for_visualisation(
            pl_module.test_batch_est,
            pl_module.test_batch_true,
            indices,
            save_dir,
        )

        # creating jupyter notebook for test data visualisation
        create_test_visualisation_jupyter(save_dir)

        return super().on_test_epoch_end(trainer, pl_module)


# class LogWandbCallback(Callback):
#     def __init__(self, logger):
#         self.logger = logger.wandb1

#     def on_test_epoch_end(
#         self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
#     ) -> None:
#         plot_test_result_table(
#             pl_module.var_test_losses,
#             pl_module.y_hat_test,
#             pl_module.y_val_test,
#             self.logger,
#         )

#         # choosing 10 random cycles for plotting on tensorboard and saving
#         numCycle = 10
#         torch.manual_seed(48)
#         indices = torch.randperm(pl_module.y_hat_test.size(0))[:numCycle]

#         # plotting the selected test data estimations to tensorboard
#         logging.log(level=logging.INFO, msg="Plotting the test plots on wandb...")
#         plot_all_test_estimation(
#             trainer.datamodule.data_test.Subs,
#             pl_module.y_hat_test,
#             pl_module.y_val_test,
#             pl_module.K_loss_test,
#             pl_module.K_loss_test_val,
#             indices=indices,
#             dir=trainer.default_root_dir,
#             logger=self.logger,
#         )

#         # saving the test data for jupyter visualisation
#         logging.log(level=logging.INFO, msg="Saving test data for visualisation...")
#         save_test_for_visualisation(
#             pl_module.y_hat_test,
#             pl_module.y_val_test,
#             pl_module.y_left_test,
#             pl_module.constants_test,
#             pl_module.hip_data_test,
#             pl_module.deriv_data_test,
#             pl_module.add_data_test,
#             indices,
#             self.logger.save_dir,
#         )

#         # creating jupyter notebook for test data visualisation
#         create_test_visualisation_jupyter(self.logger.save_dir)

#         return super().on_test_epoch_end(trainer, pl_module)
