import os
import json
import argparse
import numpy as np

def main(args):
    input_response = args.input_response or os.environ.get("INPUT_RESPONSE", "")
    input_theta = args.input_theta or os.environ.get("INPUT_THETA_JSONL", "")
    output_path = args.output_path or os.environ.get("OUTPUT_PATH", "aug_sampling.jsonl")
    n_synthetic = args.n_synthetic if args.n_synthetic is not None else int(os.environ.get("N_SYNTHETIC", "200"))
    alpha = args.alpha if args.alpha is not None else float(os.environ.get("ALPHA", "0.1"))
    epsilon = args.epsilon if args.epsilon is not None else float(os.environ.get("EPSILON", "1e-6"))
    seed = args.seed if args.seed is not None else int(os.environ.get("SEED", "42"))

    np.random.seed(seed)

    if not input_theta or not os.path.exists(input_theta):
        raise FileNotFoundError(f"Ability file not found: {input_theta}")
    if not input_response or not os.path.exists(input_response):
        raise FileNotFoundError(f"Response matrix file not found: {input_response}")

    print(f"Loading abilities from '{input_theta}'...")
    with open(input_theta, "r", encoding="utf-8") as f:
        theta_data = json.loads(f.readline())

    ability_list = theta_data["ability"]
    subject_id_map_from_json = theta_data["subject_ids"]
    sorted_subject_items = sorted(subject_id_map_from_json.items(), key=lambda item: int(item[0]))
    ordered_subject_ids = [item[1] for item in sorted_subject_items]
    theta_map = dict(zip(ordered_subject_ids, ability_list))
    theta_values = np.array(ability_list, dtype=float)
    print("Abilities loaded.")

    print(f"Loading response matrix from '{input_response}'...")
    with open(input_response, "r", encoding="utf-8") as f:
        raw_data = [json.loads(ln) for ln in f]

    subject_ids = [d["subject_id"] for d in raw_data]
    item_ids = sorted({item for d in raw_data for item in d["responses"]}, key=lambda x: int(x.split("_")[-1]))
    item_index = {iid: i for i, iid in enumerate(item_ids)}

    n_students = len(subject_ids)
    n_items = len(item_ids)

    response_matrix = np.full((n_students, n_items), np.nan, dtype=float)
    for i, d in enumerate(raw_data):
        for iid, val in d["responses"].items():
            if iid in item_index:
                response_matrix[i, item_index[iid]] = float(val)

    subject_id_to_idx = {sid: i for i, sid in enumerate(subject_ids)}
    print("Response matrix loaded.")

    p_base = np.nanmean(response_matrix, axis=0)
    d_i = 1.0 - p_base

    print(f"\nGenerating {n_synthetic} synthetic samples...")
    augmented = []

    for k in range(n_synthetic):
        theta_b = np.random.uniform(theta_values.min(), theta_values.max())
        idx_a = np.abs(theta_values - theta_b).argmin()
        theta_a = theta_values[idx_a]
        subj_a = ordered_subject_ids[idx_a]
        if subj_a not in subject_id_to_idx:
            print(f"Warning: subject '{subj_a}' not found in response matrix, skipping.")
            continue
        response_a = response_matrix[subject_id_to_idx[subj_a]]
        lambda_i = alpha / (d_i + epsilon)
        delta_theta = theta_b - theta_a
        with np.errstate(invalid="ignore"):
            sign_adjust = (2 * response_a - 1)
        sign_adjust[np.isnan(sign_adjust)] = 0
        sigmoid = 1 / (1 + np.exp(-delta_theta))
        p_b = p_base + lambda_i * sigmoid * sign_adjust
        p_b = np.clip(p_b, 0.001, 0.999)
        n_b = np.random.binomial(1, p_b)
        response = {iid: int(n_b[j]) for j, iid in enumerate(item_ids)}
        sample = {"subject_id": f"sample_{k}", "responses": response}
        augmented.append(sample)

    with open(output_path, "w", encoding="utf-8") as f:
        for entry in augmented:
            f.write(json.dumps(entry) + "\n")

    print(f"\nSynthetic data saved to {output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input-response")
    parser.add_argument("--input-theta")
    parser.add_argument("--output-path")
    parser.add_argument("--n-synthetic", type=int)
    parser.add_argument("--alpha", type=float)
    parser.add_argument("--epsilon", type=float)
    parser.add_argument("--seed", type=int)
    args = parser.parse_args()
    main(args)
