import json
import os
from typing import Tuple

import numpy as np
import torch
from sklearn.preprocessing import StandardScaler


def load_dataset(dataset_path):
    X = np.load(os.path.join(dataset_path, "X.npy"))
    Y = np.load(os.path.join(dataset_path, "Y.npy"))
    labels = np.load(os.path.join(dataset_path, "labels.npy"))

    with open(os.path.join(dataset_path, "config.json"), "r") as f:
        config = json.load(f)

    return {"X": X, "Y": Y, "labels": labels, "config": config}


def dataset_name_to_ids(dataset_name: str) -> Tuple[int, int]:
    split = dataset_name.split("_")
    return int(split[1]), int(split[2])


def scale_data(X, Y):
    scaler_x = StandardScaler()
    scaler_y = StandardScaler()
    X_scaled = scaler_x.fit_transform(X)
    Y_scaled = scaler_y.fit_transform(Y)

    # Convert to tensors
    X_tensor = torch.tensor(X_scaled, dtype=torch.float64)
    Y_tensor = torch.tensor(Y_scaled, dtype=torch.float64)
    if len(Y_tensor.shape) == 1:
        Y_tensor = Y_tensor.unsqueeze(1)

    # Create cut points
    data_limits = torch.zeros((X.shape[1], 2), dtype=torch.float64)
    for i in range(X.shape[1]):
        data_limits[i, 0] = torch.quantile(X_tensor[:, i], 0.0)
        data_limits[i, 1] = torch.quantile(X_tensor[:, i], 1.0)
    return X_tensor, Y_tensor, X_scaled, Y_scaled, data_limits, scaler_x, scaler_y
