import os
import sys
import datetime
import pytz
import random
import argparse

import numpy as np
import pandas as pd
import torch

from typing import List, Dict, Any
from scipy.stats import spearmanr
import ray
import mlflow
from mmirt.utils.contami_judge import is_contaminated
from mmirt.utils.masks import create_validation_mask_for_items, create_balanced_mask_with_fixed_test
from itertools import product

def parse_args():
    parser = argparse.ArgumentParser(description="CAT kli")
    parser.add_argument("--N", type=int, default=900,
                        help="groups")
    parser.add_argument("--seed", type=int, default=None,
                        help="seed")

    parser.add_argument("--mathvista", action="store_true", help="use mathvista")
    parser.add_argument("--vqa", action="store_true", help="use vqaat")
    parser.add_argument("--seedbench", action="store_true", help="use seedbench")

    parser.add_argument("--cuda0", action="store_true", help="CUDA_VISIBLE_DEVICES=0")
    parser.add_argument("--cuda1", action="store_true", help="CUDA_VISIBLE_DEVICES=1")
    parser.add_argument("--cuda2", action="store_true", help="CUDA_VISIBLE_DEVICES=2")
    parser.add_argument("--cuda3", action="store_true", help="CUDA_VISIBLE_DEVICES=3")

    parser.add_argument("--2pl", action="store_true", help="2PL")
    parser.add_argument("--3pl", action="store_true", help="3PL")

    parser.add_argument("--filter_no_prefix", action="store_true",
                        help="'<no_…>'")
    parser.add_argument("--fisher", action="store_true", help="Switch to Fisher(only single dimensional mmirt)")
    parser.add_argument("--asplit", action="store_true",
                        help="MMIRT (split_difficulty=True, split_ability=True)")
    parser.add_argument("--nosplit", action="store_true",
                        help="IRT (split_difficulty=False, split_ability=False)")
    parser.add_argument("--train_percentage", type=float, default=1.0,
                        help="Train pairs (0.0–1.0)")
    return parser.parse_args()

args = parse_args()

if args.seed is not None:
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

dataset_name = "mmmu"
file1_path = "mmmu_normal_shuffle/df_cleaned_renamed_with_metrics_index.csv"
file2_path = "mmmu_normal_shuffle/shufffle_mmmu.csv" 
if args.mathvista:
    print("MathVista dataset is used")
    dataset_name = "mathvista"
    file1_path = "merged_mathvista/mathvista_combined.csv"
    file2_path = "merged_mathvista/merged_mathvista_shuffle.csv"
elif args.vqa:
    print("VQAAT dataset is used")
    dataset_name = "vqa"
    file1_path = "vqa_shuffle_normal/vqa_merged.csv"
    file2_path = "vqa_shuffle_normal/vqa_shuffle_filtered.csv"
elif args.seedbench:
    print("Seed bench is used")
    dataset_name = "seed"
    file1_path = "seed_normal_shuffled/seed_normal_all_clean.csv"
    file2_path = "seed_normal_shuffled/seed_shuffled.csv"

df1 = pd.read_csv(file1_path, index_col=0)
df2 = pd.read_csv(file2_path, index_col=0)

prefixes = ["<no_image>", "<no_question>", "<no_info>"]
def extract_group(col_name: str) -> str:
    for p in prefixes:
        if col_name.startswith(p):
            return col_name[len(p):]
    return col_name

groups: Dict[str, List[str]] = {}
for col in df2.columns:
    key = extract_group(col)
    groups.setdefault(key, []).append(col)

all_groups = list(groups.keys())

random.seed(42)
selected_groups = random.sample(all_groups, args.N)

selected_columns: List[str] = []
for grp in selected_groups:
    selected_columns.extend(groups[grp])

df2_subset = df2[selected_columns]
merged_df = pd.merge(df1, df2_subset, left_index=True, right_index=True)
gt_df = df1.copy()
print(f"merged_df shape: {merged_df.shape}")
full_resp = merged_df.values.astype(np.float32)
problems  = merged_df.columns.tolist()
models    = merged_df.index.tolist()

if args.cuda0:
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
elif args.cuda1:
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
elif args.cuda2:
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
elif args.cuda3:
    os.environ["CUDA_VISIBLE_DEVICES"] = "3"
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

use_guessing = bool(args.__dict__.get("3pl"))
pl = 3 if use_guessing else 2

filter_no_prefix = args.filter_no_prefix
filter_suffix = "_filterNoPrefix" if filter_no_prefix else ""

if args.asplit:
    split_difficulty = True; split_ability = True; irt_name = "asplit"
    from mmirt.irt_onelayer_asplit_val import Standard3PLIRT
elif args.nosplit:
    split_difficulty = False; split_ability = False; irt_name = "nosplit"
    from mmirt.irt_onelayer_nostlict import Standard3PLIRT
elif args.fisher:
    from mmirt.mmirt_inner_product import Standard3PLIRT
    split_difficulty = True; split_ability = True; irt_name = "inner_mmirt"
else:
    raise ValueError("select IRT model: --asplit or --nosplit")

now = datetime.datetime.now()
mmdd = now.strftime("%m%d"); hhmm = now.strftime("%H%M")
print(f"Date: {mmdd}, Time: {hhmm}")

response_data = merged_df.values.astype(np.float32)
models = merged_df.index.tolist()
problems = merged_df.columns.tolist()
theta_max_list           = [2,4,8,16]
difficulty_base_max_list = [2,4,8,16]
a_scale_list             = [2,4,8,16]
# theta_max_list           = [2]
# difficulty_base_max_list = [2]
# a_scale_list             = [2]

@ray.remote(num_cpus=0.5)
def eval_one_config(deletion_model: str,
                    theta_max: float,
                    difficulty_base_max: float,
                    a_scale: float,
                    seed: int = 0) -> Dict[str, Any]:
    
    random.seed(seed)
    np.random.seed(seed)

    group_names = [extract_group(col) for col in df1.columns]
    unique_groups = list(set(group_names))
    selected_groups = random.sample(unique_groups, 100)
    val_items = [col for col in merged_df.columns if extract_group(col) in selected_groups]

    train_mask, _ = create_validation_mask_for_items(
        response_data=full_resp,
        item_names=problems,
        val_item_names=val_items,
        val_percentage=0.2,
        seed_init=seed
    )

    train_mask[train_mask == 1] = -1
    train_mask, _ = create_balanced_mask_with_fixed_test(
        mask=train_mask,
        response_data=full_resp,
        train_percentage=args.train_percentage,
        item_names=problems,
        seed=seed
    )
    idx = models.index(deletion_model)
    train_mask[idx, :] = -1 
    irt = Standard3PLIRT(
        response_data=full_resp,
        student_names=models,
        test_names=problems,
        train_mask=train_mask,
        split_difficulty=split_difficulty,
        split_ability=split_ability,
        lr=1e-2,
        batch_size=512,
        max_epochs=5000,
        device="cpu",
        eps=1e-3,
        embedding_dim=1024,
        theta_max=theta_max,
        difficulty_base_max=difficulty_base_max,
        difficulty_other_max=difficulty_base_max,  
        a_scale=a_scale,
        use_guessing=use_guessing,
    )
    import torch
    irt = torch.compile(irt, backend="inductor")  # PyTorch 2.0+
    irt.fit()
    metrics = irt.evaluate_validation()  # {'roc_auc': …, …}

    return {
        "deletion_model": deletion_model,
        "theta_max": theta_max,
        "difficulty_base_max": difficulty_base_max,
        "a_scale": a_scale,
        "roc_auc": metrics["roc_auc"],
        "train_mask": train_mask
    }

@ray.remote(num_cpus=1)
def run_deletion_cat_experiment_milestones(deletion_model: str,
                                           milestone_percents: List[float],
                                           theta_max: float,
                                           difficulty_base_max: float,
                                           a_scale: float,
                                           initial_train_mask: np.ndarray) -> List[Dict]:
    full_resp = merged_df.values.astype(np.float32)

    train_mask = np.ones_like(full_resp, dtype=np.float32)
    train_mask[train_mask == 1] = -1

    train_mask, _ = create_balanced_mask_with_fixed_test(
        mask=train_mask,
        response_data=full_resp,
        train_percentage=args.train_percentage,
        item_names=problems,
        seed=0
    )

    try:
        del_idx = models.index(deletion_model)
    except ValueError:
        print(f"Error: {deletion_model} not found.")
        return []

    train_mask = np.ones_like(full_resp, dtype=np.float32)
    train_mask[train_mask == 1] = -1
    train_mask, _ = create_balanced_mask_with_fixed_test(
        mask=train_mask,
        response_data=full_resp,
        train_percentage=args.train_percentage,
        item_names=problems,
        seed=0
    )
    train_mask[del_idx, :] = -1
    def get_base(pn: str) -> str:
        for p in prefixes:
            if pn.startswith(p):
                return pn[len(p):]
        return pn

    problem_groups = {}
    for p in problems:
        b = get_base(p)
        problem_groups.setdefault(b, []).append(p)
    
    all_units = list(problem_groups.keys())
    total = len(all_units)
    steps = [max(1, int(np.floor(total * p))) for p in milestone_percents]

    model_obj = Standard3PLIRT(
        response_data=full_resp,
        student_names=models,
        test_names=problems,
        train_mask=train_mask,
        split_difficulty=split_difficulty,
        split_ability=split_ability,
        lr=1e-2,
        batch_size=512,
        max_epochs=5000,
        device="cpu",
        eps=1e-3,
        embedding_dim=1024,
        theta_max=theta_max,
        difficulty_base_max=difficulty_base_max,
        difficulty_other_max=difficulty_base_max,
        a_scale=a_scale,
        theta_init=0.2,
        use_guessing=use_guessing,
    )
    print(f"[{deletion_model}] IRT learning start")
    model_obj.fit()


    administered = set()
    admin_list: List[str] = []
    records = []
    step = 0

    while administered is not None and step < max(steps):
        print(f"[{deletion_model}] {step} / {total} ({len(administered)})")
        if args.fisher:
            kli = model_obj.compute_item_fisher_information(deletion_model)
        else:
            kli = model_obj.compute_item_kli(deletion_model)
        scores = {}
        for unit in all_units:
            if unit in administered:
                continue
            if split_ability:
                scores[unit] = sum(kli.get(p, 0.0) for p in problem_groups[unit])
            else:
                scores[unit] = sum(kli.get(p, 0.0) for p in problem_groups[unit])

        if not scores:
            break

        chosen = max(scores, key=scores.get)
        administered.add(chosen)
        admin_list.extend(problem_groups[chosen])
        step += 1

        for idx, p in enumerate(problems):
            base = get_base(p)
            if base in administered:
                train_mask[del_idx, idx] = 1

        model_obj.update_single_theta(
            model_name=deletion_model,
            problem_names=admin_list,
            lr=1e-4,
            max_epochs=1000,
            patience=50,
        )

        if step in steps:
            if filter_no_prefix:
                eval_items = [p for p in admin_list if not p.startswith("<no_")]
                overall_gt_col = [p for p in gt_df.columns if not p.startswith("<no_")]
            else:
                eval_items = list(admin_list)
                overall_gt_col = [p for p in gt_df.columns]
            subset_means = merged_df[eval_items].mean(axis=1)
            overall_gt = gt_df[overall_gt_col].mean(axis=1)
            corr, _ = spearmanr(subset_means.rank(), overall_gt.rank())
            diff = subset_means.loc[deletion_model] - overall_gt.loc[deletion_model]
            acc_err = abs(diff)

            shuffle_ratio = sum(is_contaminated(p) for p in eval_items) / max(1, len(eval_items))
            
            records.append({
                "deletion_model": deletion_model,
                "milestone_percent": step / total,
                "step": step,
                "n_administered_groups": len(administered),
                "n_administered_questions": len(eval_items),
                "corr_mixed_gt": corr,
                "acc_error": acc_err,
                "problem" : eval_items, 
                "shuffle_ratio": shuffle_ratio,
            })
            print(f"[{deletion_model}] Milestone {step}: corr={corr:.4f}, err={acc_err:.4f}, shuffle_ratio={shuffle_ratio:.4f}")

    return records

if __name__ == "__main__":
    ray.init(ignore_reinit_error=True)

    # 1) Grid Search
    gs_tasks = []
    for m in models:
        for tm in theta_max_list:
            gs_tasks.append(
                eval_one_config.remote(m, tm, tm, tm, seed=args.seed or 0)
            )
    gs_results = ray.get(gs_tasks)

    print("=== Grid Search Results ===")
    best_params: Dict[str, Dict] = {}
    for m in models:
        subset = [r for r in gs_results if r["deletion_model"] == m]
        best   = max(subset, key=lambda x: x["roc_auc"])
        best_params[m] = best
    gs_results_df = pd.DataFrame(gs_results)
    gs_results_df.to_csv(f"grid_search_results_{irt_name}_{dataset_name}_{args.N}.csv", index=False)
    print("=== Best Params per Model ===")
    for m, p in best_params.items():
        print(f"{m}: θmax={p['theta_max']}, diff_base={p['difficulty_base_max']}, "
              f"a_scale={p['a_scale']} (roc_auc={p['roc_auc']:.4f})")

    milestone_percents = [i/100.0 for i in range(1,51)]
    cat_tasks = [
        run_deletion_cat_experiment_milestones.remote(
            m, milestone_percents,
            best_params[m]["theta_max"],
            best_params[m]["difficulty_base_max"],
            best_params[m]["a_scale"],
            best_params[m]["train_mask"]
        ) for m in models
    ]
    cat_results_nested = ray.get(cat_tasks)
    ray.shutdown()

    flat = [rec for sub in cat_results_nested for rec in sub]
    results_df = pd.DataFrame(flat)
    out_dir = "result"; os.makedirs(out_dir, exist_ok=True)
    if args.fisher:
        fname = (f"{out_dir}/CAT_reduce_model_shuffle_fisher_{args.N}_"
                f"{dataset_name}_{irt_name}_{mmdd}_{hhmm}{filter_suffix}_train{args.train_percentage}.csv")
    else:
        fname = (f"{out_dir}/CAT_reduce_model_shuffle_kli_{args.N}_"
                f"{dataset_name}_{irt_name}_{mmdd}_{hhmm}{filter_suffix}.csv")
    results_df.to_csv(fname, index=False)
    print(f"Results are saved to '{fname}'")

    exp_name = (f"CAT_reduce_model_shuffle{args.N}_{dataset_name}_"
                f"{mmdd}_{hhmm}_{pl}pl{filter_suffix}")
    mlflow.set_experiment(exp_name)
    with mlflow.start_run(run_name=exp_name):
        mlflow.log_params({
            "dataset_name": dataset_name,
            "pl": pl,
            "use_guessing": use_guessing,
            "split_difficulty": split_difficulty,
            "split_ability": split_ability,
            "filter_no_prefix": filter_no_prefix,
            "N": args.N,
            "seed": args.seed,
        })
        mlflow.log_artifact(fname)