import os
import datetime
import argparse
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from scipy.stats import kendalltau
import ray
import random
from mmirt.utils.contami_judge import is_contaminated

def parse_args():
    parser = argparse.ArgumentParser(
        description="Random experiment"
    )

    parser.add_argument(
        "--N",
        type=int,
        default=900,
        help="groups"
    )
    parser.add_argument(
        "--dataset",
        choices=["mathvista", "vqa","mmmu","seedbench"],
        default="mmmu",
        help="mathvista or vqa or mmmu or seedbench"
    )
    parser.add_argument(
        "--filter_no_prefix",
        action="store_true",
        help="`<no_*>`"
    )
    return parser.parse_args()

args = parse_args()
N = args.N

dataset_name = args.dataset
file1_path = "mmmu_normal_shuffle/df_cleaned_renamed_with_metrics_index.csv"
file2_path = "mmmu_normal_shuffle/shufffle_mmmu.csv"
if dataset_name == "mathvista":
    file1_path = "merged_mathvista/mathvista_combined.csv"
    file2_path = "merged_mathvista/merged_mathvista_shuffle.csv"
elif dataset_name == "vqa":
    file1_path = "vqa_shuffle_normal/vqa_merged.csv"
    file2_path = "vqa_shuffle_normal/vqa_shuffle_filtered.csv"
elif dataset_name == "seedbench":
    file1_path = "seed_normal_shuffled/seed_normal_all_clean.csv"
    file2_path = "seed_normal_shuffled/seed_shuffled.csv"

df1 = pd.read_csv(file1_path, index_col=0)
df2 = pd.read_csv(file2_path, index_col=0)

prefixes = ["<no_image>", "<no_question>", "<no_info>"]
if args.filter_no_prefix:
    def has_no_prefix(col):
        return any(col.startswith(p) for p in prefixes)
    df1 = df1.loc[:, [col for col in df1.columns if not has_no_prefix(col)]]
    df2 = df2.loc[:, [col for col in df2.columns if not has_no_prefix(col)]]

def extract_group(col_name):
    for prefix in prefixes:
        if col_name.startswith(prefix):
            return col_name[len(prefix):]
    return col_name

groups = {}
for col in df2.columns:
    grp = extract_group(col)
    groups.setdefault(grp, []).append(col)
all_groups = list(groups.keys())
if N > len(all_groups):
    raise ValueError(f"N={N} > {len(all_groups)}")
selected_groups = random.sample(all_groups, N)
selected_columns = []
for grp in selected_groups:
    selected_columns.extend(groups[grp])

df2_subset = df2[selected_columns]
merged_df = pd.merge(df1, df2_subset, left_index=True, right_index=True)
gt_df = df1.copy()

gt_mean_scores = gt_df.mean(axis=1)
ranked_models = gt_mean_scores.sort_values(ascending=False)
ranked_df = ranked_models.reset_index()
ranked_df.columns = ['model_id', 'score']
ranked_df['rank'] = ranked_df['score'].rank(ascending=False, method='min').astype(int)
ranked_df = ranked_df.sort_values('rank')

def get_base_name(problem_name):
    for prefix in prefixes:
        if problem_name.startswith(prefix):
            return problem_name[len(prefix):]
    return problem_name

x, y, r = 1, 95, 1
train_percentages = np.arange(x, y, r) / 100.0

ray.init(ignore_reinit_error=True)

@ray.remote(num_cpus=1)
def run_random_experiment(deletion_model, train_percentage, seed):
    fixed_mask = np.ones(len(merged_df.columns), dtype=bool)
    valid_columns = [col for col, ok in zip(merged_df.columns, fixed_mask) if ok]

    group_data = merged_df[valid_columns].groupby(get_base_name, axis=1).mean()
    items = list(group_data.columns)

    final_count = int(len(items) * train_percentage)
    final_count = min(final_count, len(items))
    random.seed(seed)
    random_extracted = random.sample(items, final_count)

    extracted_columns = [col for col in merged_df.columns
                            if get_base_name(col) in random_extracted]

    random_acc_scores = merged_df[extracted_columns].mean(axis=1)
    random_rank_df = random_acc_scores.sort_values(ascending=False).reset_index()
    random_rank_df.columns = ['model_id', 'random_score']
    random_rank_df['random_rank'] = random_rank_df['random_score']\
        .rank(ascending=False, method='min').astype(int)

    merged_rank = pd.merge(
        random_rank_df[['model_id', 'random_rank']],
        ranked_df[['model_id', 'rank']],
        on='model_id'
    )
    corr, p_val = spearmanr(
        merged_rank['random_rank'],
        merged_rank['rank']
    )

    rmse = np.sqrt(
        (merged_df.loc[deletion_model, extracted_columns].mean()
         - gt_df.loc[deletion_model].mean()) ** 2
    )
    shuffle_ratio = sum(is_contaminated(p) for p in extracted_columns) / len(extracted_columns)

    footrule_dist = (merged_rank['random_rank'] - merged_rank['rank']).abs().sum()

    def jaccard_at_k(df, k):
        top_pred = set(df.nsmallest(k, 'random_rank')['model_id'])
        top_true = set(df.nsmallest(k,            'rank')['model_id'])
        return len(top_pred & top_true) / len(top_pred | top_true)
    
    def top_k_accuracy(df, k):
        top_pred = set(df.nsmallest(k, 'random_rank')['model_id'])
        top_true = set(df.nsmallest(k,    'rank')['model_id'])
        return len(top_pred & top_true) / len(top_pred)
    
    def kendall_tau(df):
        tau, p_val = kendalltau(df['random_rank'], df['rank'])
        return tau, p_val  


    jaccard5  = jaccard_at_k(merged_rank,  5)
    jaccard10 = jaccard_at_k(merged_rank, 10)

    return {
        "seed":                  seed,
        "deletion_model":        deletion_model,
        "milestone_percent":     train_percentage,
        "n_administered_questions": len(extracted_columns),
        "corr_mixed_gt":         corr,
        "p_val":                 p_val,
        "acc_error":             rmse,
        "shuffle_ratio":         shuffle_ratio,
        "footrule_dist":         footrule_dist,
        "jaccard5":              jaccard5,
        "jaccard10":             jaccard10,
        "top1_acc":              top_k_accuracy(merged_rank, 1),
        "top5_acc":              top_k_accuracy(merged_rank, 5),
        "top10_acc":             top_k_accuracy(merged_rank, 10),
        "kendall_corr":          kendall_tau(merged_rank)[0],
        "kendall_pval":          kendall_tau(merged_rank)[1],
        
    }


tasks = []
for seed in range(11):
    for model in merged_df.index:
        for tp in train_percentages:
            tasks.append(run_random_experiment.remote(model, tp, seed))

results = ray.get(tasks)
print(f"Experiment finished")
results_df = pd.DataFrame(results)
suffix = datetime.datetime.now().strftime('%m%d_%H%M')
out_dir = "result"
os.makedirs(out_dir, exist_ok=True)
fname = f"{out_dir}/random_{N}_{dataset_name}_{"filtered" if args.filter_no_prefix else "all"}_{suffix}.csv"
results_df.to_csv(fname, index=False)
print(f"Results are finished and saved to {fname}")
