
import os
import numpy as np
import pandas as pd


current_file_dir = os.path.dirname(os.path.abspath(__file__))
block_prices = pd.read_csv(os.path.join(current_file_dir, "block_prices.csv"), index_col=0)


def common_preprocessing(df: pd.DataFrame) -> pd.DataFrame:
    # log-transform
    X = np.log1p(df)
    # min max transformation
    q5 = X.quantile(0.05, axis=0)
    q95 = X.quantile(0.95, axis=0)
    X = X.sub(q5, axis=1).div((q95 - q5).replace(0., 1.), axis=1) - 0.5
    # fill NaN value with 0.
    X = X.fillna(0.)
    return X


def preprocess_node_features(df: pd.DataFrame) -> np.ndarray:

    df = df.replace(0., np.nan)
    df[["total_sent", "min_sent", "max_sent", "total_received", "min_received", "max_received"]] /= 10 ** 8
    avg_sent = (df["total_sent"] / df["total_transactions_out"].clip(lower=1.)).rename("avg_sent")
    avg_received = (df["total_received"] / df["total_transactions_in"].clip(lower=1.)).rename("avg_received")

    start = df["first_transaction_in"].min()
    price = block_prices.get(1000 * int(start // 1000))

    X = pd.concat([df["degree_in"], df["degree_out"],
                   df["total_transactions_in"], df["total_transactions_out"],
                   avg_sent, df["min_sent"], df["max_sent"],
                   avg_received, df["min_received"], df["max_received"],
                   avg_sent * price, df["min_sent"] * price, df["max_sent"] * price,
                   avg_received * price, df["min_received"] * price, df["max_received"] * price,
                   df["first_transaction_in"], df["first_transaction_out"],
                   df["last_transaction_in"], df["last_transaction_out"],
    ], axis=1)

    X = common_preprocessing(X)

    return X.values


def preprocess_edge_features(df: pd.DataFrame) -> np.ndarray:

    df = df.replace(0., np.nan)
    df[["total_sent", "min_sent", "max_sent"]] /= 10 ** 8
    avg_sent = (df["total_sent"] / df["total"].clip(lower=1.)).rename("avg_sent")

    X = pd.concat([df["total"],
                   avg_sent, df["min_sent"], df["max_sent"],
                   df["reveal"], df["last_seen"]], axis=1)

    X = common_preprocessing(X)

    return X.values