import os
import sys
import datetime
import pytz
import argparse
import random
import numpy as np
import pandas as pd
from scipy.stats import spearmanr, kendalltau
import ray
import mlflow
from mmirt.utils.contami_judge import is_contaminated

def parse_args():
    parser = argparse.ArgumentParser(
        description="Flash experiment"
    )
    parser.add_argument(
        "--mode",
        choices=["group", "question"],
        default="group",
        help=""
    )
    parser.add_argument(
        "--N",
        type=int,
        default=900,
        help="group num"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="seed"
    )
    args, _ = parser.parse_known_args()
    return args

args = parse_args()
MODE = args.mode
N = args.N
random.seed(args.seed)
np.random.seed(args.seed)



file1_path = "mmmu_normal_shuffle/df_cleaned_renamed_with_metrics_index.csv"    
file2_path = "mmmu_normal_shuffle/shufffle_mmmu.csv" 
dataset_name = "mmmu"
if "--mathvista" in sys.argv:
    dataset_name = "mathvista"; data_dir = "data_mathvista"
    file1_path = "merged_mathvista/mathvista_combined.csv"
    file2_path = "merged_mathvista/merged_mathvista_shuffle.csv"
elif "--vqa" in sys.argv:
    dataset_name = "vqa"; data_dir = "data_vqa"
    file1_path = "vqa_shuffle_normal/vqa_merged.csv"
    file2_path = "vqa_shuffle_normal/vqa_shuffle_filtered.csv"
elif "--seedbench" in sys.argv:
    print("Seed bench is used")
    dataset_name = "seed"
    file1_path = "seed_normal_shuffled/seed_normal_all_clean.csv"
    file2_path = "seed_normal_shuffled/seed_shuffled.csv"

df1 = pd.read_csv(file1_path, index_col=0)
df2 = pd.read_csv(file2_path, index_col=0)

prefixes = ["<no_image>", "<no_question>", "<no_info>"]
def extract_group(col_name: str) -> str:
    for prefix in prefixes:
        if col_name.startswith(prefix):
            return col_name[len(prefix):]
    return col_name

groups: dict = {}
for col in df2.columns:
    grp = extract_group(col)
    groups.setdefault(grp, []).append(col)

all_groups = list(groups.keys())

random.seed(42)
selected_groups = random.sample(all_groups, N)

selected_columns: list = []
for grp in selected_groups:
    selected_columns.extend(groups[grp])

df2_subset = df2[selected_columns]
merged_df = pd.merge(df1, df2_subset, left_index=True, right_index=True)

    

if "--cuda0" in sys.argv:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
elif "--cuda1" in sys.argv:
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
elif "--cuda2" in sys.argv:
    os.environ['CUDA_VISIBLE_DEVICES'] = '2'
elif "--cuda3" in sys.argv:
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'
else:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

if "--2pl" in sys.argv:
    use_guessing, pl = False, 2
elif "--3pl" in sys.argv:
    use_guessing, pl = True, 3
else:
    use_guessing, pl = False, 2

filter_no_prefix = "--filter_no_prefix" in sys.argv
filter_suffix = "_filterNoPrefix" if filter_no_prefix else ""

if "--mmirt_inner" in sys.argv:
    irt_name = "mmirt_inner"; split_difficulty = True; split_ability = True
    from mmirt.mmirt_inner_product import Standard3PLMMIRT
elif "--irt" in sys.argv:
    irt_name = "nosplit"; split_difficulty = False; split_ability = False
    from mmirt.irt_onelayer_nostlict import Standard3PLIRT
else:
    assert False, "Invalid IRT model name. Use --noasplit or --nosplit."

gt_df = df1.copy()

now = datetime.datetime.now()
mmdd, hhmm = now.strftime('%m%d'), now.strftime('%H%M')
print(f"Date: {mmdd}, Time: {hhmm}")

response_data = merged_df.values.astype(np.float32)
models = merged_df.index.tolist()
problems = merged_df.columns.tolist()

@ray.remote(num_cpus=1)
def run_deletion_cat_experiment_milestones(deletion_model: str, milestone_percents: list):
    full_response = merged_df.values.astype(np.float32)
    n_models, n_items = full_response.shape
    try:
        deletion_idx = models.index(deletion_model)
    except ValueError:
        print(f"Error: {deletion_model} not found."); return []

    train_mask = np.ones_like(full_response)
    train_mask[deletion_idx, :] = -1

    if MODE == "group":
        def get_base(p):
            for pre in prefixes:
                if p.startswith(pre): return p[len(pre):]
            return p
        problem_groups = {}
        for p in problems:
            g = get_base(p)
            problem_groups.setdefault(g, []).append(p)
    else:  # MODE == "question"
        problem_groups = {p: [p] for p in problems}

    all_units = list(problem_groups.keys())
    n_total = len(all_units)
    steps = [max(1, int(np.floor(n_total * p))) for p in milestone_percents]

    model_obj = Standard3PLIRT(
        response_data=full_response,
        student_names=models,
        test_names=problems,
        train_mask=train_mask,
        split_difficulty=split_difficulty,
        split_ability=split_ability,
        lr=1e-2,
        batch_size=512,
        max_epochs=5000,
        device="cpu",
        eps=1e-3,
        embedding_dim=1024,
        theta_max=1.0,
        difficulty_base_max=3.0,
        difficulty_other_max=1.0,
        theta_init=0.2,
        use_guessing=use_guessing,
    )
    print(f"[{deletion_model}] IRT learning start")
    model_obj.fit()

    candidate = set(all_units)
    administered = set()
    administered_items = []
    records = []
    step = 0

    while candidate and step < max(steps):
        fisher = model_obj.compute_item_fisher_information(deletion_model)
        unit_scores = {
            unit: sum(fisher.get(item, 0.0) for item in problem_groups[unit])
            for unit in candidate
        }
        if not unit_scores:
            break

        chosen = max(unit_scores, key=unit_scores.get)
        candidate.remove(chosen)
        administered.add(chosen)
        administered_items.extend(problem_groups[chosen])
        step += 1

        for idx, p in enumerate(problems):
            if (MODE=="group" and get_base(p) in administered) or \
               (MODE=="question" and p in administered):
                train_mask[deletion_idx, idx] = 1

        model_obj.update_single_theta(
            model_name=deletion_model,
            problem_names=administered_items,
            lr=1e-4,
            max_epochs=1000,
            patience=50,
        )

        if step in steps:
            eval_items = [
                p for p in administered_items
                if not (filter_no_prefix and p.startswith("<no_"))
            ]

            subset_means = merged_df[eval_items].mean(axis=1)
            overall_gt   = gt_df.mean(axis=1)

            r_pred = subset_means.rank(ascending=False, method='min')
            r_true = overall_gt  .rank(ascending=False, method='min')
            corr, _ = spearmanr(r_pred, r_true)

            diff = subset_means.loc[deletion_model] - overall_gt.loc[deletion_model]
            acc_error = np.sqrt(diff**2)

            shuffle_ratio = sum(is_contaminated(p) for p in eval_items) / max(1, len(eval_items))

            footrule_dist = (r_pred - r_true).abs().sum()

            def jaccard_at_k(pred_rank, true_rank, k):
                top_pred = set(pred_rank.nsmallest(k).index)
                top_true = set(true_rank.nsmallest(k).index)
                return len(top_pred & top_true) / len(top_pred | top_true)

            jaccard5  = jaccard_at_k(r_pred, r_true, 5)
            jaccard10 = jaccard_at_k(r_pred, r_true, 10)

            def top_k_accuracy(pred_rank, true_rank, k):
                top_pred = set(pred_rank.nsmallest(k).index)
                top_true = set(true_rank.nsmallest(k).index)
                return len(top_pred & top_true) / len(top_pred)

            def kendall_tau(pred_rank, true_rank):
                tau, p_val = kendalltau(pred_rank, true_rank)
                return tau, p_val

            records.append({
                "deletion_model":          deletion_model,
                "milestone_percent":       step / n_total,
                "step":                    step,
                "n_administered_groups":   len(administered) if MODE=="group" else 0,
                "n_administered_questions":len(eval_items),
                "corr_mixed_gt":           corr,
                "acc_error":               acc_error,
                "shuffle_ratio":           shuffle_ratio,
                "footrule_dist":           footrule_dist,
                "jaccard5":                jaccard5,
                "jaccard10":               jaccard10,
                "top1_acc":                top_k_accuracy(r_pred, r_true, 1),
                "top5_acc":                top_k_accuracy(r_pred, r_true, 5),
                "top10_acc":               top_k_accuracy(r_pred, r_true, 10),
                "kendall_corr":            kendall_tau(r_pred, r_true)[0],
                "kendall_pval":            kendall_tau(r_pred, r_true)[1],
            })

            print(
                f"[{deletion_model}] Milestone {step}: "
                f"corr={corr:.4f}, error={acc_error:.4f}, "
                f"shuffle_ratio={shuffle_ratio:.4f}, "
                f"footrule={footrule_dist:.4f}, "
                f"j5={jaccard5:.4f}, j10={jaccard10:.4f}"
            )


    print(f"[{deletion_model}] CAT ended : {step}")
    return records

if __name__ == '__main__':
    ray.init(ignore_reinit_error=True)
    milestone_percents = [i/100 for i in range(1,51)]
    tasks = [
        run_deletion_cat_experiment_milestones.remote(m, milestone_percents)
        for m in models
    ]
    all_results = ray.get(tasks)
    ray.shutdown()

    flat = [r for sub in all_results for r in sub]
    df_res = pd.DataFrame(flat)

    out_path = f"result/CAT_reduce_model_shuffle_{N}_{dataset_name}_{irt_name}_{mmdd}_{hhmm}{filter_suffix}.csv"
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    df_res.to_csv(out_path, index=False)
    print(f"results are saved to '{out_path}'")

    exp_name = f"CAT_reduce_model_shuffle{N}_{dataset_name}_{mmdd}_{hhmm}_{pl}pl{filter_suffix}"
    mlflow.set_experiment(exp_name)
    with mlflow.start_run(run_name=exp_name):
        mlflow.log_params({
            "dataset_name": dataset_name,
            "pl": pl,
            "use_guessing": use_guessing,
            "N": N,
            "mode": MODE,
            "seed": args.seed,
            "split_difficulty": split_difficulty,
            "split_ability": split_ability,
            "filter_no_prefix": filter_no_prefix,
        })
        mlflow.log_artifact(out_path)
