import sys
import os
import random
import json
from itertools import product

import numpy as np
import pandas as pd
import torch
import ray

from mmirt.utils.masks import create_balanced_mask
from mmirt.irt_onelayer_asplit_val import Standard3PLIRT

file1_default = "mmmu_normal_shuffle/df_cleaned_renamed_with_metrics_index.csv"
file2_default = "mmmu_normal_shuffle/shufffle_mmmu.csv"
file1_math    = "merged_mathvista/mathvista_combined.csv"
file2_math    = "merged_mathvista/merged_mathvista_shuffle.csv"
file1_vqa     = "vqa_shuffle_normal/vqa_merged.csv"
file2_vqa     = "vqa_shuffle_normal/vqa_shuffle_filtered.csv"
file1_seed = "seed_normal_shuffled/seed_normal_all_clean.csv"
file2_seed = "seed_normal_shuffled/seed_shuffled.csv"

file1_ls = [file1_default, file1_math, file1_vqa, file1_seed]
file2_ls = [file2_default, file2_math, file2_vqa, file2_seed]
dataset_names = ["mmmu", "mathvista", "vqa", "seed"]

num = int(sys.argv[1])
dataset_name = dataset_names[num]
file1 = file1_ls[num]
file2 = file2_ls[num]

df1 = pd.read_csv(file1, index_col=0)
df2 = pd.read_csv(file2, index_col=0)
merged_df = pd.merge(df1, df2, left_index=True, right_index=True,
                     suffixes=("_normal","_shuffle"))

full_resp = merged_df.values.astype(np.float32)
problems  = merged_df.columns.tolist()
models    = merged_df.index.tolist()


theta_max_list           = [2, 4, 8, 16]
difficulty_base_min_list = [-2, -8]
difficulty_base_max_list = [2, 4, 8, 16]
a_scale_list             = [2, 4, 8, 16]

@ray.remote(num_cpus=0.5)
def eval_auc(theta_max: float,
             difficulty_base_max: float,
             a_scale: float,
             difficulty_base_min: float = 0,
             seed: int = 0) -> dict:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    train_mask, _ = create_balanced_mask(
        response_data=full_resp,
        test_percentage=0,
        item_names=problems,
        seed=seed,
        val_percentage=0.1,
    )
    train_mask = np.where(train_mask == -1, 1, train_mask)

    irt = Standard3PLIRT(
        response_data=full_resp,
        student_names=models,
        test_names=problems,
        train_mask=train_mask,
        split_difficulty=True,
        split_ability=True,
        lr=1e-2,
        batch_size=512,
        max_epochs=5000,
        device="cpu",
        eps=1e-3,
        theta_max=theta_max,
        difficulty_base_max=difficulty_base_max,
        difficulty_other_max=difficulty_base_max,
        difficulty_base_min=difficulty_base_min,
        difficulty_other_min=difficulty_base_min,
        a_scale=a_scale,
        use_guessing=False,
    )
    irt = torch.compile(irt, backend="inductor")
    estimates = irt.fit()
    metrics = irt.evaluate_validation()
    return {
        "theta_max": theta_max,
        "difficulty_base_max": difficulty_base_max,
        "difficulty_base_min": difficulty_base_min,
        "a_scale": a_scale,
        "roc_auc": metrics.get("roc_auc", float('nan')),
        "estimates": estimates,
    }

ray.init()
tasks = []
for tm in theta_max_list:
    tasks.append(eval_auc.remote(tm, tm, tm, tm, seed=0))
results = ray.get(tasks)
ray.shutdown()

best = max(results, key=lambda r: r["roc_auc"])
best_estimates = best["estimates"]

def make_mask_matrix(names: list) -> np.ndarray:
    mask = []
    for name in names:
        if "<no_info>" in name:
            mask.append([0.0, 0.0])
        elif "<no_image>" in name:
            mask.append([1.0, 0.0])
        elif "<no_question>" in name:
            mask.append([0.0, 1.0])
        else:
            mask.append([1.0, 1.0])
    return np.array(mask, dtype=np.float32)

mask_matrix = make_mask_matrix(problems)
param_idx  = {name.split("<no_")[ -1 ].split("]")[0]:
              i for i, name in enumerate(best_estimates["difficulty_full"].keys())}

a_base_dict    = best_estimates["discrimination_base"]
a_text_dict    = best_estimates["discrimination_text"]
a_image_dict   = best_estimates["discrimination_image"]
a_synergy_dict = best_estimates["discrimination_synergy"]
b_base_dict    = best_estimates["difficulty_base"]
b_text_dict    = best_estimates["difficulty_text"]
b_image_dict   = best_estimates["difficulty_image"]
b_synergy_dict = best_estimates["difficulty_synergy"]
c_dict         = best_estimates.get("guessing", {})
theta_dict     = best_estimates["theta"]

fisher = {model: {} for model in models}
for m in models:
    theta_vec = np.array([
        theta_dict[m]["theta_base"],
        theta_dict[m]["theta_text"],
        theta_dict[m]["theta_image"],
        theta_dict[m]["theta_synergy"],
    ], dtype=np.float32)
    for idx, prob in enumerate(problems):
        r_t, r_i = mask_matrix[idx]
        a_b = a_base_dict[prob]
        a_t = a_text_dict[prob]
        a_i = a_image_dict[prob]
        a_s = a_synergy_dict[prob]
        b_b = b_base_dict[prob]
        b_t = b_text_dict[prob]
        b_i = b_image_dict[prob]
        b_s = b_synergy_dict[prob]
        c   = c_dict.get(prob, 0.0)

        g = np.array([a_b, r_t*a_t, r_i*a_i, r_t*r_i*a_s], dtype=np.float32)

        linear = (
            a_b*theta_vec[0]
          + r_i*a_i*theta_vec[2]
          + r_t*a_t*theta_vec[1]
          + (r_t*r_i*a_s)*theta_vec[3]
          - (r_i*b_i + r_t*b_t + b_b + r_t*r_i*b_s)
        )
        p = 1 / (1 + np.exp(-linear))
        p = np.clip(p, 1e-3, 1-1e-3)

        info = p*(1-p) * np.outer(g, g)
        fisher[m][prob] = info.tolist()

output_file = f"{dataset_name}_best_fisher_mmirt.json"
with open(output_file, "w", encoding="utf-8") as fp:
    json.dump({"fisher": fisher}, fp, ensure_ascii=False, indent=2)

print(f"result is saved to {output_file}")
