# data_utils.py
import pandas as pd
import numpy as np
import math
import os
import time


def load_elec_data(file_path="data/electricity-normalized.csv", max_lag=2):
    elec_df = pd.read_csv(file_path)

    time_range = [0.2, 0.875]
    elec_df = elec_df[
        (elec_df["date"] >= time_range[0]) & (elec_df["date"] <= time_range[1])
    ].reset_index(drop=True)

    for _ in range(2):
        if elec_df.empty or len(elec_df["transfer"].unique()) == 1:
            break
        const_ind = elec_df[
            elec_df["transfer"] != elec_df["transfer"].iloc[0]
        ].index.min()
        if pd.isna(const_ind):
            elec_df = pd.DataFrame()
            break
        if const_ind > 0:
            elec_df = elec_df.iloc[const_ind:].reset_index(drop=True)
        else:
            break

    if elec_df.empty:
        return None, None, 0

    elec_df_selected = elec_df[
        ["nswprice", "vicprice", "nswdemand", "vicdemand", "transfer"]
    ].copy()

    elec_df_selected.ffill(inplace=True)
    elec_df_selected.bfill(inplace=True)
    elec_df_selected.fillna(0, inplace=True)

    elec_y_np = elec_df_selected["transfer"].values

    X_dict = {}
    X_dict["raw"] = elec_df_selected[
        ["nswprice", "vicprice", "nswdemand", "vicdemand"]
    ].values

    if max_lag > 0:
        lagged_data = {}
        for i in range(1, max_lag + 1):
            lagged_data[f"lag_{i}"] = elec_df_selected["transfer"].shift(i)
        elec_lag_df = pd.DataFrame(lagged_data)
        elec_lag_df.bfill(inplace=True)
        elec_lag_df.ffill(inplace=True)
        elec_lag_df.fillna(0, inplace=True)
        X_dict["lag"] = elec_lag_df.values
        if X_dict["raw"].shape[0] == X_dict["lag"].shape[0]:
            X_dict["all"] = np.hstack((X_dict["raw"], X_dict["lag"]))
        else:
            min_len = min(X_dict["raw"].shape[0], X_dict["lag"].shape[0])
            X_dict["all"] = np.hstack(
                (X_dict["raw"][:min_len], X_dict["lag"][:min_len])
            )
            elec_y_np = elec_y_np[:min_len]
            X_dict["raw"] = X_dict["raw"][:min_len]
            X_dict["lag"] = X_dict["lag"][:min_len]
    else:
        X_dict["lag"] = np.empty((len(elec_y_np), 0))
        X_dict["all"] = X_dict["raw"]

    X_dict["cluster_features"] = elec_df_selected[["nswprice"]].values

    if max_lag > 0 and len(elec_y_np) > max_lag:
        elec_y_np = elec_y_np[max_lag:]
        for key in X_dict:
            if X_dict[key].shape[0] > max_lag:
                X_dict[key] = X_dict[key][max_lag:, :]
    n_obs = len(elec_y_np)

    if n_obs == 0:
        return None, None, 0

    for key in X_dict:
        if X_dict[key].shape[0] != n_obs:
            X_dict[key] = X_dict[key][:n_obs]

    return elec_y_np, X_dict, n_obs


def simulate_arma_data_with_lags(
    n_total, p_exog_features, burnin, max_lag, random_seed=1
):
    np.random.seed(random_seed)
    X_exog_full = np.random.randn(n_total, p_exog_features)

    num_nonzero = math.floor(p_exog_features * 0.10)
    num_zero = p_exog_features - num_nonzero
    if num_zero < 0:
        num_zero = 0
    beta = np.concatenate(
        [np.random.normal(0, 4, size=num_nonzero), np.zeros(num_zero)]
    )
    np.random.shuffle(beta)

    y_full = np.zeros(n_total)
    err_full = np.zeros(n_total)
    y_full[0], err_full[0] = np.random.randn(), np.random.randn()

    for i in range(1, n_total):
        err_full[i] = np.random.randn()
        y_full[i] = (
            0.3 * y_full[i - 1]
            + X_exog_full[i, :] @ beta
            - 0.3 * err_full[i - 1]
            + err_full[i]
        )

    Y_burned = y_full[burnin:]
    X_exog_burned = X_exog_full[burnin:, :]

    if len(Y_burned) == 0:
        return None, None, 0

    X_dict = {"raw": X_exog_burned}
    Y_final = Y_burned

    if max_lag > 0:
        Y_series = pd.Series(Y_burned)
        lagged_Y_features_df = pd.DataFrame()
        for i in range(1, max_lag + 1):
            lagged_Y_features_df[f"y_lag_{i}"] = Y_series.shift(i)

        X_exog_burned_df = pd.DataFrame(X_exog_burned)
        combined_for_align = pd.concat(
            [X_exog_burned_df, lagged_Y_features_df, Y_series.rename("target")], axis=1
        )
        combined_for_align.dropna(inplace=True)

        if combined_for_align.empty:
            return None, None, 0

        Y_final = combined_for_align["target"].values
        X_dict["raw"] = combined_for_align.iloc[:, :p_exog_features].values
        X_dict["lag"] = combined_for_align.iloc[:, p_exog_features:-1].values
        X_dict["all"] = np.hstack((X_dict["raw"], X_dict["lag"]))
    else:
        X_dict["lag"] = np.empty((len(Y_final), 0))
        X_dict["all"] = X_dict["raw"]

    X_dict["cluster_features"] = (
        X_dict["raw"][:, [0]]
        if X_dict["raw"].shape[1] > 0
        else np.empty((len(Y_final), 0))
    )
    n_obs = len(Y_final)

    for key in X_dict:
        if X_dict[key].shape[0] != n_obs:
            X_dict[key] = X_dict[key][:n_obs]

    return Y_final, X_dict, n_obs


def load_and_preprocess_dataset(dataset_name, config):
    max_lag = config.get("max_lag", 2)
    Y_np, X_dict, n_obs = None, None, 0

    if dataset_name == "elec":
        data_file_path = os.path.join(
            config.get("data_dir", "data/"), "electricity-normalized.csv"
        )
        Y_np, X_dict, n_obs = load_elec_data(data_file_path, max_lag)
    elif dataset_name == "aram":
        n_total = config.get("aram_n_total", 15000)
        p_exog_features = config.get("aram_p_exog_features", 20)
        burnin = config.get("aram_burnin", 500)
        current_seed = config.get("seed", int(time.time()))
        Y_np, X_dict, n_obs = simulate_arma_data_with_lags(
            n_total, p_exog_features, burnin, max_lag, random_seed=current_seed
        )
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    if Y_np is None or X_dict is None or n_obs == 0:
        return None, None, 0

    return Y_np, X_dict, n_obs


if __name__ == "__main__":
    print("----- Testing ELEC2 data loader -----")
    elec_config_test = {
        "dataset": "elec",
        "data_dir": "data/",
        "max_lag": 2,
        "seed": 42,
    }
    y_e, Xd_e, n_e = load_and_preprocess_dataset("elec", elec_config_test)
    if y_e is not None and n_e > 0:
        print(f"ELEC Y shape: {y_e.shape}, n_obs: {n_e}")
        for k, v in Xd_e.items():
            print(f"  X_dict['{k}'] shape: {v.shape}")

    print("\n----- Testing ARAM (Synthetic ARMA) data loader -----")
    aram_config_test = {
        "dataset": "aram",
        "max_lag": 2,
        "seed": 123,
        "aram_n_total": 500,
        "aram_p_exog_features": 3,
        "aram_burnin": 50,
    }
    y_a, Xd_a, n_a = load_and_preprocess_dataset("aram", aram_config_test)
    if y_a is not None and n_a > 0:
        print(f"ARAM Y shape: {y_a.shape}, n_obs: {n_a}")
        for k, v in Xd_a.items():
            print(f"  X_dict['{k}'] shape: {v.shape}")
