import pandas as pd

# Data from ACC table (percentages)
acc_data = {
    'model': [
        'gpt4o', 'Qwen2.5-7B-Instruct', 'gemma-3-4b-it', 'Kimi-K2-Instruct', 'Llama-3.1-8B-Instruct',
        'GPT5', 'gpt-3.5-turbo-1106', 'claude-3.7-sonnet-thinking', 'o3mini', 'deepseek-r1',
        'qwen2.5-72b-instruct', 'gemini-2.5-pro', 'gemma-3-27b-it', 'Qwen3-32B', 'claude-3.5-sonnet-20241022',
        'QwQ-32B', 'llama3.3-70B-instruct', 'doubao-1-5-thinking-pro-250415', 'gemini-2.5-flash'
    ],
    'coding': [51.85, 57.41, 49.07, 62.04, 48.15, 62.96, 49.07, 66.67, 71.30, 94.44, 69.44, 73.15, 75.00, 70.37, 73.14,
               70.37, 57.41, 59.26, 75.00],
    'knowledge': [52.80, 56.07, 54.67, 61.68, 57.48, 68.69, 50.47, 72.43, 67.76, 69.63, 69.63, 67.44, 68.22, 71.03,
                  73.83, 70.56, 65.42, 65.89, 72.90],
    'math': [52.10, 49.58, 49.58, 47.90, 46.22, 61.34, 58.82, 64.71, 56.30, 59.66, 70.59, 80.00, 64.71, 73.95, 67.23,
             70.59, 55.46, 46.22, 75.63],
    'reasoning': [50.93, 60.19, 58.33, 57.41, 59.26, 73.15, 49.07, 64.81, 75.00, 67.74, 67.59, 67.74, 75.00, 72.22,
                  70.37, 68.52, 63.89, 58.33, 74.07],
    'roleplay': [48.00, 52.00, 56.00, 60.00, 52.00, 64.00, 48.00, 66.00, 58.00, 63.33, 61.33, 63.33, 64.00, 65.33,
                 71.33, 66.67, 60.67, 64.00, 67.33],
    'writing': [59.33, 64.00, 64.00, 64.00, 65.33, 66.00, 68.00, 68.67, 72.00, 72.67, 73.33, 73.33, 73.33, 73.33, 73.33,
                74.00, 74.67, 76.67, 78.00]
}
acc_df = pd.DataFrame(acc_data).set_index('model')

# Data from Cost table
cost_data = {
    'model': [
        'claude-3.5-sonnet-20241022', 'gemma-3-4b-it', 'gemma-3-27b-it', 'gpt-3.5-turbo-1106', 'gpt4o',
        'Llama-3.1-8B-Instruct', 'llama3.3-70B-instruct', 'Qwen2.5-7B-Instruct', 'qwen2.5-72b-instruct',
        'Qwen3-32B', 'claude-3.7-sonnet-thinking', 'deepseek-r1', 'doubao-1-5-thinking-pro-250415',
        'gemini-2.5-flash', 'gemini-2.5-pro', 'GPT5', 'Kimi-K2-Instruct', 'o3mini', 'QwQ-32B'
    ],
    'coding': [2520.45, 2271.06, 2178.64, 1924.02, 2236.95, 2018.6, 2235.45, 1964.84, 2231.88, 4136.56, 2586.09,
               4147.55, 3994.2, 6515.9, 6152.75, 4743.45, 2170.54, 2882.16, 3876.1],
    'knowledge': [1869.08, 1497.63, 1459.24, 1380.85, 1602.74, 1471.86, 1627.34, 1447.01, 1574.63, 2198.83, 1989.56,
                  2165.88, 2172.2, 2252.4, 3733.51, 2825.25, 1482.87, 2270.49, 2241.9],
    'math': [1899.55, 1446.65, 1456.71, 1304.49, 1567.42, 1429.1, 1615.29, 1425.35, 1576.98, 2345.76, 2010.46, 1921.35,
             1875.6, 3023, 4259.7, 2653.29, 1476.49, 2091.78, 3276.8],
    'reasoning': [2059.75, 1660.85, 1634.71, 1526.75, 1741.54, 1666.42, 1829.9, 1669.84, 1743.16, 2135.05, 2135.94,
                  2591.9, 2740, 9785.3, 5950.26, 5078.11, 1742.73, 3027.99, 4091.9],
    'roleplay': [1667.7, 1324.47, 1289.43, 1162.74, 1450.11, 1349.83, 1480.17, 1157.02, 1459.97, 1611.97, 1694.29,
                 1455.3, 1699.6, 2276.9, 2291.8, 2429.46, 1381.48, 1864.23, 1794.1],
    'writing': [2129.83, 1780.95, 1696.92, 1549.57, 1846.72, 1734.13, 1892.77, 1588.21, 1896.48, 2071.97, 2083.89,
                1786.8, 2334, 3105.6, 2781.87, 3121.2, 1781.11, 2299.67, 2490.8]
}
cost_df = pd.DataFrame(cost_data).set_index('model')

# Data from Time table
time_data = {
    'model': [
        'claude-3.5-sonnet-20241022', 'gemma-3-4b-it', 'gemma-3-27b-it', 'gpt-3.5-turbo-1106', 'gpt4o',
        'Llama-3.1-8B-Instruct', 'llama3.3-70B-instruct', 'Qwen2.5-7B-Instruct', 'qwen2.5-72b-instruct',
        'Qwen3-32B', 'claude-3.7-sonnet-thinking', 'deepseek-r1', 'doubao-1-5-thinking-pro-250415',
        'gemini-2.5-flash', 'gemini-2.5-pro', 'GPT5', 'Kimi-K2-Instruct', 'o3mini', 'QwQ-32B'
    ],
    'coding': [14.02, 3.89, 8.25, 2.24, 7.56, 4.86, 17.54, 3.75, 20.39, 13.83, 18.21, 37.41, 58.19, 20.39, 40.57, 44.83,
               9.58, 8.91, 45.4],
    'knowledge': [13.94, 2.92, 6.43, 2.62, 6.11, 6.3, 14.3, 4.82, 18.58, 25.86, 18.37, 14.65, 24.82, 6.18, 24.13, 23.16,
                  6.22, 8.93, 33.7],
    'math': [13.58, 3.06, 5.6, 2.46, 6.48, 5.67, 14.56, 3.83, 13.59, 28.28, 15.28, 14.22, 18.89, 8.87, 25.47, 19.49, 6,
             7.85, 52.69],
    'reasoning': [11.42, 3.44, 6.99, 2.09, 8.93, 7.59, 15.88, 5.3, 13.5, 17.46, 17.29, 32.09, 29.45, 34.98, 40.03,
                  77.15, 7.19, 12.29, 58.32],
    'roleplay': [13.19, 4.13, 8.28, 2.29, 7.91, 6.61, 19.27, 3.42, 17.24, 19.75, 16.49, 17.6, 25.15, 8.18, 14.61, 35,
                 7.73, 7.23, 22.7],
    'writing': [12.93, 4.4, 7.41, 2.53, 6.48, 7.8, 19.04, 3.83, 16.21, 19.34, 22.74, 10.5, 27.02, 9.97, 16.74, 49.14,
                8.04, 7.93, 27.21]
}
time_df = pd.DataFrame(time_data).set_index('model')

# Align all DataFrames to have the same models (intersection)
common_models = list(set(acc_df.index) & set(cost_df.index) & set(time_df.index))
acc_df = acc_df.loc[common_models]
cost_df = cost_df.loc[common_models]
time_df = time_df.loc[common_models]


def select_model(question_type, mode):
    """
    Select the best model based on the given mode for a specific question type.

    Parameters:
    - question_type: str, one of 'coding', 'knowledge', 'math', 'reasoning', 'roleplay', 'writing'
    - mode: int, 1 for accuracy only, 2 for accuracy - time, 3 for accuracy - norm_time - norm_cost

    Returns:
    - dict: Dictionary containing model name, accuracy, time, cost and other details
    """
    if question_type not in acc_df.columns:
        raise ValueError("Invalid question type. Must be one of: coding, knowledge, math, reasoning, roleplay, writing")

    if mode not in [1, 2, 3]:
        raise ValueError("Invalid mode. Must be 1, 2, or 3")

    # 获取所有必要数据（移至正确的缩进位置）
    acc = acc_df[question_type]
    time = time_df[question_type]
    cost = cost_df[question_type]

    if mode == 1:
        # Mode 1: Highest accuracy
        model_name = acc.idxmax()

    elif mode == 2:
        # Mode 2: accuracy - time
        score = acc - time
        model_name = score.idxmax()

    else:
        # Mode 3: accuracy - normalized_time - normalized_cost
        # Min-max normalization for time and cost (0 = best, 1 = worst)
        if time.max() > time.min():
            time_norm = (time - time.min()) / (time.max() - time.min())
        else:
            time_norm = pd.Series(0, index=time.index)  # All same, norm to 0

        if cost.max() > cost.min():
            cost_norm = (cost - cost.min()) / (cost.max() - cost.min())
        else:
            cost_norm = pd.Series(0, index=cost.index)  # All same, norm to 0

        # 修复缩进问题
        score = acc - time_norm - cost_norm
        model_name = score.idxmax()

    # 获取选定模型的性能指标（移至所有分支之外）
    acc_value = acc[model_name]
    time_value = time_df.loc[model_name, question_type]
    cost_value = cost_df.loc[model_name, question_type]
    
    return {
        'model': model_name,
        'accuracy': acc_value,
        'time': time_value,
        'cost': cost_value,
        'question_type': question_type,
        'mode': mode
    }


def print_model_details(result):
    """Format and print model details in a readable way"""
    print("\n" + "="*60)
    print(f"问题类型: {result['question_type']} | 模式: {result['mode']}")
    print("="*60)
    print(f"选择模型: {result['model']}")
    print(f"准确率: {result['accuracy']:.2f}%")
    print(f"时间: {result['time']:.2f} 秒")
    print(f"成本: {result['cost']:.2f} tokens")
    print("="*60)

if __name__ == "__main__":
    # 测试所有问题类型
    question_types = ['coding类型']
    question_types = ['coding', 'knowledge', 'math', 'reasoning', 'roleplay', 'writing']
    
    for q_type in question_types:
        for mode in [1, 2, 3]:
            result = select_model(q_type, mode)
            print_model_details(result)