import torch
import joblib
import os
from datetime import date


MODEL_BASE_DIR = "model_saves"
# Get the current date in the format YYYY-MM-DD
CURRENT_DATE = str(date.today())
CURRENT_DATE_DIR = os.path.join(MODEL_BASE_DIR, CURRENT_DATE)


class EarlyStopper:
    def __init__(self, patience: int = 20, min_delta=0.995):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float("inf")

    def early_stop(self, validation_loss):
        if validation_loss < (self.min_validation_loss * self.min_delta):
            self.min_validation_loss = validation_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

    def is_best(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            return True
        else:
            return False


def save_model(model, model_name, current_date_dir=CURRENT_DATE_DIR):
    "Save a model to the Model Base Dir"

    # Create the directory if it doesn't exist
    os.makedirs(current_date_dir, exist_ok=True)

    if model.__class__.__name__ == "PCA":
        file_suffix = ".pkl"
    else:
        file_suffix = ".pth"

    # Define the file name for the model
    model_filename = f"{model_name}_1.{file_suffix}"

    # Save the PyTorch model to the specified folder
    model_path = os.path.join(current_date_dir, model_filename)
    while os.path.isfile(model_path):
        model_path = model_path.removesuffix(file_suffix)
        path_components = model_path.split("_")
        path_components[-1] = str(int(path_components[-1]) + 1)
        model_path = "_".join(path_components) + file_suffix

    # Finally save the arifact.
    if model.__class__.__name__ == "PCA":
        joblib.dump(model, model_path)
    else:
        torch.save(model.state_dict(), model_path)


def save_checkpoint(
    model,
    experiment,
    device_id,
    epoch=None,
    base_dir="artifacts/benchmarking",
    best: bool = False,
) -> None:
    "Save a model checkpoint during training to the Model Base Dir"

    model_name = experiment.model_name
    dataset_name = experiment.dataset
    unique_id = str(device_id) + str(os.getpid())
    checkpoint_dir = f"{base_dir}/{model_name}/{dataset_name}/id{unique_id}/"
    # Create the directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)

    if model.__class__.__name__ == "PCA":
        file_suffix = "pkl"
    else:
        file_suffix = "pth"

    # Define the file name for the model
    if best:
        best_filename = f"best_{model_name}_{dataset_name}.{file_suffix}"
        best_path = os.path.join(checkpoint_dir, best_filename)
        # Save the PyTorch model to the specified folder
        if model.__class__.__name__ == "PCA":
            joblib.dump(model, best_path)
        else:
            torch.save(model.state_dict(), best_path)
    if epoch:
        model_filename = f"checkpoint_{model_name}_{dataset_name}_{epoch}.{file_suffix}"
        model_path = os.path.join(checkpoint_dir, model_filename)
        # Save the PyTorch model to the specified folder
        torch.save(model.state_dict(), model_path)
