import argparse
import os
import datetime
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
import ray
import random
import sys

from mmirt.utils.contami_judge import is_contaminated
# from mogi.irt_onelayer_asplit_val import Standard3PLIRT
from mmirt.mmirt_inner_product import Standard3PLIRT
import os
import sys
import datetime
import pytz
import random
import argparse

import numpy as np
import pandas as pd
import torch

from typing import List, Dict, Tuple, Any
from scipy.stats import spearmanr
import ray
import mlflow
from itertools import product

from mmirt.utils.contami_judge import is_contaminated
from mmirt.utils.masks import create_validation_mask_for_items, create_balanced_mask
import sys
file1_default = "mmmu_normal_shuffle/df_cleaned_renamed_with_metrics_index.csv"
file2_default = "mmmu_normal_shuffle/shufffle_mmmu.csv" 
file1_math = "merged_mathvista/mathvista_combined.csv"
file2_math = "merged_mathvista/merged_mathvista_shuffle.csv"
file1_vqa = "vqa_shuffle_normal/vqa_merged.csv"
file2_vqa = "vqa_shuffle_normal/vqa_shuffle_filtered.csv"
file1_seed = "seed_normal_shuffled/seed_normal_all_clean.csv"
file2_seed = "seed_normal_shuffled/seed_shuffled.csv"
file1_ls = [file1_default, file1_math, file1_vqa, file1_seed]
file2_ls = [file2_default, file2_math, file2_vqa, file2_seed]
dataset_names = ["mmmu", "mathvista", "vqa", "seed"]
num = int(sys.argv[1])
print(dataset_names[num])
dataset_name= dataset_names[num]
file1 = file1_ls[num]
file2 = file2_ls[num]
results_list = []

df1 = pd.read_csv(file1, index_col=0)
df2 = pd.read_csv(file2, index_col=0)
merged_df = pd.merge(
    df1,
    df2,
    left_index=True,
    right_index=True,
    suffixes=('_normal', '_shuffle')
)
# merged_df = df1
full_resp = merged_df.values.astype(np.float32)
problems  = merged_df.columns.tolist()
models    = merged_df.index.tolist()

theta_max_list           = [2,4,8,16]
difficulty_base_min_list           = [0,-2,-8]
difficulty_base_max_list = [2,4,8,16]
a_scale_list             = [2,4,8,16]

@ray.remote(num_cpus=0.5)
def eval_one_config(deletion_model: str,
                    theta_max: float,
                    difficulty_base_max: float,
                    a_scale: float,
                    difficulty_base_min: float,
                    seed: int = 0) -> Dict[str, Any]:
    
    random.seed(seed)
    np.random.seed(seed)


    train_mask, _ = create_balanced_mask(
        response_data=full_resp,
        test_percentage=0,
        item_names=problems,
        # student_percentage=0.1,
        # test_names=problems,
        # student_names=models,
        seed=seed,
        val_percentage=0.1,
    )
    train_mask = np.where(train_mask == -1, 1, train_mask)

    irt = Standard3PLIRT(
        response_data=full_resp,
        student_names=models,
        test_names=problems,
        train_mask=train_mask,
        split_difficulty=True,
        split_ability=True,
        lr=1e-2,
        batch_size=512,
        max_epochs=5000,
        device="cpu",
        eps=1e-3,
        embedding_dim=1024,
        theta_max=theta_max,
        difficulty_base_max=difficulty_base_max,
        difficulty_other_max=difficulty_base_max,  
        difficulty_base_min=difficulty_base_min,
        difficulty_other_min=difficulty_base_min,  
        a_scale=a_scale,
        use_guessing=False,
    )
    import torch
    irt = torch.compile(irt, backend="inductor")
    estimates = irt.fit()
    metrics = irt.evaluate_validation()

    return {
        "deletion_model": deletion_model,
        "theta_max": theta_max,
        "difficulty_base_max": difficulty_base_max,
        "a_scale": a_scale,
        "difficulty_base_min": difficulty_base_min,
        "roc_auc": metrics["roc_auc"],
        "train_mask": train_mask,
        "estimates": estimates
    }
gs_tasks = []
for tm in theta_max_list:
        print(f"Running config: tm={tm}")
        gs_tasks.append(
            eval_one_config.remote(deletion_model="gemini",
                                   theta_max=tm, 
                                   difficulty_base_max=tm, 
                                   a_scale=tm, 
                                   difficulty_base_min=0, 
                                   seed=0)
        )
gs_results = ray.get(gs_tasks)

for r in gs_results:
    results_list.append({
        "deletion_model":        r["deletion_model"],
        "theta_max":             r["theta_max"],
        "difficulty_base_max":   r["difficulty_base_max"],
        "difficulty_base_min":   r["difficulty_base_min"],
        "a_scale":               r["a_scale"],
        "roc_auc":               r["roc_auc"],
    })

df_gs = pd.DataFrame(results_list)
csv_path = f"json_{dataset_name}_grid_search_results.csv"
df_gs.to_csv(csv_path, index=False, encoding="utf-8-sig")
print(f"Grid search results saved to {csv_path}")
best_params: Dict[str, Dict] = {}
m = "gemini"
subset = [r for r in gs_results if r["deletion_model"] == m]
best   = max(subset, key=lambda x: x["roc_auc"])
best_params[m] = best
estimates = best_params[m]["estimates"]

import json

def convert_dict(d):
    for key, value in d.items():
        if isinstance(value, dict):
            convert_dict(value) 
        elif isinstance(value, (np.float32, np.float64, np.int32, np.int64)):
            d[key] = value.item() 
convert_dict(estimates)
with open(f"{dataset_name}_val_estimats_single_dim_normal.json", "w", encoding="utf-8") as file:
    json.dump(estimates, file, ensure_ascii=False, indent=4)

print("Json file saved.")
