import json
import os.path
import pandas as pd
import numpy as np

model_info = {
    #S <10B, M 10B~50B, L >50B
    "Deepseek_Chat_V3": {
        "size": "L",
        "open_source": True
    },
    "gemini_1.5_pro_002": {
        "size": "L",
        "open_source": False
    },
    "gpt4o": {
        "size": "L",
        "open_source": False
    },
    "gpt4o_mini": {
        "size": "L",
        "open_source": False
    },
    "gemini_1.5_flash_002": {
        "size": "L",
        "open_source": False
    },
    "gemma-2-27b": {
        "size": "M",      # 27B => 10B~50B
        "open_source": True
    },
    "Yi-1.5-34B": {
        "size": "M",      # 34B => 10B~50B
        "open_source": True
    },
    "Qwen2.5-7B-Instruct": {
        "size": "S",      # 7B => <10B
        "open_source": True
    },
    "Phi-3-medium-128k": {
        "size": "M",      
        "open_source": True
    },
    "Phi-3-small-128k": {
        "size": "S",      
        "open_source": True
    },
    "Qwen2-7B": {
        "size": "S",      # 7B => <10B
        "open_source": True
    },
    "DeepSeek-Coder-V2-Lite-Instruct": {
        "size": "M",
        "open_source": True
    },
    "Phi-3.5-Mini": {
        "size": "S",      
        "open_source": True
    },
    "Llama-3.1-8B": {
        "size": "S",      # 8B => <10B
        "open_source": True
    },
    "Qwen2.5-3B-Instruct": {
        "size": "S",      # 3B => <10B
        "open_source": True
    },
    "Qwen2-Math-7B": {
        "size": "S",
        "open_source": True
    },
    "Qwen2-Math-1.5B": {
        "size": "S",
        "open_source": True
    },
    "Yi-1.5-9B": {
        "size": "S",      # 9B => <10B
        "open_source": True
    },
    "Qwen2.5-1.5B-Instruct": {
        "size": "S",
        "open_source": True
    },
    "gemma-2-9b": {
        "size": "S",      # 9B => <10B
        "open_source": True
    },
    "Yi-1.5-6B": {
        "size": "S",      # 6B => <10B
        "open_source": True
    },
    "Mistral-7B-Instruct-v0.3": {
        "size": "S",      # 7B => <10B
        "open_source": True
    },
    "DeepSeek-V2-Lite-Chat": {
        "size": "M",
        "open_source": True
    },
    "Qwen2-0.5B": {
        "size": "S",      # 0.5B => <10B
        "open_source": True
    },
    "Qwen2-1.5B": {
        "size": "S",      # 1.5B => <10B
        "open_source": True
    },
    "Qwen2.5-0.5B-Instruct": {
        "size": "S",
        "open_source": True
    },
    "Llama-3.2-3B-Instruct": {
        "size": "S",      # 3B => <10B
        "open_source": True
    },
    "gemma-2-2b": {
        "size": "S",      # 2B => <10B
        "open_source": True
    },
    "Llama-3.2-1B-Instruct": {
        "size": "S",      # 1B => <10B
        "open_source": True
    },
    'Qwen2.5-Math-1.5B-Instruct': {
        "size": "S",
        "open_source": True
    },
}

def import_metabench():
    file_path = "MAIN_0125.json"
    df = pd.read_json(file_path, lines=True)

    metadata_cols = ["metabench_id", "original_benchmark", "original_id"]
    original_model_columns = [c for c in df.columns if c not in metadata_cols]

    def short_model_name(c):
        return c.replace("MetaBench_", "").replace("-5-shot-CoT", "")

    df[original_model_columns] = df[original_model_columns].replace("null", np.nan).astype(float)
    df[original_model_columns] = df[original_model_columns].fillna(0)

    column_to_short = {col: short_model_name(col) for col in original_model_columns}
    df.rename(columns=column_to_short, inplace=True)

    model_columns = [c for c in df.columns if c not in metadata_cols]

    models_to_exclude = []

    matched_models = {m: model_info[m] for m in model_columns if m in model_info}

    filtered_model_columns = [m for m in model_columns
                              if m in matched_models and m not in models_to_exclude]


    filtered_data = df[filtered_model_columns]

    data = filtered_data.to_numpy()
    return data, filtered_model_columns

def compute_stationary_distributions_bipartite(
        A_QM,  # shape=(Q, M).  A_QM[q,m] = 1 if question q is correct by model m
        damping_factor=0.85,
        max_iter=1000,
        tol=1e-9
):
    """
    A_QM: (Q x M)
      A_QM[q,m] = 1 if model m answered question q correctly
      Then Fail_QM = 1 - A_QM  means model m answered question q incorrectly
    """

    Q, M = A_QM.shape
    J = np.ones_like(A_QM)
    Fail_QM = J - A_QM
    row_sum_q = A_QM.sum(axis=1, keepdims=True)  # shape=(Q,1)
    if (row_sum_q == 0).any() or (row_sum_q ==M).any():
        raise ValueError("Some question was never correctly answered.")
    PQM = A_QM / row_sum_q

    Fail_MQ = Fail_QM.T
    row_sum_m = Fail_MQ.sum(axis=1, keepdims=True)
    if (row_sum_m == 0).any():
        raise ValueError("Some model never answered incorrectly.")
    PMQ = Fail_MQ / row_sum_m  # row-based => sum_{q} PMQ[m,q]=1

    alpha = damping_factor

    pi_Q = np.full(Q, 1.0 / Q)
    pi_M = np.full(M, 1.0 / M)

    for it in range(max_iter):
        pi_Q_new = alpha * (PMQ.T @ pi_M) + (1 - alpha) * np.ones(Q) / Q
        pi_M_new = alpha * (PQM.T @ pi_Q_new) + (1 - alpha) * np.ones(M) / M

        delta_Q = np.linalg.norm(pi_Q_new - pi_Q, 1)
        delta_M = np.linalg.norm(pi_M_new - pi_M, 1)
        if delta_Q < tol and delta_M < tol:
            pi_Q, pi_M = pi_Q_new, pi_M_new
            print(f"Converged after {it + 1} iterations.")
            break
        pi_Q, pi_M = pi_Q_new, pi_M_new
    else:
        print("Warning: max_iter reached without full convergence.")
    pi_Q_normalized = (pi_Q / pi_Q.max()) * 100
    pi_M_normalized = (pi_M / pi_M.max()) * 100
    return pi_Q_normalized, pi_M_normalized



def filter_questions(data, assign_scores=True):
    all_correct_questions = np.where(np.all(data == 1, axis=1))[0]
    all_incorrect_questions = np.where(np.all(data == 0, axis=1))[0]
    for q in all_correct_questions:
        assert np.all(data[q] == 1)

    for q in all_incorrect_questions:
        assert np.all(data[q] == 0)

    remaining_indices = np.setdiff1d(np.arange(data.shape[0]), np.concatenate((all_correct_questions, all_incorrect_questions)))
    filtered_data = data[remaining_indices]
    correct_scores = None
    incorrect_scores = None
    if assign_scores:
        correct_scores = np.zeros_like(all_correct_questions, dtype=float)
        incorrect_scores = np.full_like(all_incorrect_questions, np.nan, dtype=float)  # Set to NaN for now
    return all_correct_questions, all_incorrect_questions, filtered_data, correct_scores, incorrect_scores


def store_pi_json(data, model_name):
    all_correct_questions, all_incorrect_questions, filtered_data, correct_scores, incorrect_scores = filter_questions(
        data)
    pi_Q, pi_M = compute_stationary_distributions_bipartite(filtered_data)
    max_score = 120
    # Assign scores to all correct and all incorrect questions
    min_score = -20
    extended_scores =pi_Q
    extended_scores[all_correct_questions] = min_score
    extended_scores[all_incorrect_questions] = max_score

    sorted_indices = np.argsort(-extended_scores)
    sorted_scores = extended_scores[sorted_indices]
    sorted_questions = dict(zip(sorted_indices, sorted_scores))
    question_difficulty = []
    for question_id, score in sorted(sorted_questions.items(), key=lambda x: x[1], reverse=True):
        print(f"Question ID: {question_id}, Difficulty: {score:.2f}")
        question_difficulty.append({
            "metabench_id": int(question_id),  # Convert to native Python int
            "difficulty": f"{score:.6f}",


        })

    question_output_path = "RankLLM_question_difficulty.json"
    with open(question_output_path, "w") as f:
        json.dump(question_difficulty, f, indent=2)
    model_performance = []
    for model_id, performance in sorted(zip(model_name, pi_M), key=lambda x: x[1], reverse=True):
        model_id = model_id.split("-5-shot-CoT")[0].split("MetaBench_")[-1].split(".processed")[0]
        print(f"Model ID: {model_id}, Performance: {performance:.2f}")


        model_performance.append({
            "model_id": model_id,
            "performance": f"{performance:.4f}"
        })
    model_output_path = "RankLLM_Model_competency.json"
    with open(model_output_path, "w") as f:
        json.dump(model_performance, f, indent=2)





if __name__ == '__main__':
    data, model_name = import_metabench()

    all_correct_questions, all_incorrect_questions, filtered_data, correct_scores, incorrect_scores = filter_questions(
        data)
    pi_Q, pi_M = compute_stationary_distributions_bipartite(filtered_data)
    store_pi_json(data, model_name)
