from datasets import load_dataset
import random
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import re
from src import utils

########################################
# 1. 数据加载与筛选
########################################
data = load_dataset("combined_math_distill_dataset")
data = data['train']

is_correct_cols = [c for c in data.column_names if c.endswith("_is_correct")]

def calc_num_correct(example):
    correct_count = sum([example[c] for c in is_correct_cols])
    example["num_correct"] = correct_count
    return example

dataset = data.map(calc_num_correct)

total_models = len(is_correct_cols)
simple_dataset = dataset.filter(lambda x: x["num_correct"] == total_models)
hard_dataset   = dataset.filter(lambda x: 3 <= x["num_correct"] <= 5)

simple_200 = simple_dataset.shuffle(seed=42).select(range(500))
hard_200   = hard_dataset.shuffle(seed=42).select(range(500))

print(f"简单题数量: {len(simple_dataset)}")
print(f"难题数量: {len(hard_dataset)}")

simple_small = simple_200.remove_columns(
    [c for c in simple_200.column_names if c not in ['id', 'question', 'correct_answer', 'num_correct']]
)
hard_small = hard_200.remove_columns(
    [c for c in hard_200.column_names if c not in ['id', 'question', 'correct_answer', 'num_correct']]
)

simple_200_list = simple_small.to_list()
hard_200_list   = hard_small.to_list()

########################################
# 2. 准备实验数据
########################################
def prepare_experiment_data(simple_list, hard_list,batch_size=5, sample_size=50):
    experiments = []
    def generate_ratios(batch_size):
        return [(hard, batch_size - hard) for hard in range(batch_size + 1)]
    # ratios = [(0, 2), (1, 1), (2, 0)]
    # ratios = [(0, 3), (1, 2), (2, 1), (2, 0)]
    ratios = generate_ratios(batch_size)
    for hard_count, simple_count in ratios:
        h_pool = hard_list[:hard_count * sample_size]
        s_pool = simple_list[:simple_count * sample_size]

        batch_groups = []
        for i in range(sample_size):
            group = h_pool[i*hard_count:(i+1)*hard_count] + s_pool[i*simple_count:(i+1)*simple_count]
            random.shuffle(group)
            batch_groups.append(group)

        single_questions = [q for group in batch_groups for q in group]

        experiments.append({
            'ratio': (hard_count, simple_count),
            'batch_groups': batch_groups,
            'single_questions': single_questions
        })
    return experiments

########################################
# 3. 模型 API 调用
########################################
def call_model(model, msg,answer_list, temperature=0.0, max_tokens=16384):
    ret, response = utils.chat_completion_api_with_response(model, msg, temperature, max_tokens)
    usage = response.usage
    answers=utils.extract_batch_answers(ret,len(answer_list))
    correct_num=0
    for i in range(len(answers)):
        if answers[i].lower().strip()==answer_list[i].lower().strip():
            correct_num+=1
    if correct_num==0:
        print(f"answer_list:{answer_list}")
        print(f"answers:{answers}")
        print(f"ret:{ret}")
    return {
        'ret': ret,
        'response': response,
        'prompt_tokens': usage.prompt_tokens,
        'completion_tokens': usage.completion_tokens,
        'total_tokens': usage.total_tokens,
        'reasoning_tokens': usage.completion_tokens_details.reasoning_tokens
            if usage.completion_tokens_details else None,
        'correct_num':correct_num,
        'total_num':len(answer_list)
    }

########################################
# 4. 内层并发
########################################
def run_in_parallel(items, worker_func, max_workers=20, desc="", position=0):
    results = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(worker_func, item): item for item in items}
        for fut in tqdm(as_completed(futures), total=len(futures),
                        desc=desc, position=position, leave=True):
            try:
                results.append(fut.result())
            except Exception as e:
                results.append({'error': str(e)})
    return results

########################################
# 5. 单题 / 批量模式任务
########################################
from src import prompt
def run_mode_task(model, exp_idx, mode_name, items, position):
    # prompt_template = (
    #     "我会给你一个或多个问题，请你在最终输出的时候，最后部分为：\n"
    #     "[final answer]:\n"
    #     "answer1\n"
    #     "answer2\n"
    #     "answer3\n"
    #     "一行一个答案，仅包含答案，不需要编号。\n"
    #     "下面是问题列表：\n"
    #     "{questions}"
    # )

    if mode_name == "single":
        def worker(q):
            q_list=[q['question']]
            answer_list=[q['correct_answer']]
            q_msg = prompt.batch_chat_prompt(q_list)
            return call_model(model, q_msg,answer_list)
    else:
        def worker(group):
            q_list=[q['question'] for q in group]
            answer_list=[q['correct_answer'] for q in group]
            q_msg = prompt.batch_chat_prompt(q_list)
            # q_msg = "\n".join([f"Q{i+1}: {q['question']}" for i, q in enumerate(group)])
            return call_model(model, q_msg,answer_list)

    return run_in_parallel(items, worker_func=worker,
                           max_workers=10,
                           desc=f"Exp{exp_idx+1}-{mode_name}",
                           position=position)


########################################
# 7. 统计（含准确率）
########################################
def summarize_results(batch_results, single_results, batch_items, single_items):
    def sum_tokens(res_list):
        return sum(r.get('total_tokens', 0) for r in res_list if 'total_tokens' in r)

    def calc_accuracy(results, items):
        correct_count = 0
        total_count = 0
        for res in results:
            if 'correct_num' in res:
                correct_count += res['correct_num']
            if 'total_num' in res:
                total_count += res['total_num']
       
        return correct_count / total_count if total_count > 0 else 0

    stats = {}
    stats['batch_total_tokens']  = sum_tokens(batch_results)
    stats['single_total_tokens'] = sum_tokens(single_results)
    stats['saving_ratio']        = 1 - stats['batch_total_tokens'] / stats['single_total_tokens']
    stats['batch_avg_tokens']    = stats['batch_total_tokens'] / len(batch_results)
    stats['single_avg_tokens']   = stats['single_total_tokens'] / len(single_results)

    # **计算准确率**
    stats['batch_accuracy']  = calc_accuracy(batch_results, batch_items)
    stats['single_accuracy'] = calc_accuracy(single_results, single_items)

    return stats

########################################
# 8. 运行所有实验 (外层任务)
########################################
def run_all_experiments(model, experiments):
    tasks = []
    pos_counter = 0
    for idx, exp in enumerate(experiments):
        tasks.append({'exp_idx': idx, 'mode': 'single', 'items': exp['single_questions'], 'pos': pos_counter})
        pos_counter += 1
        tasks.append({'exp_idx': idx, 'mode': 'batch',  'items': exp['batch_groups'],   'pos': pos_counter})
        pos_counter += 1

    results_map = {}
    with ThreadPoolExecutor(max_workers=len(tasks)) as executor:
        future_map = {
            executor.submit(run_mode_task, model, t['exp_idx'], t['mode'], t['items'], t['pos']): t
            for t in tasks
        }
        for fut in as_completed(future_map):
            t = future_map[fut]
            try:
                res_list = fut.result()
            except Exception as e:
                res_list = []
                print(f"[Error] {t['exp_idx']+1}-{t['mode']}: {e}")
            results_map[(t['exp_idx'], t['mode'])] = res_list

    # 汇总输出（准确率+token）
    for idx, exp in enumerate(experiments):
        batch_results  = results_map.get((idx, 'batch'), [])
        single_results = results_map.get((idx, 'single'), [])
        stats = summarize_results(
            batch_results, single_results,
            batch_items=[q for group in exp['batch_groups'] for q in group],  # 展平批量
            single_items=exp['single_questions']
        )
        print(f"=== 实验 {idx+1} === ratio={exp['ratio']}\n{stats}")

########################################
# 9. 主入口
########################################
import argparse
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Model Distillation Pipeline")
    parser.add_argument("--batch_size", type=int, required=True, default=5,
                        help="Path to the model directory or file")
    args = parser.parse_args()
    print(args)
    experiments = prepare_experiment_data(simple_list=simple_200_list, hard_list=hard_200_list,batch_size=args.batch_size)
    model = 'deepseek-r1-250120'
    run_all_experiments(model, experiments)