import pandas as pd
import numpy as np
import os
import torch
import json
import argparse
import joblib
from datetime import datetime

from .utils import get_folders
from models.data_utils import get_inference_dataloader
from models.architectures.PCA.training import dataloader_to_numpy
from models import DeepCAE, StackedCAE, JointVAE, ConvAE, StandardAE, TransformerAE, PCA
from benchmarking.benchmarking import get_configs


# TODO: Run on these with the right checkpoints:
# TODO: Ran already on the following, add back to list in the end: 
model_classes = [ConvAE, PCA, DeepCAE, StackedCAE, StandardAE, TransformerAE, JointVAE]
DATA_PATH = "artifacts/data"
EMBEDDINGS_PATH = "artifacts/embeddings"


def perform_model_inference(
    model_class, dataset_name: str, config: dict = None
) -> None:
    """
    Perform inference for a certain model and dataset and store the resulting embeddings and their labels.
    """

    print(
        f"Now working on {model_class.__name__} for {dataset_name} with config: \n{config}"
    )

    # Get data
    batch_size = 64
    dataloader = get_inference_dataloader(
        batch_size=batch_size,
        path_to_data=f"artifacts/data/{dataset_name}/processed.csv",
    )
    input_dim = next(iter(dataloader)).size()[1]
    dim_reduction = 0.505

    # Get model
    checkpoint_base_path = (
        f"artifacts/benchmarking/{model_class.__name__}/{dataset_name}/"
    )
    directories = os.listdir(checkpoint_base_path)
    directories = list(filter(lambda x: "old" not in x, directories))
    checkpoint_base_path += next(iter(directories))
    assert "old" not in checkpoint_base_path, "Accidentally pulled in an old checkpoint (probably from a different experiment)."
    print("The checkpoint base path: ", checkpoint_base_path)
    files = os.listdir(checkpoint_base_path)

    if model_class == PCA:
        pkl_file = next((file for file in files if file.endswith(".pkl")), None)
        assert (
            pkl_file != None
        ), f"No model Checkpoint found for dataset {dataset_name} and model {model_class}"
        pkl_file_path = os.path.join(checkpoint_base_path, pkl_file)

        model = joblib.load(pkl_file_path)
        data = dataloader_to_numpy(dataloader)
        embeddings = model.transform(data)

        # Storing the embeddings at the end.

    else:
        # Find the first file that ends with .pth
        pth_file = next((file for file in files if file.endswith(".pth")), None)

        assert (
            pth_file != None
        ), f"No model Checkpoint found for dataset {dataset_name} and model {model_class}"

        pth_file_path = os.path.join(checkpoint_base_path, pth_file)

        # Init and load the model
        model_kwargs = {
            "input_dim": input_dim,
        }
        hidden_spec = config.get("hidden_spec", None)
        if hidden_spec and len(hidden_spec) > 0:
            hidden_spec[-1] = round(input_dim * dim_reduction)
            if len(hidden_spec) == 2:
                # For the MultiLayer experiments comparing DeepCAE and StackedCAE.
                hidden_spec[-2] = round((input_dim + hidden_spec[-1]) / 2)

        if hidden_spec:
            model_kwargs["hidden_spec"] = hidden_spec
        else:
            # If there is no hidden spec, it has to be JointVAE and
            # there is a hidden_dim parameter instead.
            model_kwargs["hidden_dim"] = round(input_dim * dim_reduction)

        if channel_spec := config.get("channel_spec", None):
            model_kwargs["channel_spec"] = channel_spec

        if latent_spec := config.get("latent_spec", None):
            model_kwargs["latent_spec"] = latent_spec
        elif model_class == JointVAE:
            # Per default, we use no discrete variables.
            # We only use the same number of hidden continuous features as hidden_dim
            model_kwargs["latent_spec"] = {"cont": round(input_dim * dim_reduction)}

        model = model_class(**model_kwargs)
        model.load_state_dict(
            torch.load(pth_file_path, map_location=torch.device("cpu"))
        )

        model.eval()
        embeddings = []

        # No gradient calculation is needed during inference
        with torch.no_grad():
            for batch in dataloader:
                if model_class == TransformerAE:
                    batch = batch.unsqueeze(-1)

                if model_class == JointVAE or model_class == ConvAE:
                    batch = batch.unsqueeze(1)

                # Perform inference
                _, embeddings_batch = model(
                    batch
                )  # returns reconstruction and embeddings.

                # DeepCAE returns also the intermediate embeddings of hidden layers.
                if model_class == DeepCAE or model_class == StackedCAE:
                    embeddings_batch = embeddings_batch[-1]

                # JointVAE returns the latent distribution parameters.
                if model_class == JointVAE:
                    embeddings_batch = embeddings_batch["cont"][0]  # Get the mean.

                # Collect the embeddings
                embeddings.append(embeddings_batch.cpu().numpy())

        # Concatenate all embeddings to form a single numpy array
        embeddings = np.concatenate(embeddings, axis=0)

    # Convert the numpy array to a pandas DataFrame
    embeddings_df = pd.DataFrame(embeddings)

    # Rename the columns
    embeddings_df.columns = [f"emb_dim_{i+1}" for i in range(embeddings_df.shape[1])]

    # First load the config.
    with open(f"{DATA_PATH}/{dataset_name}/conf.json", "r") as config_file:
        config = json.load(config_file)

    data_df = pd.read_csv(f"{DATA_PATH}/{dataset_name}/processed.csv")

    # Separate the label as per config
    target_names = config["target"]
    labels = data_df[target_names]
    embeddings_df[target_names] = labels

    # Now save embeddings with labels to .csv
    embedding_base_path = f"{EMBEDDINGS_PATH}/{model_class.__name__}/{dataset_name}"
    os.makedirs(embedding_base_path, exist_ok=True)
    embeddings_df.to_csv(f"{embedding_base_path}/embeddings.csv", index=False)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, default="artifacts/data")
    parser.add_argument("--dataset-name", nargs="+", help="Dataset name, e.g. TeaRetail", required=False)
    parser.add_argument("--model-name", nargs="+", help="Model name, e.g. DeepCAE", required=False)

    args = parser.parse_args()

    start_time = datetime.now()

    # Perform inference for all datasets and all models
    if not args.dataset_name:
        dataset_names = get_folders(args.data_path)
    else:
        dataset_names = list(args.dataset_name)
    if "MNIST" in dataset_names:
        dataset_names.remove("MNIST")  # MNIST will need special
    print(f"Dataset names: {dataset_names}")

    if not args.model_name:
        model_names = [model.__name__ for model in model_classes]
    else:
        model_names = list(args.model_name)
        model_classes = [eval(model_name) for model_name in model_names]
    print(f"Model names: {model_names}")
    model_configs = get_configs(model_names, dataset_names)

    print('Now starting to perform inference.')
    for model_class in model_classes:
        for dataset_name in dataset_names:

            config = model_configs[dataset_name].get(model_class.__name__, None)
            perform_model_inference(model_class, dataset_name, config)

    end_time = datetime.now()
    runtime = end_time - start_time
    print(f"Finished inference after a duration of {runtime}!")
