import argparse
import os
import datetime
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
import ray
import random
import sys

from mmirt.utils.contami_judge import is_contaminated
from mmirt.irt_onelayer_nostlict import Standard3PLIRT
import os
import sys
import datetime
import pytz
import random
import argparse

import numpy as np
import pandas as pd
import torch

from typing import List, Dict, Tuple, Any
from scipy.stats import spearmanr
import ray
import mlflow
from itertools import product

from mmirt.utils.contami_judge import is_contaminated
from mmirt.utils.masks import create_validation_mask_for_items, create_balanced_mask
import sys
file1_default = "mmmu_normal_shuffle/df_cleaned_renamed_with_metrics_index.csv"     
file2_default = "mmmu_normal_shuffle/shufffle_mmmu.csv" 
file1_math = "merged_mathvista/mathvista_combined.csv"
file2_math = "merged_mathvista/merged_mathvista_shuffle.csv"
file1_vqa = "vqa_shuffle_normal/vqa_merged.csv"
file2_vqa = "vqa_shuffle_normal/vqa_shuffle_filtered.csv"
file1_ls = [file1_default, file1_math, file1_vqa]
file2_ls = [file2_default, file2_math, file2_vqa]
dataset_names = ["mmmu", "mathvista", "vqa"]
num = int(sys.argv[1])
print(dataset_names[num])
dataset_name= dataset_names[num]
file1 = file1_ls[num]
file2 = file2_ls[num]


df1 = pd.read_csv(file1, index_col=0)
df2 = pd.read_csv(file2, index_col=0)
# merged_df = pd.merge(
#     df1,
#     df2,
#     left_index=True,
#     right_index=True,
#     suffixes=('_normal', '_shuffle')
# )
merged_df = df1
full_resp = merged_df.values.astype(np.float32)
problems  = merged_df.columns.tolist()
models    = merged_df.index.tolist()


theta_max_list           = [2]
difficulty_base_min_list = [-2]
difficulty_base_max_list = [2]
a_scale_list             = [2]
# ----------------------------

@ray.remote(num_cpus=0.5)
def eval_one_config(
                    theta_max: float,
                    difficulty_base_max: float,
                    a_scale: float,
                    difficulty_base_min: float = 0,
                    seed: int = 0) -> Dict[str, Any]:
    
    random.seed(seed)
    np.random.seed(seed)


    train_mask, _ = create_balanced_mask(
        response_data=full_resp,
        test_percentage=0,
        item_names=problems,
        # student_percentage=0.1,
        # test_names=problems,
        # student_names=models,
        seed=seed,
        val_percentage=0.1,
    )
    train_mask = np.where(train_mask == -1, 1, train_mask)

    irt = Standard3PLIRT(
        response_data=full_resp,
        student_names=models,
        test_names=problems,
        train_mask=train_mask,
        split_difficulty=False,
        split_ability=False,
        lr=1e-2,
        batch_size=512,
        max_epochs=5000,
        device="cpu",
        eps=1e-3,
        embedding_dim=1024,
        theta_max=theta_max,
        difficulty_base_max=difficulty_base_max,
        difficulty_other_max=difficulty_base_max,  
        # difficulty_base_min=difficulty_base_min,
        # difficulty_other_min=difficulty_base_min, 
        # a_scale=a_scale,
        use_guessing=False,
        enable_abs_clamp=False
    )
    import torch
    irt = torch.compile(irt, backend="inductor")  # PyTorch 2.0+
    estimates = irt.fit()
    metrics = irt.evaluate_predictions()

    return {
        "theta_max": theta_max,
        "difficulty_base_max": difficulty_base_max,
        "a_scale": a_scale,
        "roc_auc": metrics["roc_auc"],
        "train_mask": train_mask,
        "estimets" : estimates
    }
gs_tasks = []
for tm, dbm, dbmin,a in product(theta_max_list,
                                difficulty_base_max_list,
                                difficulty_base_min_list,
                                a_scale_list):
        gs_tasks.append(
            eval_one_config.remote(tm, dbm, a, difficulty_base_min=dbmin ,seed=0)
        )
gs_results = ray.get(gs_tasks)
best_params: Dict[str, Dict] = {}
best   = max(gs_results, key=lambda x: x["roc_auc"])

estimates = best["estimets"]

import json

def convert_dict(d):
    for key, value in d.items():
        if isinstance(value, dict):
            convert_dict(value)  
        elif isinstance(value, (np.float32, np.float64, np.int32, np.int64)):
            d[key] = value.item() 
convert_dict(estimates)
with open(f"{dataset_name}_val_estimats_irt_clean.json", "w", encoding="utf-8") as file:
    json.dump(estimates, file, ensure_ascii=False, indent=4)

print("Json file saved successfully.")