import argparse
from src.data import create_parasitic_dataset, load_yeast, load_K562_scGPT, load_RPE1, load_cell_painting, load_K562, load_dataset
from src.embeddings import create_pca_embeddings_with_noise
from src.mlp import train_mlp_model, train_and_evaluate_mlp_attention
from src.imputation import evaluate_imputations
import torch
import random
import numpy as np
from src.utils import save_results_to_pickle


def set_global_seed(seed):
    random.seed(seed)               # Python's random module
    np.random.seed(seed)            # NumPy
    torch.manual_seed(seed)         # PyTorch
    torch.cuda.manual_seed(seed)    # For GPU operations
    torch.backends.cudnn.deterministic = True  # Ensure deterministic behavior
    torch.backends.cudnn.benchmark = False     # Disable optimization for reproducibility

def resolve_device(device_arg: str) -> torch.device:
    if device_arg.isdigit():
        index = int(device_arg)
        if torch.cuda.is_available() and index < torch.cuda.device_count():
            return torch.device(f"cuda:{index}")
        else:
            print(f"[WARNING] Requested GPU cuda:{index} not available. Falling back to CPU.")
            return torch.device("cpu")
    elif device_arg == "cpu":
        return torch.device("cpu")
    else:
        raise ValueError(f"Invalid device argument: {device_arg}. Use 'cpu' or a GPU index like '0'.")

# Press the green button in the gutter to run the script.
def main(args):
    # Create the dataset
    # y_noise_std = 1
    # num_train = 1000
    # num_test = 50
    # relevant_fraction = 0.01
    # SNR = 2
    # sparsity = 0.9
    # pert_magnitude = 3
    # yeast = False
    # lr = 0.001
    # epochs = 1000
    # n_components = 0.99
    # k_extra = 10
    # noise_scale_embeddings = 0.
    # x_columns = 10
    # mlp_hidden_units = 100
    # method = 'mlp'

    GLOBAL_SEED = args.seed

    set_global_seed(GLOBAL_SEED)

    W_true_test = None
    W_true_val = None
    X_train_loss = None
    X_test_loss = None
    X_val_loss = None
    W_ridge = None
    W_sparse = None
    W_lasso = None
    X_test_approx = None

    if args.dataset in ["yeast", "K562", "K562_scGPT", "cell_painting"]:
        if args.dataset == "cell_painting":
            num_samples= 1000
        else:
            num_samples = None
        X_train, Y_train, X_val, Y_val, X_test, Y_test = load_dataset(args.dataset, num_test=args.num_test, columns=args.x_columns, val=True, seed=args.seed, use_random_split=args.use_random_split, num_samples=num_samples, target_ind=args.target_indx)
    elif args.dataset == 'RPE1_binary' or args.dataset == 'RPE1_rank':
        X_train, Y_train, X_val, Y_val, X_test, Y_test = load_RPE1(num_test=args.num_test, columns=args.x_columns, val=True, dataset=args.dataset, use_random_split=args.use_random_split)
    else:
        #X_train, Y_train, X_val, Y_val, X_test, Y_test, W_true_val, W_true_test = create_parasitic_dataset(
        #    seed=GLOBAL_SEED,
        #    num_train=args.num_train,
        #    num_test=args.num_test,
        #    relevant_fraction=args.relevant_fraction,
        #    sparsity=args.sparsity,
        #    noise_std=args.y_noise_std,
        #    SNR=args.SNR,
        #    pert_magnitude=args.pert_magnitude,
        #    columns=args.x_columns
        #)
        X_train, Y_train, X_val, Y_val, X_test, Y_test, W_true_val, W_true_test = create_parasitic_dataset(pert_magnitude=args.pert_magnitude)

    # Embeddings creation
    # E_train, E_val, E_test = create_pca_embeddings_with_noise(
    #    X_train, X_val, X_test,
    #    n_components=args.n_components,
    #    k_extra=args.k_extra,
    #    noise_scale=args.noise_scale_embeddings,
    #    seed=GLOBAL_SEED
    # )

    # Create Noise
    X_mean = np.mean(np.concatenate([X_train, X_val, X_test]), axis=0)
    X_std = np.std(np.concatenate([X_train, X_val, X_test]), axis=0)
    noise_X_train = np.random.RandomState(args.seed).normal(X_mean, scale=X_std*args.noise_scale_embeddings, size=X_train.shape)
    X_train += noise_X_train
    noise_X_val = np.random.RandomState(args.seed).normal(X_mean, scale=X_std*args.noise_scale_embeddings, size=X_val.shape)
    X_val += noise_X_val
    noise_X_test = np.random.RandomState(args.seed).normal(X_mean, scale=X_std*args.noise_scale_embeddings, size=X_test.shape)
    X_test += noise_X_test


    E_train, E_val, E_test = X_train, X_val, X_test
    device = resolve_device(args.device)
    print(f"[INFO] Using device: {device}")

    # Convert numpy arrays to torch tensors with dtype=torch.float32 and move to device
    X_train, Y_train = torch.from_numpy(X_train).float().to(device), torch.from_numpy(Y_train).float().to(device)
    X_test, Y_test = torch.from_numpy(X_test).float().to(device), torch.from_numpy(Y_test).float().to(device)
    E_train, E_test = torch.from_numpy(E_train).float().to(device), torch.from_numpy(E_test).float().to(device)
    if X_val is not None:
        X_val = torch.from_numpy(X_val).float().to(device)
        Y_val = torch.from_numpy(Y_val).float().to(device)
        E_val = torch.from_numpy(E_val).float().to(device)
    if W_true_test is not None:
        W_true_val = torch.from_numpy(W_true_val).float().to(device)
        W_true_test = torch.from_numpy(W_true_test).float().to(device)

    if args.method == 'mlp':
        X_test_approx, X_train_loss, X_test_loss, X_val_loss = train_mlp_model(
            E_train=E_train,
            X_train=X_train,
            E_test=E_test,
            X_test=X_test,
            mlp_hidden_units=args.mlp_hidden_units,
            E_val=E_val,
            X_val=X_val,
            lr=args.lr,
            epochs=args.epochs,
            seed=GLOBAL_SEED,
            device=device
        )
    elif args.method == 'none':
        X_test_approx = X_test

    if args.method == 'mlp' or args.method == 'none':
        losses, W_ridge, W_sparse, W_lasso = evaluate_imputations(
            X_train=X_train,
            Y_train=Y_train,
            X_val=X_val,
            Y_val=Y_val,
            X_test_approx=X_test_approx,
            Y_test=Y_test,
            criterion=torch.nn.MSELoss(),
            lambda_reg_grid=[1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 10, 100],
            ridge_param_grid={"lambda_": [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 10, 100]},   # for yeast 0.01,
            ols_ridge_param_grid={"alpha": [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 10, 100]},
            lasso_param_grid={"lam": [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 10, 100], "lr": [1e-6, 1e-4, 1e-2, 1e-1], "max_iter": [2000, 5000]},
            attention_param_grid={"lambda_reg":[1e-6], "lr":[1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3], "hidden_dim":[32, 64, 128]},
            W_true_val=W_true_val,
            W_true_test=W_true_test
        )

    if args.method == 'attention':
        losses = train_and_evaluate_mlp_attention(
            E_train,
            X_train,
            Y_train,
            E_test,
            Y_test,
            args.mlp_hidden_units,
            E_val,
            Y_val,
            lambda_reg=1e-6,
            lr=args.lr,
            epochs=args.epochs,
            device=device
        )
    print(losses)
    print("\nEvaluation Metrics (formatted to 4 decimal places):")
    print("=" * 60)
    print(f"{'Method':<20}{'MSE':<20}{'Pearson Correlation':<20}")
    print("-" * 60)
    for method, metrics in losses.items():
        mse = metrics['mse']
        pearson_corr = metrics['pearson_corr']
        print(f"{method:<20}{mse:<20.4f}{pearson_corr:<20.4f}")
    print("=" * 60)

    # Save losses, W estimates, W_true, X_test, X_test_approx.
    save_results_to_pickle(
        global_seed=GLOBAL_SEED,
        method=args.method,
        hidden_units=args.mlp_hidden_units,
        dataset=args.dataset,
        perturbation=args.pert_magnitude,
        noise_scale_embeddings=args.noise_scale_embeddings,
        k_extra=args.k_extra,
        num_test=args.num_test,
        n_components=args.x_columns,
        epochs=args.epochs,
        lr=args.lr,
        use_random_split=args.use_random_split,
        losses=losses,
        W_ridge=W_ridge,
        W_sparse=W_sparse,
        W_lasso=W_lasso,
        W_true_val=W_true_val,
        W_true_test=W_true_test,
        X_test=X_test,
        X_test_approx=X_test_approx,
        X_train_loss=X_train_loss,
        X_test_loss=X_test_loss,
        X_val_loss=X_val_loss,
        target_indx=args.target_indx
    )


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train and evaluate a model with embeddings and datasets.")

    # Set seed
    parser.add_argument("--seed", type=int, required=True, help="Seed for reproducibility.")

    # Set device
    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        help="Device to use: 'cpu' or GPU index (e.g., 0, 1, 2)"
    )

    # Dataset selection
    parser.add_argument("--dataset", type=str, choices=['yeast', 'RPE1_binary', 'RPE1_rank', 'K562_scGPT', 'K562', 'synthetic', 'cell_painting'], required=True,
                        help="Dataset to use: 'yeast' or 'synthetic'.")

    # Default parameters
    parser.add_argument("--num_train", type=int, default=1000, help="Number of training samples. (synthetic only)")
    parser.add_argument("--num_test", type=int, default=50, help="Number of test samples.")

    # Synthetic-specific dataset parameters
    parser.add_argument("--y_noise_std", type=float, default=1, help="Standard deviation of noise in Y_train.")  # Default is 1
    parser.add_argument("--relevant_fraction", type=float, default=0.01, help="Fraction of relevant training samples (synthetic only).")
    parser.add_argument("--SNR", type=float, default=2, help="Signal-to-noise ratio (synthetic only).")
    parser.add_argument("--sparsity", type=float, default=0.9, help="Sparsity level of the dataset (synthetic only).")
    parser.add_argument("--pert_magnitude", type=float, default=1, help="Perturbation magnitude for dataset (synthetic only).")

    # Embeddings arguments
    parser.add_argument("--n_components", type=float, help="Number of PCA components or variance retained.")
    parser.add_argument("--k_extra", type=int, help="Number of noise-only dimensions to add to embeddings.")
    parser.add_argument("--noise_scale_embeddings", type=float, help="Scale of noise added to embeddings.")
    parser.add_argument("--x_columns", type=int, help="Number of columns to use from dataset.")  # 100 for synthetic and 10 for real yeast

    # MLP / NN arguments
    parser.add_argument(
        "--method",
        type=str,
        choices=['mlp', 'attention', 'none'],  # Include 'none' as an explicit option
        required=True,
        help="Method to use for training: 'mlp', 'attention', or 'none'."
    )
    parser.add_argument("--mlp_hidden_units", type=int, help="Number of hidden units in the MLP.")
    parser.add_argument("--lr", type=float, help="Learning rate for training.")
    parser.add_argument("--epochs", type=int, help="Number of epochs for training.")

    parser.add_argument("--target_indx", type=int, default=1, help="Target index for cell painting dataset.")
    parser.add_argument("--use_random_split", type=int, default=1, help="Toggle to use random split or most OOD split.")

    args = parser.parse_args()

    # Validate synthetic-specific parameters
    #if args.dataset == 'synthetic':
    #    if any(arg is None for arg in [args.relevant_fraction, args.SNR, args.sparsity, args.pert_magnitude]):
    #        parser.error("Arguments --relevant_fraction, --SNR, --sparsity, and --pert_magnitude are required for synthetic dataset.")
#
    #args = parser.parse_args()
    main(args)
