import os
import glob
import torch


def save_checkpoint(args, logger, model, optimizer, epoch, checkpoint_directory):
    checkpoint_dict = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "args": vars(args),
    }

    if args.supervised:
        checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict()
        checkpoint_dict["loss_history"] = model.loss_history
        checkpoint_dict["evaluation_history"] = model.evaluation_history
    
    elif args.model_name == "SDMI":
        optimizer_E, optimizer_M = optimizer
        checkpoint_dict["optimizer_E_state_dict"] = optimizer_E.state_dict()
        checkpoint_dict["optimizer_M_state_dict"] = optimizer_M.state_dict()
        checkpoint_dict["E_loss_history"] = model.E_loss_history
        checkpoint_dict["M_loss_history"] = model.M_loss_history
        checkpoint_dict["E_linear_probing_history"] = model.E_linear_probing_history
        checkpoint_dict["M_linear_probing_history"] = model.M_linear_probing_history

    else:
        checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict()
        checkpoint_dict["loss_history"] = model.loss_history
        checkpoint_dict["linear_probing_history"] = model.linear_probing_history
    
    checkpoint_filename = f"checkpoint_{epoch:04d}.pth"
    checkpoint_path = os.path.join(checkpoint_directory, checkpoint_filename)
    torch.save(checkpoint_dict, checkpoint_path)
    logger.info(f"Checkpoint saved at: {checkpoint_path}.")


def load_checkpoint(args, logger, model, optimizer, checkpoint_directory, device):
    epoch = -1
    checkpoint_path = None

    if args.resume:
        checkpoint_path = args.resume
        logger.info(f"Explicit resume checkpoint provided: {checkpoint_path}")

    else:
        checkpoint_files = glob.glob(os.path.join(checkpoint_directory, "checkpoint_*.pth"))

        if checkpoint_files:
            latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
            checkpoint_path = latest_checkpoint
        else:
            logger.info("No checkpoint found. Training will start from scratch.")
            
            return epoch

    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
        epoch = checkpoint["epoch"]

        if epoch == 0:
            return -1

        model.load_state_dict(checkpoint["model_state_dict"])

        if args.supervised:
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            model.loss_history = checkpoint["loss_history"]
            model.evaluation_history = checkpoint["evaluation_history"]

        elif args.model_name == "SDMI":
            optimizer_E, optimizer_M = optimizer
            optimizer_E.load_state_dict(checkpoint["optimizer_E_state_dict"])
            optimizer_M.load_state_dict(checkpoint["optimizer_M_state_dict"])
            model.E_loss_history = checkpoint["E_loss_history"]
            model.M_loss_history = checkpoint["M_loss_history"]
            model.E_linear_probing_history = checkpoint["E_linear_probing_history"]
            model.M_linear_probing_history = checkpoint["M_linear_probing_history"]

        else:
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            model.loss_history = checkpoint["loss_history"]
            model.linear_probing_history = checkpoint["linear_probing_history"]
            
        logger.info(f"Checkpoint loaded from {checkpoint_path}.")

    return epoch


def save_evaluation_checkpoint(args, logger, model, checkpoint_directory):
    checkpoint_dict = {}

    if args.model_name == "SDMI":
        checkpoint_dict["E_encoder_state_dict"] = {
            k.replace("E_encoder.", ""): v
            for k, v in model.state_dict().items() 
            if k.startswith("E_encoder.")
        }
        checkpoint_dict["M_encoder_state_dict"] = {
            k.replace("M_encoder.", ""): v
            for k, v in model.state_dict().items() 
            if k.startswith("M_encoder.")
        }

    else:
        checkpoint_dict["online_encoder_state_dict"] = {
            k.replace("online_encoder.", ""): v
            for k, v in model.state_dict().items() 
            if k.startswith("online_encoder.")
        }

    checkpoint_filename = f"evaluation_checkpoint.pth"
    checkpoint_path = os.path.join(checkpoint_directory, checkpoint_filename)
    torch.save(checkpoint_dict, checkpoint_path)
    logger.info(f"Evaluation checkpoint saved at: {checkpoint_path}.")


def load_evaluation_checkpoint(args, logger, encoder, checkpoint_directory, device):
    checkpoint_path = os.path.join(checkpoint_directory, "evaluation_checkpoint.pth")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)

    if args.model_name == "SDMI":
        encoder[0].load_state_dict(checkpoint["E_encoder_state_dict"])
        encoder[1].load_state_dict(checkpoint["M_encoder_state_dict"])
        logger.info(f"E and M encoders loaded from evaluation checkpoint.")
    else:
        encoder[0].load_state_dict(checkpoint["online_encoder_state_dict"])
        logger.info(f"Encoder loaded from evaluation checkpoint.")
