import os
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
import torch


class Logger:
    """
    Logger class for recording and logging the training history.
    We write the training history to a log file and stdout and plot the training history.
    """
    def __init__(self) -> None:
        self.epoch_number = []
        self.loss_history = []
        self.mae_history = []
        self.mse_history = []
        self.l1re_history = []
        self.l2re_history = []
    
    def setup(self, config):
        """
        Setup the logger.
        """
        self.config = config

        # Create a folder in `logs/` to store the log files,
        # which is named after the current date + 
        # current time + problem name + round number.
        folder_name = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        folder_name = folder_name + \
            f"_{config['Problem Name']}_{config['Round Number']}_{config['Total Rounds']}_{os.getpid()}"
        log_path = os.path.join("logs", folder_name)
        os.makedirs(log_path, exist_ok=True)

        self.log_path = log_path + "/"
        self.fileIO = open(self.log_path + 'run.log', "w")
    
    def log_config(self):
        """
        Log the configuration.
        """
        # Write the configuration to the log file
        self.fileIO.write("Configuration:\n")
        for key, value in self.config.items():
            self.fileIO.write(f"{key}: {value}\n")
        self.fileIO.write("\n")
        self.fileIO.flush()

    def log_precond_time(self, time_elapsed):
        """
        Log the preconditioner time cost.
        """
        # Write the preconditioner time cost to the log file
        self.fileIO.write(f"Computing preconditioner costs: {time_elapsed:.2f}s\n")

    def log_precond_res(self, cond_error, after_cond):
        """
        Log the preconditioner.
        """
        # Write the preconditioner to the log file
        self.fileIO.write("L2 relative error between P^{-1}b and A^{-1}b"
                        f": {cond_error:.2e}\n")
        self.fileIO.write("Condition number after preconditioning"
                        f": {after_cond:.2e}\n")
    
    def log_original_cond(self, cond):
        """
        Log the original condition number.
        """
        # Write the original condition number to the log file
        self.fileIO.write(f"Original condition number: {cond:.2e}\n")
    
    def log_train_start(self):
        """
        Log the training start.
        """
        # Write the training start to the log file
        self.fileIO.write("Training History:\n")

    def log_sub_interval_start(self, sub_interval):
        """
        Log the sub-interval start.
        """
        # Write the sub-interval start to the log file
        self.fileIO.write(f"Sub-Time Interval {sub_interval}:\n")

    def log_sub_interval_end(self, time_elapsed):
        """
        Log the sub-interval end.
        """
        # Write the sub-interval end to the log file
        self.fileIO.write(f"Sub-Time Interval costs: {time_elapsed:.2f}s\n")

    def log_train(self, epoch, loss_val, mae, 
        mse, l1re, l2re, time_elapsed):
        """
        Log the training history.
        """
        # Store the training history
        self.epoch_number.append(epoch)
        self.loss_history.append(loss_val)
        self.mae_history.append(mae)
        self.mse_history.append(mse)
        self.l1re_history.append(l1re)
        self.l2re_history.append(l2re)
        # Write the training history to the log file
        self.fileIO.write("[Epoch {}] Loss {:.2e} MAE {:.2e} MSE {:.2e} L1RE {:.2e} L2RE {:.2e}\n"
            .format(epoch, loss_val, mae, mse, l1re, l2re))
        self.fileIO.write("Time elapsed: {:.2f}s\n".format(time_elapsed))
        self.fileIO.flush()
    
    def log_train_empty_sub_interval(self, epoch, 
        loss_val, time_elapsed):
        """
        Log the training history for sub-interval
        that does not have reference data.
        """
        # Store the training history
        self.epoch_number.append(epoch)
        self.loss_history.append(loss_val)
        self.mae_history.append(-1)
        self.mse_history.append(-1)
        self.l1re_history.append(-1)
        self.l2re_history.append(-1)
        # Write the training history to the log file
        self.fileIO.write("[Epoch {}] Loss {:.2e} [No Reference Data]\n"
            .format(epoch, loss_val))
        self.fileIO.write("Time elapsed: {:.2f}s\n".format(time_elapsed))
        self.fileIO.flush()
    
    def log_sub_interval_final_result(self, mae, mse, l1re, l2re):
        """
        Log the sub-time interval final result.
        """
        # Write the sub-interval final result to the log file
        self.fileIO.write("Final result: MAE {:.2e} MSE {:.2e} L1RE {:.2e} L2RE {:.2e}\n"
            .format(mae, mse, l1re, l2re))
        self.fileIO.flush()

    def log_train_end(self, time_elapsed):
        """
        Log the training end.
        """
        # Write the training end to the log file
        self.fileIO.write("Training costs: {:.2f}s\n".format(time_elapsed))
        self.fileIO.flush()
    
    def log_newton_update(self):
        """
        Log the update of Newton iteration.
        """
        # Write the Newton iteration update to the log file
        self.fileIO.write("Updating the Newton system...\n")
        self.fileIO.flush()
    
    def log_precond_update(self):
        """
        Log the update of preconditioner.
        """
        # Write the preconditioner update to the log file
        self.fileIO.write("Updating the preconditioner...\n")
        self.fileIO.flush()
    
    def log_info(self, info):
        """
        Log the info.
        """
        # Write the info to the log file
        self.fileIO.write(info + "\n")
        self.fileIO.flush()
    
    def log_warning(self, warning):
        """
        Log the warning.
        """
        # Write the warning to the log file
        self.fileIO.write("[Warning] " + warning + "\n")
        self.fileIO.flush()
    
    def log_error(self, error):
        """
        Log the error.
        """
        # Write the error to the log file
        self.fileIO.write("[Error] " + error + "\n")
        self.fileIO.flush()
    
    def close(self):
        """
        Close the logger.
        """
        self.fileIO.close()

    def save_history_data(self):
        """
        Save the training history data.
        """
        # Save the training history data
        np.savez(self.log_path + 'history_data.npz', 
            epoch_number=self.epoch_number,
            loss_history=self.loss_history,
            mae_history=self.mae_history,
            mse_history=self.mse_history,
            l1re_history=self.l1re_history,
            l2re_history=self.l2re_history)
    
    def save_prediction(self, test_X, test_Y, model):
        """
        Save the prediction.
        """
        # Save the prediction
        model.eval()
        X = torch.tensor(test_X,
            dtype=torch.float32, 
            device=next(model.parameters()).device)
        np.savez(self.log_path + 'prediction.npz', 
            test_X=test_X,
            test_Y=test_Y,
            pred_Y=model(X).cpu().detach().numpy())
    
    def save_sub_interval_prediction(self, test_Y, pred_Y):
        """
        Save the prediction in each sub-interval.
        """
        np.savez(self.log_path + 'prediction.npz', 
            test_Y=test_Y,
            pred_Y=pred_Y)

    def save_history_plot(self):
        """
        Save the training history plot.
        """
        # Plot the training history (L2RE and loss w.r.t. epoch number)
        # loss in the left and L2RE in the right axis
        fig, ax1 = plt.subplots()
        ax1.set_xlabel('Epoch Number')
        ax1.set_ylabel('Loss')
        ax1.set_yscale('log')
        ax1.plot(self.epoch_number, self.loss_history, 'b-')
        ax1.tick_params(axis='y', labelcolor='b')
        ax2 = ax1.twinx()
        ax2.set_ylabel('L2RE')
        ax2.set_yscale('log')
        ax2.plot(self.epoch_number, self.l2re_history, 'r-')
        ax2.tick_params(axis='y', labelcolor='r')
        fig.tight_layout()
        plt.savefig(self.log_path + 'history_plot.png')
        plt.close()

    def save_model(self, model):
        """
        Save the model.
        """
        # Save the model
        torch.save(model.state_dict(), self.log_path + 'model.pt')
