""" Utils Functions """
import os
import torch
import matplotlib.pyplot as plt
import numpy as np

plt.style.use('ggplot')

def write_log(log_values, model_name, log_dir="", log_type='loss', type_write='a'):
    if not os.path.exists(log_dir):
            os.makedirs(log_dir)
    with open(log_dir + "/"+ model_name + "_" + log_type + ".txt", type_write) as f:
        f.write(','.join(log_values)+"\n")

# Code reference: https://debuggercafe.com/saving-and-loading-the-best-model-in-pytorch/
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(self, save_dir, best_valid_loss=float('inf')):
        self.save_dir = save_dir
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, epoch_idx, current_valid_loss, model, optimizer, criterion
    ):
        epoch = epoch_idx + 1
        
        best_save_path = os.path.join(self.save_dir, 'best_model.pth')
        latest_save_path = os.path.join(self.save_dir, 'latest_model.pth')

        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss_func': criterion,
            }, latest_save_path)

        # Save best model
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest validation loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch {epoch}\n")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss_func': criterion,
                }, best_save_path)

def save_accuracy_plots(save_dir, train_accu, valid_accu):
    """
    Function to save the loss plots to disk.
    """

    save_path =  os.path.join(save_dir, 'accuracy.png')
        
    # loss plots
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_accu, color='orange', linestyle='-', 
        label='train accuracy'
    )
    plt.plot(
        valid_accu, color='red', linestyle='-', 
        label='validation accuracy'
    )
    plt.xlabel('Epochs')
    plt.ylabel('accuracy')
    plt.legend()
    plt.savefig(save_path)

def save_loss_plots(save_dir, train_loss, valid_loss):
    """
    Function to save the loss plots to disk.
    """

    save_path =  os.path.join(save_dir, 'loss.png')
        
    # loss plots
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_loss, color='orange', linestyle='-', 
        label='train loss'
    )
    plt.plot(
        valid_loss, color='red', linestyle='-', 
        label='validation loss'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(save_path)

def save_pretrain_loss_plots(save_dir, pretrain_loss):
    """
    Function to save the loss plots to disk.
    """

    save_path =  os.path.join(save_dir, 'loss.png')
        
    # loss plots
    plt.figure(figsize=(10, 7))
    plt.plot(
        pretrain_loss, color='orange', linestyle='-', 
        label='pretrain loss'
    )

    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(save_path)
    
def save_lr_plots(save_dir, lr_rate):
    """
    Function to save the loss plots to disk.
    """

    save_path =  os.path.join(save_dir, 'lr_rate_decay.png')
        
    # loss plots
    plt.figure(figsize=(10, 7))
    plt.plot(
        lr_rate, color='blue', linestyle='-', 
        label='train loss'
    )
    plt.xlabel('Epochs')
    plt.ylabel('LR Rate')
    plt.legend()
    plt.savefig(save_path)

def save_cdf_plots(save_dir, error):
    save_path =  os.path.join(save_dir, 'cdf_error_curve.png')
    sorted_error = np.sort(error)
    cdf = np.arange(len(error)) / len(error)
    # 计算50%误差
    errors_clean = error[~np.isnan(error)]
    median_error = np.percentile(errors_clean, 50)
    tail_error = np.percentile(errors_clean, 90)

    plt.figure(figsize=(6, 4))
    plt.plot(sorted_error, cdf, label='CDF of Error', color='blue')
    plt.axvline(median_error, color='r', linestyle='--', label=f'50% = {median_error:.2f} m')
    plt.axvline(tail_error, color='g', linestyle='--', label=f'90% = {tail_error:.2f} m')
    plt.xlabel("location_error (m)")
    plt.ylabel("Cumulative probability")
    plt.title("CDF")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path)  # 保存 CDF 曲线图像