#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import os
import json
import datetime
import random

import numpy as np
import pandas as pd
import torch
import ray

from itertools import product
from mmirt.utils.masks import create_balanced_mask
# from mogi.irt_onelayer_nostlict import Standard3PLIRT
from mmirt.mmirt_inner_product import Standard3PLIRT

if len(sys.argv) < 2:
    print("Usage: python script.py <dataset_index>")
    sys.exit(1)

num = int(sys.argv[1])
dataset_names = ["mmmu", "mathvista", "vqa","seed"]
dataset_name = dataset_names[num]

file1_default = "mmmu_normal_shuffle/df_cleaned_renamed_with_metrics_index.csv"
file2_default = "mmmu_normal_shuffle/shufffle_mmmu.csv"
file1_math    = "merged_mathvista/mathvista_combined.csv"
file2_math    = "merged_mathvista/merged_mathvista_shuffle.csv"
file1_vqa     = "vqa_shuffle_normal/vqa_merged.csv"
file2_vqa     = "vqa_shuffle_normal/vqa_shuffle_filtered.csv"
file1_seed = "seed_normal_shuffled/seed_normal_all_clean.csv"
file2_seed = "seed_normal_shuffled/seed_shuffled.csv"

file1_ls = [file1_default, file1_math, file1_vqa, file1_seed]
file2_ls = [file2_default, file2_math,    file2_vqa, file2_seed]

file1 = file1_ls[num]
file2 = file2_ls[num]

print(f"Dataset: {dataset_name}")
print(f"  normal file : {file1}")
print(f"  shuffle file: {file2}")

df1 = pd.read_csv(file1, index_col=0)
df2 = pd.read_csv(file2, index_col=0)
merged_df = pd.merge(
    df1, df2,
    left_index=True, right_index=True,
    suffixes=('_normal', '_shuffle')
)
full_resp = merged_df.values.astype(np.float32)
problems  = merged_df.columns.tolist()
models    = merged_df.index.tolist()
theta_max_list           = [2,4,8,16]
# difficulty_base_min_list = [-2]
# difficulty_base_max_list = [2]
# a_scale_list             = [2]

@ray.remote(num_cpus=0.5)
def eval_one_config(
    theta_max: float,
    difficulty_base_max: float,
    a_scale: float,
    difficulty_base_min: float = 0,
    seed: int = 0
) -> dict:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    train_mask, _ = create_balanced_mask(
        response_data=full_resp,
        test_percentage=0,
        item_names=problems,
        seed=seed,
        val_percentage=0.1,
    )
    train_mask = np.where(train_mask == -1, 1, train_mask)

    irt = Standard3PLIRT(
        response_data=full_resp,
        train_mask=train_mask,
        student_names=models,
        test_names=problems,
        split_difficulty=False,
        split_ability=False,
        lr=1e-2,
        batch_size=512,
        max_epochs=5000,
        device="cpu",
        eps=1e-3,
        theta_max=theta_max,
        difficulty_base_max=difficulty_base_max,
        difficulty_other_max=difficulty_base_max,
        a_scale=a_scale,
        use_guessing=False,
        enable_abs_clamp=False
    )
    irt = torch.compile(irt, backend="inductor")

    estimates = irt.fit()
    metrics   = irt.evaluate_predictions()

    fisher_info = irt.compute_item_fisher_information(dataset_name)

    return {
        "theta_max": theta_max,
        "difficulty_base_max": difficulty_base_max,
        "a_scale": a_scale,
        "roc_auc": metrics.get("roc_auc", float("nan")),
        "train_mask": train_mask.tolist(),
        "estimates": estimates,
        "fisher_info": fisher_info
    }

ray.init(ignore_reinit_error=True)

tasks = []
for tm in theta_max_list:
    tasks.append(
        eval_one_config.remote(
            theta_max=tm,
            difficulty_base_max=tm,
            a_scale=tm,
            # difficulty_base_min=0,
            seed=0
        )
    )

results = ray.get(tasks)

best = max(results, key=lambda x: x["roc_auc"])
print(f"Best ROC AUC: {best['roc_auc']:.4f}")
print(f"Params: theta_max={best['theta_max']}, difficulty_base_max={best['difficulty_base_max']}, a_scale={best['a_scale']}")

def convert_numpy(obj):
    if isinstance(obj, dict):
        return {k: convert_numpy(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [convert_numpy(v) for v in obj]
    if isinstance(obj, np.generic):
        return obj.item()
    return obj

best_estimates = convert_numpy(best["estimates"])
best_fisher    = convert_numpy(best["fisher_info"])
best_fisher = {k: v.tolist() if isinstance(v, np.ndarray) else v
               for k, v in best_fisher.items()}

os.makedirs("results", exist_ok=True)

with open(f"results/{dataset_name}_best_fisher_info_mmirt_onedim.json", "w", encoding="utf-8") as f:
    json.dump(best_fisher, f, ensure_ascii=False, indent=4)

print("result is saved")

ray.shutdown()
