import os
import sys
import datetime
import pytz
import random
import argparse

import numpy as np
import pandas as pd
import torch

from typing import List, Dict, Any
from scipy.stats import spearmanr
import ray
import mlflow
from mmirt.utils.contami_judge import is_contaminated
from mmirt.utils.masks import create_validation_mask_for_items
from itertools import product

def parse_args():
    parser = argparse.ArgumentParser(description="CAT Experiment")
    parser.add_argument("--N", type=int, default=900,
                        help="random groups")
    parser.add_argument("--seed", type=int, default=None,
                        help="seed")

    parser.add_argument("--mathvista", action="store_true", help="use mathvista")
    parser.add_argument("--vqa", action="store_true", help="use vqa")

    parser.add_argument("--cuda0", action="store_true", help="CUDA_VISIBLE_DEVICES=0")
    parser.add_argument("--cuda1", action="store_true", help="CUDA_VISIBLE_DEVICES=1")
    parser.add_argument("--cuda2", action="store_true", help="CUDA_VISIBLE_DEVICES=2")
    parser.add_argument("--cuda3", action="store_true", help="CUDA_VISIBLE_DEVICES=3")

    parser.add_argument("--2pl", action="store_true", help="2pl model")
    parser.add_argument("--3pl", action="store_true", help="3PL model")

    parser.add_argument("--filter_no_prefix", action="store_true",
                        help="'<no_…>' exclude prefix filter")

    parser.add_argument("--asplit", action="store_true",
                        help="mmirt")
    parser.add_argument("--nosplit", action="store_true",
                        help="irt")

    return parser.parse_args()

args = parse_args()

if args.seed is not None:
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

dataset_name = "mmmu"
file1_path = "mmmu_normal_shuffle/df_cleaned_renamed_with_metrics_index.csv" 
file2_path = "mmmu_normal_shuffle/shufffle_mmmu.csv"
if args.mathvista:
    print("MathVista Dataset used")
    dataset_name = "mathvista"
    file1_path = "merged_mathvista/mathvista_combined.csv"
    file2_path = "merged_mathvista/merged_mathvista_shuffle.csv"
elif args.vqa:
    print("VQAAT Dataset used")
    dataset_name = "vqa"
    file1_path = "vqa_shuffle_normal/vqa_merged.csv"
    file2_path = "vqa_shuffle_normal/vqa_shuffle_filtered.csv"
df1 = pd.read_csv(file1_path, index_col=0)
df2 = pd.read_csv(file2_path, index_col=0)

prefixes = ["<no_image>", "<no_question>", "<no_info>"]
def extract_group(col_name: str) -> str:
    for p in prefixes:
        if col_name.startswith(p):
            return col_name[len(p):]
    return col_name

groups: Dict[str, List[str]] = {}
for col in df2.columns:
    key = extract_group(col)
    groups.setdefault(key, []).append(col)

all_groups = list(groups.keys())

random.seed(42)
selected_groups = random.sample(all_groups, args.N)

selected_columns: List[str] = []
for grp in selected_groups:
    selected_columns.extend(groups[grp])

df2_subset = df2[selected_columns]
merged_df = pd.merge(df1, df2_subset, left_index=True, right_index=True)
gt_df = df1.copy()
print(f"merged_df shape: {merged_df.shape}")
full_resp = merged_df.values.astype(np.float32)
problems  = merged_df.columns.tolist()
models    = merged_df.index.tolist()

if args.cuda0:
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
elif args.cuda1:
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
elif args.cuda2:
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
elif args.cuda3:
    os.environ["CUDA_VISIBLE_DEVICES"] = "3"
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

use_guessing = bool(args.__dict__.get("3pl"))
pl = 3 if use_guessing else 2

filter_no_prefix = args.filter_no_prefix
filter_suffix = "_filterNoPrefix" if filter_no_prefix else ""


split_difficulty = False; split_ability = False; irt_name = "nosplit"
from mmirt.irt_onelayer_nostlict import Standard3PLIRT

now = datetime.datetime.now()
mmdd = now.strftime("%m%d"); hhmm = now.strftime("%H%M")
print(f"Date: {mmdd}, Time: {hhmm}")

response_data = merged_df.values.astype(np.float32)
models = merged_df.index.tolist()
problems = merged_df.columns.tolist()
# theta_max_list           = [2,4,8,16]
# difficulty_base_max_list = [2,4,8,16]
# a_scale_list             = [2,4,8,16]
theta_max_list           = [2]
difficulty_base_max_list = [2]
a_scale_list             = [2]

@ray.remote(num_cpus=0.5)
def eval_one_config(deletion_model: str,
                    theta_max: float,
                    difficulty_base_max: float,
                    a_scale: float,
                    seed: int = 0) -> Dict[str, Any]:

    random.seed(seed)
    np.random.seed(seed)

    group_names = [extract_group(col) for col in df1.columns]

    unique_groups = list(set(group_names))

    selected_groups = random.sample(unique_groups, 100)
    val_items = [col for col in merged_df.columns if extract_group(col) in selected_groups]

    train_mask, _ = create_validation_mask_for_items(
        response_data=full_resp,
        item_names=problems,
        val_item_names=val_items,
        val_percentage=0.2,
        seed_init=seed
    )
    idx = models.index(deletion_model)
    train_mask[idx, :] = -1
    irt = Standard3PLIRT(
        response_data=full_resp,
        student_names=models,
        test_names=problems,
        train_mask=train_mask,
        split_difficulty=split_difficulty,
        split_ability=split_ability,
        lr=1e-2,
        batch_size=512,
        max_epochs=5,
        device="cpu",
        eps=1e-3,
        embedding_dim=1024,
        theta_max=theta_max,
        difficulty_base_max=difficulty_base_max,
        difficulty_other_max=difficulty_base_max,  
        a_scale=a_scale,
        enable_abs_clamp=False,
        use_guessing=False,
    )
    import torch
    irt = torch.compile(irt, backend="inductor")  # PyTorch 2.0+
    irt.fit()
    metrics = {'roc_auc': 1}

    return {
        "deletion_model": deletion_model,
        "theta_max": theta_max,
        "difficulty_base_max": difficulty_base_max,
        "a_scale": a_scale,
        "roc_auc": metrics["roc_auc"],
        "train_mask": train_mask
    }

@ray.remote(num_cpus=1)
def run_deletion_cat_experiment_milestones(deletion_model: str,
                                           milestone_percents: List[float],
                                           theta_max: float,
                                           difficulty_base_max: float,
                                           a_scale: float,
                                           initial_train_mask: np.ndarray) -> List[Dict]:
    full_resp = merged_df.values.astype(np.float32)
    train_mask = initial_train_mask.copy().astype(np.float32)
    try:
        del_idx = models.index(deletion_model)
    except ValueError:
        print(f"Error: {deletion_model} not found.")
        return []

    train_mask = np.ones_like(full_resp)
    train_mask[del_idx, :] = -1

    def get_base(pn: str) -> str:
        for p in prefixes:
            if pn.startswith(p):
                return pn[len(p):]
        return pn

    problem_groups = {}
    for p in problems:
        b = get_base(p)
        problem_groups.setdefault(b, []).append(p)
    
    all_units = list(problem_groups.keys())
    total = len(all_units)
    steps = [max(1, int(np.floor(total * p))) for p in milestone_percents]

    model_obj = Standard3PLIRT(
        response_data=full_resp,
        student_names=models,
        test_names=problems,
        train_mask=train_mask,
        split_difficulty=split_difficulty,
        split_ability=split_ability,
        lr=1e-2,
        batch_size=512,
        max_epochs=5,
        device="cpu",
        eps=1e-3,
        embedding_dim=1024,
        theta_max=theta_max,
        difficulty_base_max=difficulty_base_max,
        difficulty_other_max=difficulty_base_max,
        a_scale=a_scale,
        theta_init=0.2,
        use_guessing=use_guessing,
    )
    print(f"[{deletion_model}] (MM)IRT learning start")
    model_obj.fit()


    administered = set()
    admin_list: List[str] = []
    records = []
    step = 0

    while administered is not None and step < max(steps):
        print(f"[{deletion_model}] {step} / {total} ({len(administered)})")
        kli = model_obj.compute_item_kli(deletion_model)
        scores = {}
        for unit in all_units:
            if unit in administered:
                continue
            if split_ability:
                scores[unit] = sum(kli.get(p, 0.0) for p in problem_groups[unit])
            else:
                scores[unit] = sum(kli.get(p, 0.0) for p in problem_groups[unit])

        if not scores:
            break

        chosen = max(scores, key=scores.get)
        administered.add(chosen)
        admin_list.extend(problem_groups[chosen])
        step += 1

        for idx, p in enumerate(problems):
            base = get_base(p)
            if base in administered:
                train_mask[del_idx, idx] = 1

        model_obj.update_single_theta(
            model_name=deletion_model,
            problem_names=admin_list,
            lr=1e-4,
            max_epochs=10,
            patience=50,
        )

        if step in steps:
            if filter_no_prefix:
                eval_items = [p for p in admin_list if not p.startswith("<no_")]

                overall_gt_col = [p for p in gt_df.columns if not p.startswith("<no_")]
            else:
                eval_items = list(admin_list)
                overall_gt_col = [p for p in gt_df.columns]

            subset_means = merged_df[eval_items].mean(axis=1)
            overall_gt = gt_df[overall_gt_col].mean(axis=1)
            corr, _ = spearmanr(subset_means.rank(), overall_gt.rank())
            diff = subset_means.loc[deletion_model] - overall_gt.loc[deletion_model]
            acc_err = abs(diff)

            shuffle_ratio = sum(is_contaminated(p) for p in eval_items) / max(1, len(eval_items))
            
            records.append({
                "deletion_model": deletion_model,
                "milestone_percent": step / total,
                "step": step,
                "n_administered_groups": len(administered),
                "n_administered_questions": len(eval_items),
                "corr_mixed_gt": corr,
                "acc_error": acc_err,
                "problem" : eval_items, 
                "shuffle_ratio": shuffle_ratio,
            })
            print(f"[{deletion_model}] Milestone {step}: corr={corr:.4f}, err={acc_err:.4f}, shuffle_ratio={shuffle_ratio:.4f}")

    return records

if __name__ == "__main__":
    ray.init(ignore_reinit_error=True)

    # 1) Grid Search
    gs_tasks = []
    for m in models:
        for tm, dbm, a in product(theta_max_list,
                                   difficulty_base_max_list,
                                   a_scale_list):
            gs_tasks.append(
                eval_one_config.remote(m, tm, dbm, a, seed=args.seed or 0)
            )
    gs_results = ray.get(gs_tasks)

    print("=== Grid Search Results ===")
    best_params: Dict[str, Dict] = {}
    for m in models:
        subset = [r for r in gs_results if r["deletion_model"] == m]
        best   = max(subset, key=lambda x: x["roc_auc"])
        best_params[m] = best
    gs_results_df = pd.DataFrame(gs_results)
    gs_results_df.to_csv(f"grid_search_results_asplit_{dataset_name}_{args.N}.csv", index=False)
    print("=== Best Params per Model ===")
    for m, p in best_params.items():
        print(f"{m}: θmax={p['theta_max']}, diff_base={p['difficulty_base_max']}, "
              f"a_scale={p['a_scale']} (roc_auc={p['roc_auc']:.4f})")

    milestone_percents = [i/100.0 for i in range(1,1)]
    cat_tasks = [
        run_deletion_cat_experiment_milestones.remote(
            m, milestone_percents,
            best_params[m]["theta_max"],
            best_params[m]["difficulty_base_max"],
            best_params[m]["a_scale"],
            best_params[m]["train_mask"]
        ) for m in models
    ]
    cat_results_nested = ray.get(cat_tasks)
    ray.shutdown()

    flat = [rec for sub in cat_results_nested for rec in sub]
    results_df = pd.DataFrame(flat)
    out_dir = "result"; os.makedirs(out_dir, exist_ok=True)
    fname = (f"{out_dir}/CAT_reduce_model_shuffle_kli_{args.N}_"
             f"{dataset_name}_{irt_name}_{mmdd}_{hhmm}{filter_suffix}.csv")
    results_df.to_csv(fname, index=False)
    print(f"Results are saves to '{fname}'")

    exp_name = (f"CAT_reduce_model_shuffle{args.N}_{dataset_name}_"
                f"{mmdd}_{hhmm}_{pl}pl{filter_suffix}")
    mlflow.set_experiment(exp_name)
    with mlflow.start_run(run_name=exp_name):
        mlflow.log_params({
            "dataset_name": dataset_name,
            "pl": pl,
            "use_guessing": use_guessing,
            "split_difficulty": split_difficulty,
            "split_ability": split_ability,
            "filter_no_prefix": filter_no_prefix,
            "N": args.N,
            "seed": args.seed,
        })
        mlflow.log_artifact(fname)