import os
import sys
import datetime
import pytz
import random
import argparse

import numpy as np
import pandas as pd
import torch

from typing import List, Dict, Any
from scipy.stats import spearmanr
import ray
import mlflow
from mmirt.utils.contami_judge import is_contaminated
from mmirt.utils.masks import create_validation_mask_for_items, create_balanced_mask_with_fixed_test
from itertools import product


def parse_args():
    parser = argparse.ArgumentParser(description="CAT")
    parser.add_argument("--N", type=int, default=900,
                        help="group num")
    parser.add_argument("--seed", type=int, default=None,
                        help="seed")
    parser.add_argument("--mathvista", action="store_true", help="use mathvista dataset")
    parser.add_argument("--seedbench", action="store_true", help="use seedbench dataset")
    parser.add_argument("--vqa", action="store_true", help="use vqaat dataset")
    parser.add_argument("--cuda0", action="store_true", help="CUDA_VISIBLE_DEVICES=0")
    parser.add_argument("--cuda1", action="store_true", help="CUDA_VISIBLE_DEVICES=1")
    parser.add_argument("--cuda2", action="store_true", help="CUDA_VISIBLE_DEVICES=2")
    parser.add_argument("--cuda3", action="store_true", help="CUDA_VISIBLE_DEVICES=3")
    parser.add_argument("--2pl", action="store_true", help="2PL model")
    parser.add_argument("--3pl", action="store_true", help="3PL moel")
    parser.add_argument("--filter_no_prefix", action="store_true",
                        help="'<no_…>'")
    parser.add_argument("--asplit", action="store_true",
                        help="MMIRT (split_difficulty=True, split_ability=True)")
    parser.add_argument("--mirt", action="store_true",
                        help="IRT (split_difficulty=False, split_ability=False)")
    parser.add_argument("--nosplit", action="store_true",
                        help="IRT (split_difficulty=False, split_ability=False)")
    parser.add_argument("--train_percentage", type=float, default=1.0,
                        help="Train data (0.0–1.0)")
    return parser.parse_args()

args = parse_args()

if args.seed is not None:
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

dataset_name = "mmmu"
file1_path = "mmmu_normal_shuffle/df_cleaned_renamed_with_metrics_index.csv"
file2_path = "mmmu_normal_shuffle/shufffle_mmmu.csv" 
if args.mathvista:
    print("MathVista Dataset is used")
    dataset_name = "mathvista"
    file1_path = "merged_mathvista/mathvista_combined.csv"
    file2_path = "merged_mathvista/merged_mathvista_shuffle.csv"
elif args.vqa:
    print("VQAAT Dataset is used")
    dataset_name = "vqa"
    file1_path = "vqa_shuffle_normal/vqa_merged.csv"
    file2_path = "vqa_shuffle_normal/vqa_shuffle_filtered.csv"
elif args.seedbench:
    print("Seed bench is used")
    dataset_name = "seed"
    file1_path = "seed_normal_shuffled/seed_normal_all_clean.csv"
    file2_path = "seed_normal_shuffled/seed_shuffled.csv"
df1 = pd.read_csv(file1_path, index_col=0)
df2 = pd.read_csv(file2_path, index_col=0)

prefixes = ["<no_image>", "<no_question>", "<no_info>"]
df1 = pd.read_csv(file1_path, index_col=0)
df2 = pd.read_csv(file2_path, index_col=0)

def extract_group(col_name: str) -> str:
    for p in prefixes:
        if col_name.startswith(p):
            return col_name[len(p):]
    return col_name

groups: Dict[str, List[str]] = {}
for col in df2.columns:
    key = extract_group(col)
    groups.setdefault(key, []).append(col)

all_groups = list(groups.keys())

random.seed(42)
selected_groups = random.sample(all_groups, args.N)

selected_columns: List[str] = []
for grp in selected_groups:
    selected_columns.extend(groups[grp])

df2_subset = df2[selected_columns]
merged_df = pd.merge(df1, df2_subset, left_index=True, right_index=True)
gt_df = df1.copy()
print(f"merged_df shape: {merged_df.shape}")
full_resp = merged_df.values.astype(np.float32)
problems  = merged_df.columns.tolist()
models    = merged_df.index.tolist()


if args.cuda0:
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
elif args.cuda1:
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
elif args.cuda2:
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
elif args.cuda3:
    os.environ["CUDA_VISIBLE_DEVICES"] = "3"
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

use_guessing = bool(args.__dict__.get("3pl"))
pl = 3 if use_guessing else 2

filter_no_prefix = args.filter_no_prefix
filter_suffix = "_filterNoPrefix" if filter_no_prefix else ""

if args.asplit:
    split_difficulty = True; split_ability = True; irt_name = "asplit"
    from mmirt.irt_onelayer_asplit_val import Standard3PLIRT
elif args.nosplit:
    split_difficulty = False; split_ability = False; irt_name = "nosplit"
    from mmirt.irt_onelayer_nostlict import Standard3PLIRT
elif args.mirt:
    split_difficulty = True; split_ability = True; irt_name = "mirt"
    from mmirt.mirt import MultiDim3PLIRT as Standard3PLIRT
else:
    raise ValueError("Select IRT model: --asplit or --nosplit")

now = datetime.datetime.now()
mmdd = now.strftime("%m%d"); hhmm = now.strftime("%H%M")
print(f"Date: {mmdd}, Time: {hhmm}")
response_data = merged_df.values.astype(np.float32)
models = merged_df.index.tolist()
problems = merged_df.columns.tolist()
if args.asplit:
    shared_scale_list = [2, 4, 8, 16]
else:
    shared_scale_list = [2, 4, 8, 16]

@ray.remote(num_cpus=0.5)
def eval_one_config(deletion_model: str,
                    theta_max: float,
                    difficulty_base_max: float,
                    a_scale: float,
                    seed: int = 0) -> Dict[str, Any]:

    random.seed(seed)
    np.random.seed(seed)


    group_names = [extract_group(col) for col in df1.columns]

    unique_groups = list(set(group_names))

    selected_groups = random.sample(unique_groups, 100)
    val_items = [col for col in merged_df.columns if extract_group(col) in selected_groups]


    train_mask, _ = create_validation_mask_for_items(
        response_data=full_resp,
        item_names=problems,
        val_item_names=val_items,
        val_percentage=0.2,
        seed_init=seed
    )
    train_mask[train_mask == 1] = -1
    train_mask, _ = create_balanced_mask_with_fixed_test(
        mask=train_mask,
        response_data=full_resp,
        train_percentage=args.train_percentage,
        item_names=problems,
        seed=seed
    )
    idx = models.index(deletion_model)
    train_mask[idx, :] = -1
    if args.mirt:
        irt = Standard3PLIRT(
            response_data=full_resp,
            student_names=models,
            test_names=problems,
            train_mask=train_mask,
            n_dims = 4,
            # split_difficulty=split_difficulty,
            # split_ability=split_ability,
            lr=1e-2,
            batch_size=512,
            max_epochs=5,
            device="cpu",
            eps=1e-3,
            # embedding_dim=1024,
            # theta_max=theta_max,
            # difficulty_base_max=difficulty_base_max,
            # difficulty_other_max=difficulty_base_max, 
            # a_scale=a_scale,

            use_guessing=use_guessing,
        )
    else:
        irt = Standard3PLIRT(
        response_data=full_resp,
        student_names=models,
        test_names=problems,
        train_mask=train_mask,
        split_difficulty=split_difficulty,
        split_ability=split_ability,
        lr=1e-2,
        batch_size=512,
        max_epochs=5000,
        device="cpu",
        eps=1e-3,
        embedding_dim=1024,
        theta_max=theta_max,
        difficulty_base_max=difficulty_base_max,
        difficulty_other_max=difficulty_base_max, 
        # difficulty_other_min=-difficulty_base_max, 
        a_scale=a_scale,
        use_guessing=use_guessing,
    )
    import torch
    irt = torch.compile(irt, backend="inductor")  # PyTorch 2.0+

    irt.fit()
    metrics = irt.evaluate_validation()  # {'roc_auc': …, …}

    return {
        "deletion_model": deletion_model,
        "theta_max": theta_max,
        "difficulty_base_max": difficulty_base_max,
        "a_scale": a_scale,
        "roc_auc": metrics["roc_auc"],
        "train_mask": train_mask
    }

@ray.remote(num_cpus=1)
def run_deletion_cat_experiment_milestones(
        deletion_model: str,
        milestone_percents: List[float],
        theta_max: float,
        difficulty_base_max: float,
        a_scale: float
    ) -> List[Dict]:

    full_resp = merged_df.values.astype(np.float32)

    train_mask = np.ones_like(full_resp, dtype=np.float32)
    train_mask[train_mask == 1] = -1
    train_mask, _ = create_balanced_mask_with_fixed_test(
        mask=train_mask,
        response_data=full_resp,
        train_percentage=args.train_percentage,
        item_names=problems,
        seed=0
    )

    try:
        del_idx = models.index(deletion_model)
    except ValueError:
        print(f"Error: {deletion_model} not found.")
        return []

    train_mask = np.ones_like(full_resp, dtype=np.float32)
    train_mask[train_mask == 1] = -1
    train_mask, _ = create_balanced_mask_with_fixed_test(
        mask=train_mask,
        response_data=full_resp,
        train_percentage=args.train_percentage,
        item_names=problems,
        seed=0
    )
    train_mask[del_idx, :] = -1

    print(f"[{deletion_model}] (MM)IRT learning started idx={del_idx}")
    def get_base(pn: str) -> str:
        for p in prefixes:
            if pn.startswith(p):
                return pn[len(p):]
        return pn

    problem_groups = {}
    for p in problems:
        b = get_base(p)
        problem_groups.setdefault(b, []).append(p)
    
    all_units = list(problem_groups.keys())
    total = len(all_units)
    steps = [max(1, int(np.floor(total * p))) for p in milestone_percents]

    if args.mirt:
        model_obj = Standard3PLIRT(
            response_data=full_resp,
            student_names=models,
            test_names=problems,
            train_mask=train_mask,
            n_dims = 4,
            # split_difficulty=split_difficulty,
            # split_ability=split_ability,
            lr=1e-2,
            batch_size=512,
            max_epochs=5000,
            device="cpu",
            eps=1e-3,
            # embedding_dim=1024,
            # theta_max=theta_max,
            # difficulty_base_max=difficulty_base_max,
            # difficulty_other_max=difficulty_base_max, 
            # a_scale=a_scale,

            use_guessing=use_guessing,
        )
    else:
        model_obj = Standard3PLIRT(
        response_data=full_resp,
        student_names=models,
        test_names=problems,
        train_mask=train_mask,
        split_difficulty=split_difficulty,
        split_ability=split_ability,
        lr=1e-2,
        batch_size=512,
        max_epochs=5000,
        device="cpu",
        eps=1e-3,
        embedding_dim=1024,
        theta_max=theta_max,
        difficulty_base_max=difficulty_base_max,
        difficulty_other_max=difficulty_base_max,
        # difficulty_other_min=-difficulty_base_max,  
        # difficulty_base_min=-difficulty_base_max, 
        a_scale=a_scale,
        use_guessing=use_guessing,
    )
    print(f"[{deletion_model}] IRT learning started ")
    model_obj.fit()


    administered = set()
    admin_list: List[str] = []
    records = []
    step = 0

    while administered is not None and step < max(steps):
        print(f"[{deletion_model}] {step} / {total} ({len(administered)})")
        fisher = model_obj.compute_item_fisher_information(deletion_model)
        scores = {}
        for unit in all_units:
            if unit in administered:
                continue
            if split_ability or args.mirt:
                mat = sum(fisher.get(p, np.zeros((fisher[next(iter(fisher))].shape))) for p in problem_groups[unit])
                scores[unit] = np.linalg.det(mat)
                # import random
                # # 0 to 100 int
                # if random.randint(0, 100) < 2:
                #     print(f"[{deletion_model}] Randomly selected {unit} with score {scores[unit]}")
            else:
                scores[unit] = sum(fisher.get(p, 0.0) for p in problem_groups[unit])

        if not scores:
            break

        chosen = max(scores, key=scores.get)
        administered.add(chosen)
        admin_list.extend(problem_groups[chosen])
        step += 1

        for idx, p in enumerate(problems):
            base = get_base(p)
            if base in administered:
                train_mask[del_idx, idx] = 1

        model_obj.update_single_theta(
            model_name=deletion_model,
            problem_names=admin_list,
            lr=1e-4,
            max_epochs=1000,
            patience=50,
        )

        if step in steps:
            if filter_no_prefix:
                eval_items = [p for p in admin_list if not p.startswith("<no_")]
                overall_gt_col = [p for p in gt_df.columns if not p.startswith("<no_")]
            else:
                eval_items = list(admin_list)
                overall_gt_col = [p for p in gt_df.columns]
            subset_means = merged_df[eval_items].mean(axis=1)
            overall_gt = gt_df[overall_gt_col].mean(axis=1)
            corr, _ = spearmanr(subset_means.rank(), overall_gt.rank())
            diff = subset_means.loc[deletion_model] - overall_gt.loc[deletion_model]
            acc_err = abs(diff)

            shuffle_ratio = sum(is_contaminated(p) for p in eval_items) / max(1, len(eval_items))
            
            records.append({
                "deletion_model": deletion_model,
                "milestone_percent": step / total,
                "step": step,
                "n_administered_groups": len(administered),
                "n_administered_questions": len(eval_items),
                "corr_mixed_gt": corr,
                "acc_error": acc_err,
                "problem" : eval_items, 
                "shuffle_ratio": shuffle_ratio,
            })
            print(f"[{deletion_model}] Milestone {step}: corr={corr:.4f}, err={acc_err:.4f}, shuffle_ratio={shuffle_ratio:.4f}")

    return records

if __name__ == "__main__":
    ray.init(ignore_reinit_error=True)
    if args.train_percentage < 1.0:
        grid_csv = f"grid_search_results_{irt_name}_{dataset_name}_{args.N}_train{args.train_percentage}_suimasenn.csv"
    else:
        grid_csv = f"grid_search_results_{irt_name}_{dataset_name}_{args.N}_suimasenn.csv"

    gs_tasks = []
    for m in models:
        for scale in shared_scale_list:
            # theta_max == difficulty_base_max == a_scale == scale
            gs_tasks.append(
                eval_one_config.remote(m, scale, scale, scale, seed=args.seed or 0)
            )
    gs_results = ray.get(gs_tasks)

    gs_results_df = pd.DataFrame(gs_results)
    if args.train_percentage < 1.0:
        grid_csv = f"grid_search_results_{irt_name}_{dataset_name}_{args.N}_train{args.train_percentage}_equal_q.csv"
    else:
        grid_csv = f"grid_search_results_{irt_name}_{dataset_name}_{args.N}_equal_q.csv"
    gs_results_df.to_csv(grid_csv, index=False)
    print(f"Grid search results saved to {grid_csv}")

    print("=== Best Params per Model ===")
    best_params: Dict[str, Dict] = {}
    for m in models:
        subset = [r for r in gs_results if r["deletion_model"] == m]
        best = max(subset, key=lambda x: x["roc_auc"])
        best_params[m] = best
        print(f"{m}: θmax={best['theta_max']}, diff_base={best['difficulty_base_max']}, "
              f"a_scale={best['a_scale']} (roc_auc={best['roc_auc']:.4f})")
    milestone_percents = [i/100.0 for i in range(1,51)]
    cat_tasks = [
        run_deletion_cat_experiment_milestones.remote(
            m,
            milestone_percents,
            best_params[m]["theta_max"],
            best_params[m]["difficulty_base_max"],
            best_params[m]["a_scale"]
        )
        for m in models
    ]
    cat_results_nested = ray.get(cat_tasks)
    # ray.shutdown()
    ray.shutdown()

    flat = [rec for sub in cat_results_nested for rec in sub]
    results_df = pd.DataFrame(flat)
    out_dir = "result"; os.makedirs(out_dir, exist_ok=True)
    fname = (f"{out_dir}/CAT_reduce_model_shuffle_{args.N}_"
             f"{dataset_name}_{irt_name}_{mmdd}_{hhmm}{filter_suffix}.csv")
    results_df.to_csv(fname, index=False)
    print(f"Results are saved to '{fname}'")

    exp_name = (f"CAT_reduce_model_shuffle{args.N}_{dataset_name}_"
                f"{mmdd}_{hhmm}_{pl}pl{filter_suffix}")
    mlflow.set_experiment(exp_name)
    with mlflow.start_run(run_name=exp_name):
        mlflow.log_params({
            "dataset_name": dataset_name,
            "pl": pl,
            "use_guessing": use_guessing,
            "split_difficulty": split_difficulty,
            "split_ability": split_ability,
            "filter_no_prefix": filter_no_prefix,
            "N": args.N,
            "seed": args.seed,
        })
        mlflow.log_artifact(fname)