import joblib
import numpy as np
import pandas as pd
import torch
import os

# save model and embeddings
def save_pickle(operator, embeddings, base_dir):
    """
    Save model operator as a pickle file and embeddings as a CSV file in the Hydra output directory.
    """
    os.makedirs(base_dir, exist_ok=True)
    operator_path = os.path.join(base_dir, "model_operator.pkl")
    embedding_path = os.path.join(base_dir, "embeddings.csv")

    # Save the RF-PHATE operator
    joblib.dump(operator, operator_path)
    # Save the embeddings as a CSV file
    pd.DataFrame(embeddings).to_csv(embedding_path, index=False)


def save_autoencoder(model, embeddings, base_dir):
    """
    Save autoencoder model as a checkpoint and embeddings as a CSV file in the Hydra output directory.
    """
    os.makedirs(base_dir, exist_ok=True)
    ckpt_path = os.path.join(base_dir, "ae_checkpoint.ckpt")
    embedding_path = os.path.join(base_dir, "embeddings.csv")

    # Save the autoencoder model as a PyTorch checkpoint
    torch.save(model.state_dict(), ckpt_path)
    # Save the embeddings as a CSV file
    pd.DataFrame(embeddings).to_csv(embedding_path, index=False)


def save_model_and_embeddings(model, embeddings, base_dir, save_type):
    """
    Save model and embeddings based on the model type in the Hydra output directory.

    Args:
        model: The model instance (e.g., RF-PHATE or Autoencoder).
        embeddings: The embeddings generated by the model.
        base_dir: Base directory for saving files.
        model_type: The type of the model ('rfphate' or 'autoencoder').
    """
    if save_type == "pickle":
        save_pickle(model, embeddings, base_dir)
    elif save_type == "checkpoint":
        save_autoencoder(model, embeddings, base_dir)
    else:
        raise ValueError(f"Unsupported model type: {save_type}")



# load model and embeddings
def load_pickle(operator_path, embedding_path):
    """
    Load model operator and embeddings.
    """
    operator = joblib.load(operator_path)

    embeddings = pd.read_csv(embedding_path).values

    return operator, embeddings

def load_autoencoder(model, ckpt_path, embedding_path):
    """
    Load autoencoder model and embeddings.
    """
    model.load_state_dict(torch.load(ckpt_path))

    embeddings = pd.read_csv(embedding_path).values

    return model, embeddings

def load_model_and_embeddings(model, load_paths, save_type):
    """
    Load model and embeddings based on the model type.

    Args:
        model: The model instance.
        load_paths: Dictionary containing paths for loading (e.g., ckpt_path, operator_path, embedding_path).
        model_type: The type of the model ('rfphate' or 'autoencoder').

    Returns:
        model: Loaded model.
        embeddings: Loaded embeddings.
    """
    if save_type == "pickle":
        return load_pickle(load_paths['operator_path'], load_paths['embedding_path'])
    elif save_type == "checkpoint":
        return load_autoencoder(model, load_paths['ckpt_path'], load_paths['embedding_path'])
    else:
        raise ValueError(f"Unsupported model type: {save_type}")
