#!/usr/bin/env python3
"""
Create minimal preprocessed data for adult so that run_exp_multi_method.sh can
run training without the full parent-repo data prep.

Reads clean .npy from data/adult/ (must exist; run this from AugMask_share/data/
or set DATA_ROOT). Applies MCAR missingness (p, seed) and writes to
data/adult_semi_xy/<pattern>_p<p>_<seed>/.

Usage (from AugMask_share):
  python data/create_minimal_preprocessed.py
  python data/create_minimal_preprocessed.py --pattern NU2 --p 0.9 --seed 0

Requires data/adult/ with .npy files and info.json (e.g. from parent repo data prep).
"""

import argparse
import json
import os
import shutil
import numpy as np


def get_mcar_mask(shape, p, seed):
    rng = np.random.default_rng(seed)
    return rng.random(shape) < p


def corrupt_and_save(data_root, dataname="adult", mode="both", pattern="NU2", p=0.9, seed=0):
    # Load clean data from data_root/dataname/
    clean_dir = os.path.join(data_root, dataname)
    if not os.path.isdir(clean_dir):
        raise FileNotFoundError(
            f"Clean data dir not found: {clean_dir}. "
            "Create it first (e.g. from parent repo data prep)."
        )

    X_num_train = np.load(os.path.join(clean_dir, "X_num_train.npy"))
    num_dtype = X_num_train.dtype
    X_num_test = np.load(os.path.join(clean_dir, "X_num_test.npy"))
    X_cat_train = np.load(os.path.join(clean_dir, "X_cat_train.npy"), allow_pickle=True)
    X_cat_test = np.load(os.path.join(clean_dir, "X_cat_test.npy"), allow_pickle=True)
    y_train = np.load(os.path.join(clean_dir, "y_train.npy"), allow_pickle=True)
    y_test = np.load(os.path.join(clean_dir, "y_test.npy"), allow_pickle=True)
    with open(os.path.join(clean_dir, "info.json")) as f:
        info = json.load(f)
    task_type = info["task_type"]

    train_X = np.concatenate([X_num_train, X_cat_train], 1)
    test_X = np.concatenate([X_num_test, X_cat_test], 1)
    full_y = np.concatenate([y_train, y_test], 0).astype(float) if task_type == "regression" else np.concatenate([y_train, y_test], 0).astype(str)
    full_X = np.concatenate([train_X, test_X], 0)
    full = np.concatenate([full_X, full_y], 1)

    # MCAR on full matrix (same as prepare_data NU1 / simple both)
    miss_mask = get_mcar_mask(full.shape, p, seed)
    full[miss_mask] = np.nan
    obs_mask = (~miss_mask).astype(float)
    mask_x = obs_mask[:, : full_X.shape[1]]
    mask_y = obs_mask[:, full_X.shape[1] :]

    full_X = full[:, : full_X.shape[1]]
    full_y = full[:, full_X.shape[1] :]
    train_X_ = full_X[: train_X.shape[0]]
    test_X_ = full_X[train_X.shape[0] :]
    train_mask_ = mask_x[: train_X.shape[0]]
    test_mask_ = mask_x[train_X.shape[0] :]

    n_num = X_num_train.shape[1]
    X_num_train_ = train_X_[:, :n_num].astype(num_dtype)
    X_cat_train_ = train_X_[:, n_num:].copy()
    mask_num_train_ = train_mask_[:, :n_num].astype(num_dtype)
    mask_cat_train_ = train_mask_[:, n_num:]
    X_num_train_[mask_num_train_ == 0] = np.nan
    X_cat_train_[mask_cat_train_ == 0] = "nan"

    X_num_test_ = test_X_[:, :n_num].astype(num_dtype)
    X_cat_test_ = test_X_[:, n_num:].copy()
    mask_num_test_ = test_mask_[:, :n_num].astype(num_dtype)
    mask_cat_test_ = test_mask_[:, n_num:]
    X_num_test_[mask_num_test_ == 0] = np.nan
    X_cat_test_[mask_cat_test_ == 0] = "nan"

    train_y_ = full_y[: y_train.shape[0]]
    test_y_ = full_y[y_train.shape[0] :]
    train_mask_y_ = mask_y[: y_train.shape[0]]
    test_mask_y_ = mask_y[y_train.shape[0] :]
    if task_type != "regression":
        train_y_[train_mask_y_ == 0] = "nan"
        test_y_[test_mask_y_ == 0] = "nan"
    else:
        train_y_[train_mask_y_ == 0] = np.nan
        test_y_[test_mask_y_ == 0] = np.nan
        train_y_ = train_y_.astype(np.float32)
        test_y_ = test_y_.astype(np.float32)

    if task_type in ("binclass", "multiclass", "classification"):
        mask_cat_train_ = np.concatenate([train_mask_y_, mask_cat_train_], 1)
        mask_cat_test_ = np.concatenate([test_mask_y_, mask_cat_test_], 1)
    else:
        mask_num_train_ = np.concatenate([train_mask_y_, mask_num_train_], 1)
        mask_num_test_ = np.concatenate([test_mask_y_, mask_num_test_], 1)

    save_dir = os.path.join(data_root, f"{dataname}_semi_xy")
    exp_dir = os.path.join(save_dir, f"{pattern}_p{p}_{seed}")
    os.makedirs(exp_dir, exist_ok=True)
    shutil.copyfile(os.path.join(clean_dir, "info.json"), os.path.join(exp_dir, "info.json"))

    np.save(os.path.join(exp_dir, "X_num_train.npy"), X_num_train_)
    np.save(os.path.join(exp_dir, "X_num_test.npy"), X_num_test_)
    np.save(os.path.join(exp_dir, "X_cat_train.npy"), X_cat_train_)
    np.save(os.path.join(exp_dir, "X_cat_test.npy"), X_cat_test_)
    np.save(os.path.join(exp_dir, "y_train.npy"), train_y_)
    np.save(os.path.join(exp_dir, "y_test.npy"), test_y_)
    np.save(os.path.join(exp_dir, "mask_num_train.npy"), mask_num_train_)
    np.save(os.path.join(exp_dir, "mask_num_test.npy"), mask_num_test_)
    np.save(os.path.join(exp_dir, "mask_cat_train.npy"), mask_cat_train_)
    np.save(os.path.join(exp_dir, "mask_cat_test.npy"), mask_cat_test_)
    np.save(os.path.join(exp_dir, "y_mask_train.npy"), train_mask_y_)
    np.save(os.path.join(exp_dir, "y_mask_test.npy"), test_mask_y_)
    print(f"Saved preprocessed data to {exp_dir}")
    return exp_dir


def main():
    parser = argparse.ArgumentParser(description="Create minimal preprocessed adult data.")
    parser.add_argument("--data_root", type=str, default=None, help="Root data dir (default: parent of this script = AugMask_share/data)")
    parser.add_argument("--dataname", type=str, default="adult")
    parser.add_argument("--pattern", type=str, default="NU2")
    parser.add_argument("--p", type=float, default=0.9)
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()
    if args.data_root is None:
        args.data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    corrupt_and_save(args.data_root, dataname=args.dataname, mode="both", pattern=args.pattern, p=args.p, seed=args.seed)


if __name__ == "__main__":
    main()
