"""Functions for preprocessing."""

import pandas as pd
import numpy as np
import time
import os

from glob import glob
from src.utils import load_raw_data, load_linked_ent_data, set_split, col_names_per_type
from configs.preprocess_configs import preprocess_configs
from configs.path_configs import path_configs


def serialize_llm(data, target_name=None, prefix=None, lowercase=False):
    """Function to serialize a row in a table."""

    # Serialize
    data_ = data.copy()
    for col in data.columns:
        col_ = col.replace("\n", "").replace("_", " ")
        data_[col] = f"The {col_} is " + data_[col]
    data_.fillna("", inplace=True)
    data_serialized = data_[data_.columns.tolist()].agg(". ".join, axis=1) + "."
    data_serialized = data_serialized.str.replace(" .", "")

    # Control for lowercase
    if lowercase:
        data_serialized = data_serialized.str.lower()

    # Control for target and prefix
    if target_name is not None:
        data_serialized += f" What is the value of {target_name}?"
    if prefix is not None:
        data_serialized = f"{prefix} " + data_serialized

    return pd.DataFrame(data_serialized, columns=["serialized"])


def serialize_kgt5(data, target_name=None, prefix=None, lowercase=False):
    """Function to serialize a row in a table with kgt5 format."""

    # Serialize
    data_ = data.copy()
    for col in data.columns:
        col_ = col.replace("\n", "").replace("_", " ")
        data_[col] = f"{col_} | " + data_[col]
    data_.fillna("", inplace=True)
    data_serialized = data_[data_.columns.tolist()].agg(". ".join, axis=1) + "."
    data_serialized = data_serialized.str.replace(" .", "")

    # Control for lowercase
    if lowercase:
        data_serialized = data_serialized.str.lower()

    # Control for target and prefix
    if target_name is not None:
        data_serialized += f"Predict: {target_name}."
    if prefix is not None:
        data_serialized = f"{prefix} " + data_serialized

    return pd.DataFrame(data_serialized, columns=["serialized"])


def prepare_tarte(X_train, X_test=None):
    """Function to prepare TARTE embeddings."""

    from tarte_ai import TARTE_TablePreprocessor, TARTE_TableEncoder
    from sklearn.pipeline import Pipeline

    fasttext_model_path = path_configs["fasttext_path"]

    tarte_tab_prepper = TARTE_TablePreprocessor(fasttext_model_path=fasttext_model_path)
    tarte_tab_encoder = TARTE_TableEncoder(
        layer_index=2
    )  # Can change which layer to extract the embeddings

    prep_pipe = Pipeline([("prep", tarte_tab_prepper), ("tabenc", tarte_tab_encoder)])

    X_train = prep_pipe.fit_transform(X_train)
    if X_test is not None:
        X_test = prep_pipe.transform(X_test)

    return X_train, X_test


def prepare_tabvec(X_train, X_test):
    """Function to prepare with StringEncoder (TabVec)."""

    from skrub import StringEncoder, TableVectorizer

    text_encoder = StringEncoder(random_state=1234)
    encoder = TableVectorizer(high_cardinality=text_encoder, cardinality_threshold=1)

    X_train = encoder.fit_transform(X_train)
    X_test = encoder.transform(X_test)

    return np.array(X_train), np.array(X_test)


def _load_huggingface(model_name):
    """Function to load a model uploaded in huggingface."""

    # Preliminary check
    cache_folder = path_configs["huggingface_cache_folder"]
    model_configs = preprocess_configs[model_name]
    model_base_path = (
        f'{cache_folder}/models--{model_configs['model_name'].replace("/", "--")}'
    )
    assert os.path.exists(
        model_base_path
    ), "The huggingface model is missing. Download the model before loading the model."

    from sentence_transformers import SentenceTransformer

    model_path = glob(f"{model_base_path}/snapshots/*/config.json")[0].split(
        "config.json"
    )[0]

    return SentenceTransformer(
        model_name_or_path=model_path,
        cache_folder=cache_folder,
        device="cuda",
    )


def extract_huggingface(
    model_name,
    X_train,
    X_test=None,
    n_components=None,
):
    """Function to extract embeddings from huggingface models."""

    # Normalize
    normalize_embeddings = True if X_test is None else False

    # Load model
    lm_model = _load_huggingface(model_name)

    # Calculate batch_size
    X_temp = pd.concat([X_train, X_test], axis=0)
    max_length = [X_temp[col].str.len().max() for col in X_temp.columns]
    batch_size = 32
    # Exceptions to llama-based models
    if (model_name == "llm-row_llama3") | (model_name == "llm-col_llama3"):
        batch_size = 16
        lm_model.tokenizer.pad_token = lm_model.tokenizer.eos_token
    if model_name == "tm_tabula":
        batch_size = 16
    if np.max(max_length) > 1000:
        batch_size = 4

    def _extract_per_col(col):

        out_train = lm_model.encode(
            np.array(X_train[col].astype(str)),
            convert_to_numpy=True,
            batch_size=batch_size,
            normalize_embeddings=normalize_embeddings,
        )

        if X_test is not None:

            out_test = lm_model.encode(
                np.array(X_test[col].astype(str)),
                convert_to_numpy=True,
                batch_size=batch_size,
                normalize_embeddings=False,
            )

            if n_components is not None:

                from sklearn.decomposition import PCA
                from skrub._scaling_factor import scaling_factor

                # Run PCA
                scale_factor = scaling_factor(out_train)
                pca = PCA(n_components=n_components, random_state=1234)
                out_train = pca.fit_transform(out_train)
                out_test = pca.transform(out_test)
                # block-scale
                out_train /= scale_factor
                out_test /= scale_factor

        else:
            out_test = X_test

        return out_train, out_test

    out = [_extract_per_col(col) for col in X_train.columns]
    out_train = np.hstack([x for (x, _) in out])

    if X_test is not None:
        out_test = np.hstack([x for (_, x) in out])
    else:
        out_test = X_test

    return out_train, out_test


def extract_fasttext(
    X_train,
    X_test=None,
    n_components=None,
):
    """Function to extract fasttext embeddings."""

    import fasttext

    fasttext_path = path_configs["fasttext_path"]
    lm_model = fasttext.load_model(fasttext_path)

    def _extract_per_col(col):

        out_train = [lm_model.get_sentence_vector(str(x)) for x in X_train[col]]
        out_train = np.array(out_train)

        if X_test is not None:

            out_test = [lm_model.get_sentence_vector(str(x)) for x in X_test[col]]
            out_test = np.array(out_test)

            if n_components is not None:

                from sklearn.decomposition import PCA
                from skrub._scaling_factor import scaling_factor

                # Run PCA
                scale_factor = scaling_factor(out_train)
                pca = PCA(n_components=n_components, random_state=1234)
                out_train = pca.fit_transform(out_train)
                out_test = pca.transform(out_test)
                # block-scale (skrub handling)
                out_train /= scale_factor
                out_test /= scale_factor

        else:
            out_test = X_test

        return out_train, out_test

    out = [_extract_per_col(col) for col in X_train.columns]
    out_train = np.hstack([x for (x, _) in out])

    if X_test is not None:
        out_test = np.hstack([x for (_, x) in out])
    else:
        out_test = X_test

    return out_train, out_test


def preprocess_llm(
    X_train,
    X_test,
    target_name,
    embed_method,
):
    """Function to preprocess with LLMs."""

    # Preliminaries
    method_parse = embed_method.split("_")
    preprocess_method = method_parse[0]
    prefix = preprocess_configs[embed_method]["prefix"]
    n_components = 30 if "col" in preprocess_method else None

    # Set data for serialization
    if "row" in preprocess_method:
        X_train = serialize_llm(X_train, target_name, prefix)
        if X_test is not None:
            X_test = serialize_llm(X_test, target_name, prefix)
    else:
        pass

    if preprocess_configs[embed_method]["huggingface"]:
        X_train, X_test = extract_huggingface(
            embed_method,
            X_train,
            X_test,
            n_components,
        )
    elif preprocess_configs[embed_method]["model_name"] == "fasttext":
        X_train, X_test = extract_fasttext(
            X_train,
            X_test,
            n_components,
        )

    return X_train, X_test


def preprocess_kg(
    X_train,
    X_test,
    target_name,
    embed_method,
):
    """Function to preprocess with knowledge graph embedding."""

    # Preliminaries
    method_parse = embed_method.split("_")
    model_name_ = method_parse[1]
    n_components = None

    if "kgt5" in model_name_:
        # Serialize
        X_train = serialize_kgt5(X_train, target_name)
        if X_test is not None:
            X_test = serialize_kgt5(X_test, target_name)
        # Extract embeddings
        X_train, X_test = extract_huggingface(
            embed_method,
            X_train,
            X_test,
            n_components,
        )
    elif "knowledge-card-wiki" in model_name_:
        # Serialize
        X_train = serialize_llm(X_train, target_name)
        if X_test is not None:
            X_test = serialize_llm(X_test, target_name)
        # Extract embeddings
        X_train, X_test = extract_huggingface(
            embed_method,
            X_train,
            X_test,
            n_components,
        )
    elif "tarte" in model_name_:
        X_train, X_test = prepare_tarte(X_train, X_test)

    return X_train, X_test


def preprocess_tm(
    X_train,
    X_test,
    target_name,
    embed_method,
):
    """Function to preprocess with knowledge graph embedding."""

    # Preliminaries
    method_parse = embed_method.split("_")
    model_name_ = method_parse[1]
    n_components = None

    if "tabula" in model_name_:
        # Serialize
        X_train = serialize_llm(X_train, target_name)
        if X_test is not None:
            X_test = serialize_llm(X_test, target_name)
        # Extract embeddings
        X_train, X_test = extract_huggingface(
            embed_method,
            X_train,
            X_test,
            n_components,
        )

    return X_train, X_test


def extract_total_embedding(data_name, embed_method):
    """Function to extract the embedding of the total dataframe."""

    # Load data, set preliminary
    data, data_config = load_raw_data(data_name)
    target_name = data_config["target"]
    _, cat_col, _ = col_names_per_type(data, target_name)
    data_cat = data[cat_col].copy()

    # Extract embeddings
    start_time = time.perf_counter()

    # preprocess depending on the preprocessing method
    if "llm" in embed_method.split("_")[0]:
        data_embed, _ = preprocess_llm(data_cat, None, target_name, embed_method)
    elif "kg" in embed_method.split("_")[0]:
        data_embed, _ = preprocess_kg(data_cat, None, target_name, embed_method)
    elif "tm" in embed_method.split("_")[0]:
        data_embed, _ = preprocess_tm(data_cat, None, target_name, embed_method)

    # Clean the column names and concat witht the target
    data_embed = pd.DataFrame(data_embed)
    col_names = [f"X{i}" for i in range(data_embed.shape[1])]
    data_embed = data_embed.set_axis(col_names, axis="columns")
    data_embed = pd.concat([data_embed, data[target_name]], axis=1)

    end_time = time.perf_counter()
    duration_preprocess = round(end_time - start_time, 4)

    return data_embed, duration_preprocess


def preprocess_linked_data(data_name, embed_method):
    """Function to extract linked data."""

    # Load data and preliminaries
    data, data_config = load_linked_ent_data(data_name)
    target_name = data_config["target"]
    wiki_col = "wikidata_id"
    ent_cols = data.columns.tolist()
    ent_cols.remove(target_name)
    ent_cols.remove(wiki_col)
    df_party = None
    # exception
    if data_name == "carte_us_presidential":
        ent_cols.remove("party")
        df_party = pd.get_dummies(data["party"], prefix="party")
        df_party = df_party.astype(float)

    prefix = preprocess_configs[embed_method]["prefix"]
    n_components = None

    start_time = time.perf_counter()

    # Prepare for different modes
    if "kg_kgt5" in embed_method:
        data_embed = data[ent_cols] + " | " + target_name
    elif "llm-" in embed_method.split("_")[0]:
        data_embed = f"What is {target_name} of " + data[ent_cols] + "?"
    elif "tm_tabula" in embed_method:
        data_embed = f"What is {target_name} of " + data[ent_cols] + "?"
    elif "kg_knowledge-card-wiki" in embed_method:
        data_embed = f"What is {target_name} of " + data[ent_cols] + "?"
    elif "kg_tarte" in embed_method:
        data_embed = data[ent_cols].copy()

    if prefix is not None:
        data_embed = f"{prefix} " + data_embed[ent_cols]

    if preprocess_configs[embed_method]["huggingface"]:
        data_embed, _ = extract_huggingface(
            embed_method,
            data_embed,
            None,
            n_components,
        )
    elif preprocess_configs[embed_method]["model_name"] == "fasttext":
        data_embed, _ = extract_fasttext(
            data_embed,
            None,
            n_components,
        )
    elif embed_method == "kg_tarte":
        data_embed, _ = prepare_tarte(data_embed, None)

    # Clean the column names and concat witht the target
    data_embed = pd.DataFrame(data_embed)
    col_names = [f"X{i}" for i in range(data_embed.shape[1])]
    data_embed = data_embed.set_axis(col_names, axis="columns")
    data_embed = pd.concat([data_embed, df_party, data[target_name]], axis=1)

    end_time = time.perf_counter()
    duration_preprocess = round(end_time - start_time, 4)

    return data_embed, duration_preprocess


# Preprocess (We want to cache this)
def preprocess_with_cache(data_name, embed_method, num_train, random_state):

    # Load data, Set preliminary
    data, data_config = load_raw_data(data_name)
    X_train, X_test, y_train, y_test = set_split(
        data, data_config, num_train, random_state
    )
    target_name = data_config["target"]

    start_time = time.perf_counter()

    # preprocess depending on the preprocessing method
    if "llm" in embed_method.split("_")[0]:
        X_train, X_test = preprocess_llm(X_train, X_test, target_name, embed_method)
    elif "kg" in embed_method.split("_")[0]:
        X_train, X_test = preprocess_kg(X_train, X_test, target_name, embed_method)
    elif "tm" in embed_method.split("_")[0]:
        X_train, X_test = prepare_tabvec(X_train, X_test)

    end_time = time.perf_counter()
    duration_preprocess = round(end_time - start_time, 4)

    return X_train, X_test, y_train, y_test, duration_preprocess
