import os
import pickle
from typing import Tuple

import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold, train_test_split
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler

from src.embeddings import create_pca_embeddings_with_noise

# Get the absolute path to the root of the repository
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

# Define dataset paths relative to the repo root
PATH_TO_YEAST = os.path.join(REPO_ROOT, "yeast")
PATH_TO_K562_scGPT = os.path.join(REPO_ROOT, "K562_scGPT")
PATH_TO_K562 = os.path.join(REPO_ROOT, "K562")
PATH_TO_RPE1_rank = os.path.join(REPO_ROOT, "RPE1_rank")
PATH_TO_RPE1_binary = os.path.join(REPO_ROOT, "RPE1_binary")
PATH_TO_CELL_PAINTING = os.path.join(REPO_ROOT, "cell_painting")


def generate_X_train_data(num_train=100, seed=None, SNR=100, sparsity=0.9):
    rng = np.random.default_rng(seed)  # Create a local RNG

    num_features = num_train

    # Step 1: Sample diagonal elements
    sigma = -1 * rng.uniform(0, 1, num_features)  # Diagonal values
    num_train = len(sigma)
    X_train = np.diag(sigma)

    # Step 2: Generate off-diagonal elements with controlled sparsity
    M = rng.normal(0, 1. / SNR, (num_train, num_features))  # Off-diagonal elements
    mask = rng.random((num_train, num_features)) > sparsity  # Sparsity mask
    M *= mask  # Apply sparsity
    np.fill_diagonal(M, 0)  # Keep diagonal values as sigma

    # Step 3: Add off-diagonal elements to X_train
    X_train += M
    return X_train

def create_parasitic_dataset(
    n_train=5,          # number of training environments (2 is enough)
    n_val=10,           # validation envs
    n_test=10,          # test envs
    beta_c=1.0,         # invariant causal coefficient (y = beta_c * t)
    sigma_x=0.05,       # measurement noise std on X means
    sigma_y=0.005,      # tiny averaging noise on Y means
    a_train=None,       # list/array of length n_train with proxy ties a_e; if None -> [ +1, -1 ] when n_train=2, else sampled
    a_range=(-5.0, 5.0),# range to sample proxy ties for OOD mode
    mode="LEC",         # "OOD" (new interventions) or "LEC" (linear combos of train envs)
    alpha=2.0,          # OOD scale for t in test/val (std multiplier) if mode=="OOD"
    lec_k=None,         # number of train envs used per combo (if None use all) when mode=="LEC"
    lec_weight_scale=(1.0, 3.0),  # magnitude range for LEC weights
    seed=None
):
    """
    Minimal DRIG-style PoC: 2 features (causal x_c and proxy x_p), explicit additive interventions on X only.
    Environment means only; outcome depends *only* on the causal feature mean.

    Generative model (per environment e):
        t_e ~ N(0, 1)                  # intervention strength (explicit delta on causal)
        a_e                            # env-specific tie from causal to proxy
        x*_c^e = t_e
        x*_p^e = a_e * t_e
        y^e = beta_c * t_e + eps_y
        observed means:
            x_c^e = x*_c^e + eps_x_c,  x_p^e = x*_p^e + eps_x_p  with eps_x_* ~ N(0, sigma_x^2)

    Test/Val:
        mode == "OOD":
            t_v ~ N(0, alpha^2), a_v ~ Uniform(a_range), build X* and y as above.
        mode == "LEC":
            choose weights W over training envs; X*_valtest = W @ X*_train (rowwise),
            y_valtest = W @ y_train; then add small measurement noise to X only.

    Returns
    -------
    X_train : (n_train, 2)  observed train means [x_c, x_p]
    y_train : (n_train,)
    X_val   : (n_val, 2)
    y_val   : (n_val,)
    X_test  : (n_test, 2)
    y_test  : (n_test,)
    W_val   : (n_val, n_train) or None   # only for mode=="LEC"
    W_test  : (n_test, n_train) or None  # only for mode=="LEC"
    meta    : dict with useful internals (t_train, a_train, noise-free X*, etc.)
    """
    rng = np.random.default_rng(seed)

    # --- TRAIN ENVIRONMENTS ---
    # proxy ties a_e for train
    if a_train is None:
        if n_train == 2:
            a_train = np.array([+1.0, -1.0])
        else:
            a_train = rng.uniform(a_range[0], a_range[1], size=n_train)
    else:
        a_train = np.asarray(a_train, dtype=float)
        assert len(a_train) == n_train, "a_train length must equal n_train"

    # explicit intervention strengths t_e (delta on causal mean)
    t_train = rng.normal(0.0, 1.0, size=n_train)

    # noise-free means X*_train = [[t_e, a_e * t_e]]
    Xstar_train = np.stack([t_train, a_train * t_train], axis=1)

    # observed means with small measurement noise on X only
    X_train = Xstar_train + rng.normal(0.0, sigma_x, size=Xstar_train.shape)

    # outcomes (invariant, purely causal)
    y_train = beta_c * t_train + rng.normal(0.0, sigma_y, size=n_train)

    # --- BUILD VAL/TEST ---
    total = n_val + n_test
    if mode.upper() == "OOD":
        # fresh interventions (stronger magnitude), proxy ties within range
        t_valtest = rng.normal(0.0, alpha, size=total)
        a_valtest = rng.uniform(a_range[0], a_range[1], size=total)

        Xstar_valtest = np.stack([t_valtest, a_valtest * t_valtest], axis=1)
        X_valtest = Xstar_valtest + rng.normal(0.0, sigma_x, size=Xstar_valtest.shape)
        y_valtest = beta_c * t_valtest + rng.normal(0.0, sigma_y, size=total)

        # No environment-combo weights in OOD mode
        W_val = None
        W_test = None

    elif mode.upper() == "LEC":
        # linear combinations of training environments
        k = n_train if lec_k is None else int(max(1, lec_k))
        W = np.zeros((total, n_train))
        for i in range(total):
            idx = rng.choice(n_train, size=min(k, n_train), replace=False)
            w = rng.normal(0.0, 1.0, size=idx.size)
            # scale magnitudes (allows extrapolation inside span)
            scale = rng.uniform(lec_weight_scale[0], lec_weight_scale[1])
            w *= scale
            W[i, idx] = w

        Xstar_valtest = W @ Xstar_train
        X_valtest = Xstar_valtest + rng.normal(0.0, sigma_x, size=Xstar_valtest.shape)
        y_valtest = W @ y_train + rng.normal(0.0, sigma_y, size=total)

        W_val, W_test = W[:n_val], W[n_val:]
    else:
        raise ValueError("mode must be 'OOD' or 'LEC'")

    # split val/test
    X_val, X_test = X_valtest[:n_val], X_valtest[n_val:]
    y_val, y_test = y_valtest[:n_val], y_valtest[n_val:]

    meta = {
        "beta_c": float(beta_c),
        "t_train": t_train,
        "a_train": a_train,
        "Xstar_train": Xstar_train,
        "Xstar_val": Xstar_valtest[:n_val],
        "Xstar_test": Xstar_valtest[n_val:],
        "mode": mode.upper(),
        "sigma_x": float(sigma_x),
        "sigma_y": float(sigma_y),
    }
    if mode.upper() == "OOD":
        meta.update({"t_val": t_valtest[:n_val], "t_test": t_valtest[n_val:],
                     "a_val": a_valtest[:n_val], "a_test": a_valtest[n_val:]})

    return X_train, y_train, X_val, y_val, X_test, y_test, W_val, W_test, #meta


def create_parasitic_dataset(pert_magnitude = 1):
    n_features = 3
    n_samples = 50
    n_env = 2
    n_test = 100
    

    B = np.random.standard_normal((n_features, n_features))
    for i in range(n_features):
        for j in range(i+1,n_features):
            B[i,j] = 0
    trainX = np.zeros((n_env, n_features))
    deltas = np.eye(n_features)[:,np.random.choice(n_features-1, (n_env), replace=False)]

    for i in range(n_env):
        X = np.random.normal(0,1, (n_features, n_samples))
        epsilon = np.random.normal(0,0.3,(n_features, n_samples))
        delta = deltas[:,i].reshape(-1,1)

        for _ in range(20):
            X = (B @ X) + epsilon + delta 
            X = X /np.linalg.norm(X)


        trainX[i] = np.mean(X, axis=1)
    
    alpha = np.random.normal(0,0.1, (n_test*2, n_env))
    alpha += np.sum((pert_magnitude * np.eye(n_env))[np.random.choice(n_env, (n_test*2, 1))], axis=1)

    testX = alpha @ trainX 


    mean_trainX = np.mean(trainX, axis=0)
    trainX = trainX - mean_trainX
    testX = testX - mean_trainX

    return trainX[:,:-1], trainX[:,-1], testX[:n_test,:-1], testX[:n_test,-1], testX[n_test:,:-1], testX[n_test:,-1], alpha[:n_test], alpha[n_test:]

def train_and_evaluate(X_train, Y_train, X_test, Y_test, weights=None):
    model = LinearRegression(fit_intercept=False)
    model.fit(X_train, Y_train, sample_weight=weights)
    Y_pred = model.predict(X_test)
    mse = mean_squared_error(Y_test, Y_pred)
    return mse, Y_pred


def find_most_ood_samples(
        X: np.ndarray,
        Y: np.ndarray,
        num_test: int,
        val: bool,
        path: str,
        columns: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Identifies the most Out-of-Distribution (OOD) samples based on highest MSE.
    Uses caching to avoid recomputation.

    Args:
        X (np.ndarray): Feature matrix of interventions (N_samples, N_features).
        Y (np.ndarray): Target matrix of interventions (N_samples, N_targets).
        num_test (int): Number of test samples.
        val (bool): Whether to create a validation set.
        path (str): Path to dataset folder for caching results.
        columns (int): Number of columns used in dataset (for cache naming).

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray]:
            - train_inds (np.ndarray): Indices for training set.
            - val_inds (np.ndarray): Indices for validation set (empty if val=False).
            - test_inds (np.ndarray): Indices for test set.
    """

    # Define cache file path
    cache_file = os.path.join(path, f"{columns}_most_OOD_samples.pkl")

    # Check if cached results exist
    if os.path.exists(cache_file):
        print(f"Loading cached OOD sample indices from {cache_file}...")
        with open(cache_file, "rb") as f:
            cache_data = pickle.load(f)
            mse_scores = cache_data["mse_scores"]
            worst_indices = cache_data["worst_indices"]
    else:
        print("Cache not found. Running LOOCV to compute OOD sample indices...")

        # Initialize MSE scores
        mse_scores = np.zeros(len(X))

        # Leave-One-Out Cross Validation (LOOCV)
        kf = KFold(n_splits=len(X), shuffle=False)

        # Run LOOCV with progress bar
        for train_index, test_index in tqdm(kf.split(X), total=len(X), desc="Computing OOD scores"):
            X_train, X_test = X[train_index], X[test_index]
            Y_train, Y_test = Y[train_index], Y[test_index]

            # Train and evaluate with unweighted ERM
            mse_unweighted, _ = train_and_evaluate(X_train, Y_train, X_test, Y_test)
            mse_scores[test_index] = mse_unweighted  # Store MSE for each sample

        # Identify worst-performing samples (highest MSE)
        worst_indices = np.argsort(-mse_scores)  # Descending order

        # Save results to cache
        with open(cache_file, "wb") as f:
            pickle.dump({"mse_scores": mse_scores, "worst_indices": worst_indices}, f)
        print(f"Saved OOD sample indices to {cache_file}")

    # Ensure sufficient samples for splitting
    if len(worst_indices) < 2 * num_test + 1:
        raise ValueError("Not enough samples in `worst_indices` for the specified splits.")

    # Select test and validation indices
    selected_indices = worst_indices[:2 * num_test]  # Take 2*num_test indices
    test_inds = selected_indices[::2]  # Alternate indices for test set
    val_inds = selected_indices[1::2] if val else np.array([], dtype=int)  # Alternate indices for validation

    # Remaining indices are for training
    train_inds = worst_indices[2 * num_test:] if val else worst_indices[num_test:]

    return train_inds, val_inds, test_inds


def load_dataset(dataset_name, num_test=50, columns=None, val=False, seed=1, use_random_split=1, num_samples=None, target_ind=0):
    
    if dataset_name == "cell_painting":
        X = np.concatenate([np.load(f"cell_painting/X_ChemBERT_part_{i}.npy", allow_pickle=True) for i in range(4)], axis=0)
        Y = np.load("cell_painting/Y_ChemBERT_full.npy", allow_pickle=True) # Does not need to be normalized. Done internally.
        Y = Y[:, target_ind].squeeze()
    elif dataset_name == "K562":
        path = os.path.join(PATH_TO_K562, "features.pkl")
        with open(path, 'rb') as f:
            data = pickle.load(f)

        X, Y = data['X_intervention'].values, data['Y_intervention'].values.squeeze()
    elif dataset_name == "K562_scGPT":
        path = os.path.join(PATH_TO_K562_scGPT, "features.pkl")
        with open(path, 'rb') as f:
            data = pickle.load(f)

        X, Y = data['X'].values, data['Y'].values.squeeze()
    elif dataset_name == "yeast":
        path = os.path.join(PATH_TO_YEAST, "features.pkl")
        with open(path, 'rb') as f:
            data = pickle.load(f)
        _, _, X, Y, _ = data['X_obs'], data['Y_obs'], data['X_intervention'], data['Y_intervention'], data['TC']
    
    

    if num_samples is not None:
        if num_samples < X.shape[0]:
            indeces = np.random.RandomState(seed).choice(range(X.shape[0]), num_samples, replace=False) 
            X = X[indeces]
            Y = Y[indeces]

    obs_median = np.median(X, axis=0).reshape(1, -1)
    obs_iqr = np.subtract(*np.percentile(X, [75, 25], axis=0)).reshape(1, -1)
    X_z = (X - obs_median) / (obs_iqr + 1e-6)
    X = np.clip(X_z, -10, 10)

    if columns == -1:
        columns = X.shape[1]
    else:
        inds = np.argsort(np.var(X, 0))
        X = X[:, inds[-columns:]]
    
    if use_random_split == 0:
        # Find most OOD samples
        train_inds, val_inds, test_inds = find_most_ood_samples(X, Y, num_test, val, os.path.join(REPO_ROOT, dataset_name), columns)
    else:
        rng = np.random.default_rng(seed)
        shuffled = rng.permutation(len(X))
#
        test_inds = shuffled[:num_test]
        val_inds = shuffled[num_test:2*num_test]
        train_inds = shuffled[2*num_test:]

    # Create datasets
    X_train, Y_train = X[train_inds], Y[train_inds]
    X_test, Y_test = X[test_inds], Y[test_inds]
    X_val, Y_val = (X[val_inds], Y[val_inds]) if val else (None, None)

    # center response
    Y_train_mean = Y_train.mean()
    Y_train = Y_train - Y_train_mean
    if val:
        Y_val = Y_val - Y_train_mean
    Y_test = Y_test - Y_train_mean

    return X_train, Y_train, X_val, Y_val, X_test, Y_test
