import numpy as np
from typing import Any, Literal, TypeVar, Union
from pathlib import Path
from scipy.special import expit
import scipy.stats
from sklearn.model_selection import train_test_split

PROJECT_DIR = Path(__file__).parent.parent

CACHE_DIR = PROJECT_DIR / 'cache'
DATA_DIR = PROJECT_DIR / 'data'
EXP_DIR = PROJECT_DIR / 'exp'


DataKey = Literal['x_num', 'x_bin', 'x_cat', 'x_meta', 'y']


def get_path(path: str | Path) -> Path:
    path = str(path)
    if path.startswith(":"):
        path = PROJECT_DIR / path[1:]
    return Path(path).absolute().resolve()

def load_data(path) -> dict[DataKey, dict[str, np.ndarray]]:
    path = get_path(path)
    # path = get_path(self.ds_name, self.path)
    split = "default"
    data = {}
    for key in ['X_num', 'X_bin', 'X_cat', 'X_meta', 'Y']:
        if not path.joinpath(f'{key}.npy').exists():
            print(f'No {key}')
            continue

        arr = np.load(path / f'{key}.npy', allow_pickle=False, mmap_mode='r')
        data[key.lower()] = {
            part: arr[np.load(path / f'split-{split}/{part}_idx.npy')]
            for part in ['train', 'val', 'test']
        }

    return data

def load_data_fair(path) -> dict[DataKey, dict[str, np.ndarray]]:
    path = get_path(path)
    # path = get_path(self.ds_name, self.path)
    split = "default"
    data = {}
    for key in ['X_num', 'X_bin', 'X_cat', 'X_meta', 'Y', 'Z']:
        if not path.joinpath(f'{key}.npy').exists():
            print(f'No {key}')
            continue

        arr = np.load(path / f'{key}.npy', allow_pickle=False, mmap_mode='r')
        data[key.lower()] = {
            part: arr[np.load(path / f'split-{split}/{part}_idx.npy')]
            for part in ['train', 'val', 'test']
        }

    return data


def get_dataset(dataset_name: str):
    path = DATA_DIR / dataset_name
    data = load_data(path)
    X_train = data["x_num"]["train"]
    y_train = data["y"]["train"]
    X_test = data["x_num"]["test"]
    y_test = data["y"]["test"]
    X_val = data["x_num"]["val"]
    y_val = data["y"]["val"]
    X_test = np.concatenate([X_val, X_test])
    y_test = np.concatenate([y_val, y_test])
    return X_train, y_train, X_test, y_test


def load_data_synthetic(n_samples, n_dimensions, n_directions, ratio, var, alpha):
    w = np.zeros(n_dimensions)
    w[: n_directions] = 1

    # normalize w
    w /= np.linalg.norm(w)

    # compute an orthogonal vector
    w_perp = create_orthonormal_vector(w)
    mean, cov, P, D = get_mean_cov(w, w_perp, ratio, var)
    X_model = scipy.stats.multivariate_normal(mean=mean, cov=cov)

    X = X_model.rvs(size=(n_samples,1), random_state=0)
    print(X.shape)
    dot = np.dot(X, w)
    H = expit(dot)
    F = expit(dot) + delta(X, w, w_perp, alpha)
    y = np.random.binomial(1, F)
    return X, y, H, F, w

def get_mean_cov(w, w_perp, ratio, var):
    d = w.shape[0]
    P = np.eye(d)
    P[:, 0] = w
    P[:, 1] = w_perp
    P = gram_schmidt_orthonormalization(P)
    D = np.eye(d)
    D[0, 0] = ratio
    D *= var
    # Create PSD cov from PDP^-1 decomposition
    cov = P @ D @ np.linalg.inv(P)
    mean = np.zeros_like(w)
    return mean, cov, P, D



def delta(X, w, w_perp, alpha):
    dot_perp = np.dot(X, w_perp)
    dot = np.dot(X, w)
    _delta_max = np.minimum(1 - expit(dot), expit(dot))
    y = 2 * expit(alpha * dot_perp) - 1
    return np.multiply(np.sign(y) * np.abs(y), _delta_max)

def create_orthonormal_vector(x):
    """Return a vector orthonormal to the given one."""
    if np.allclose(x, np.zeros_like(x)):
        raise ValueError("x is null")
    if len(x) < 2:
        raise ValueError("x must be at least 2 dimensional to find orthonormal vector")
    y = np.zeros_like(x)
    m = np.argmax(x != 0)
    n = (m + 1) % len(x)
    y[n] = x[m]
    y[m] = -x[n]
    y /= np.linalg.norm(y)
    return y

def gram_schmidt_orthonormalization(P):
    """Turn given basis of R^d into an orthonormal basis."""
    P = P.copy()

    # Make it orthogonal using Gram-Schmidt orthogonalization procedure
    for i in range(1, P.shape[1]):
        P[:, i] -= np.sum([projection(P[:, i], P[:, j]) for j in range(i)], axis=0)

    # Normalize
    P = np.divide(P, np.linalg.norm(P, axis=0, keepdims=True))

    return P

def projection(x, y):
    """Projection of x onto y."""
    return np.vdot(x, y) / np.square(np.linalg.norm(y)) * y
