
import os
import sys
import datetime
import pytz
import numpy as np
import pandas as pd
import ray
import mlflow
from scipy.stats import spearmanr
from mmirt.utils.masks import (
    create_validation_mask_for_items,
    create_test_mask_for_groups,
)

if "--mmmirt" in sys.argv:
    from mmirt.irt_onelayer_asplit_val import Standard3PLIRT
    split_difficulty = True
    split_ability = True
    split_a = True
    irt_name = "mmmirt"
elif "--noasplit" in sys.argv:
    raise NotImplementedError("noasplit 3PL is not implemented yet.")
elif "--nosplit" in sys.argv:
    from mmirt.irt_onelayer_nostlict import Standard3PLIRT
    split_difficulty = False
    split_ability = False
    split_a = False
    irt_name = "nosplit_3pl"
elif "--mmirt" in sys.argv:
    from mmirt.mmirt_inner_product import Standard3PLIRT
    split_difficulty = True
    split_ability = True
    split_a = True
    irt_name = "mmirt"
else:
    from mmirt.irt_onelayer_asplit_val import Standard3PLIRT
    split_difficulty = True
    split_ability = True
    split_a = True
    irt_name = "bsplit_3pl"

if "--mathvista" in sys.argv:
    dataset_name = "mathvista"
    data_dir = "data_mathvista"
    file1_path = "merged_mathvista/mathvista_combined.csv"
    file2_path = "merged_mathvista/merged_mathvista_shuffle.csv"
elif "--vqa" in sys.argv:
    dataset_name = "vqa"
    data_dir = "data_vqa"
    file1_path = "vqa_shuffle_normal/vqa_merged.csv"
    file2_path = "vqa_shuffle_normal/vqa_shuffle_filtered.csv"
elif "--seedbench" in sys.argv:
    dataset_name = "seed"
    data_dir = "data_seed"
    file1_path = "seed_normal_shuffled/seed_normal_all_clean.csv"
    file2_path = "seed_normal_shuffled/seed_shuffled.csv"
else:
    dataset_name = "mmmu"
    data_dir = "data_mmmu_0220"
    file1_path = "mmmu_normal_shuffle/df_cleaned_renamed_with_metrics_index.csv"
    file2_path = "mmmu_normal_shuffle/shufffle_mmmu.csv"

now = datetime.datetime.now()
mmdd = now.strftime("%m%d")
hhmm = now.strftime("%H%M")
experiment_name = f"mix_{dataset_name}_{irt_name}_{mmdd}_{hhmm}"
csv_name = f"mix_{dataset_name}_{irt_name}"

print("Loading data...")
df1 = pd.read_csv(file1_path, index_col=0)
df2 = pd.read_csv(file2_path, index_col=0)

if set(df1.index) != set(df2.index):
    raise ValueError("df1 and df2 do not have the same students (row indices).")


common_idx = sorted(df1.index)
df1 = df1.loc[common_idx]
df2 = df2.loc[common_idx]


mean_scores = df1.mean(axis=1)
base_rank_df = (
    mean_scores
    .sort_values(ascending=False)
    .reset_index()
    .rename(columns={"index": "model_id", 0: "mean_score"})
)
base_rank_df["rank"] = (
    base_rank_df["mean_score"]
    .rank(ascending=False, method="min")
    .astype(int)
)

resp1 = df1.values
resp2 = df2.values
students = df1.index.tolist()
items1 = df1.columns.tolist()
items2 = df2.columns.tolist()

print("Initializing Ray...")
ray.init(ignore_reinit_error=True)

if "--nosplit" not in sys.argv:
    theta_max_list           = [2, 4, 8, 16]
    difficulty_base_max_list = [2, 4, 8, 16]
    a_scale_list             = [2, 4, 8, 16]
    # theta_max_list           = [2]
    # difficulty_base_max_list = [2]
    # a_scale_list             = [2]
else:
    theta_max_list           = [2]
    difficulty_base_max_list = [2]
    a_scale_list             = [2]

@ray.remote(num_cpus=1)
def run_combo(mix_ratio, seed,
              theta_max, difficulty_base_max, a_scale,
              resp1, resp2, students, items1, items2):
    rng = np.random.RandomState(seed)

    k = int(np.floor(resp2.shape[1] * mix_ratio))
    idx = rng.choice(resp2.shape[1], size=k, replace=False)
    hybrid = np.concatenate([resp1, resp2[:, idx]], axis=1)
    item_names = items1 + [items2[i] for i in idx]


    total_items = len(item_names)
    val_indices = rng.choice(total_items, size=100, replace=False)
    val_item_names = [item_names[i] for i in val_indices]
    mask, _ = create_validation_mask_for_items(
        hybrid, item_names,
        val_item_names=val_item_names,
        val_percentage=0.2
    )
    mask[mask == 1] = -1 

    def extract_base(name):
        for p in ["<no_info>", "<no_image>", "<no_question>"]:
            if p in name:
                return name.replace(p, "")
        return name

    mask, _ = create_test_mask_for_groups(
        mask,
        response_data=hybrid,
        test_percentage=0.1,
        item_names=item_names,
        seed=seed
    )
    mask[mask == -1] = 1

    model = Standard3PLIRT(
        response_data=hybrid,
        student_names=students,
        test_names=item_names,
        train_mask=mask,
        split_difficulty=split_difficulty,
        split_ability=split_ability,
        lr=1e-3,
        batch_size=256,
        max_epochs=5000,
        device="cuda",
        eps=1e-3,
        embedding_dim=32,
        theta_max=theta_max,
        difficulty_base_max=difficulty_base_max,
        difficulty_other_max=difficulty_base_max,
        a_scale=a_scale,
        theta_init=0.1,
        use_guessing=False
    )
    model.fit()
    val_metrics = model.evaluate_predictions() if "--nosplit" in sys.argv else model.evaluate_predictions()
    test_metrics = model.evaluate_predictions()

    return {
        "mix_ratio": mix_ratio,
        "seed": seed,
        "theta_max": theta_max,
        "difficulty_base_max": difficulty_base_max,
        "a_scale": a_scale,
        "val_auc": val_metrics["roc_auc"],
        "test_auc": test_metrics["roc_auc"],
        "test_shuffle_auc": test_metrics["roc_auc_shuffle"],
        "test_normal_auc": test_metrics["roc_auc_normal"],
    }


print("Launching tasks...")
tasks = []
for mr in np.arange(0, 101, 5) / 100:
    for seed in range(7):
        for tm in theta_max_list:
                    tasks.append(
                        run_combo.remote(
                            mr, seed,
                            tm, tm, tm,
                            resp1, resp2, students, items1, items2
                        )
                    )

print(f"Total tasks: {len(tasks)}")
results = ray.get(tasks)
df = pd.DataFrame(results)

print("Selecting best settings...")
best_df = (
    df
    .sort_values(["mix_ratio", "seed", "val_auc"], ascending=[True, True, False])
    .groupby(["mix_ratio", "seed"], as_index=False)
    .first()
)
out_df = best_df[[
    "mix_ratio", "seed",
    "theta_max", "difficulty_base_max", "a_scale",
    "val_auc","test_auc","test_shuffle_auc","test_normal_auc"
]]

os.makedirs("result", exist_ok=True)
csv_path = f"result/{csv_name}_auc_best.csv"
out_df.to_csv(csv_path, index=False)

mlflow.set_experiment(experiment_name)
with mlflow.start_run(run_name=experiment_name):
    mlflow.log_param("dataset", dataset_name)
    mlflow.log_param("irt", irt_name)
    mlflow.log_param("grid_theta_max", theta_max_list)
    mlflow.log_param("grid_diff_base_max", difficulty_base_max_list)
    mlflow.log_param("grid_a_scale", a_scale_list)
    mlflow.log_artifact(csv_path)

print("Experiment finished.")
