import json
import re
import random
import os

from prompt.omni_prompt import build_judge_prompt, build_long_judge_response
from utils import extract_omni_answer
from collections import Counter


####################################################################################
def build_short_omni_prompt(prompt, chosen, rejected):

    prompt = """You are a fair, professional, and neutral multimodal AI evaluator.\nYou are tasked with evaluating two different multimodal responses (which may include Text, Image, Video, and Audio) generated for the same user query, and determining which one is better.\n\n    ### Final Output Format:\nPlease directly output the final verdict in the following format:\nThe final verdict is `[[A]]` or `[[B]]`\n\n\n###Input Format\n[User Question]\n{instruction}\n\n[The Start of Assistant A's Answer]\n{response_1}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{response_2}\n[The End of Assistant B's Answer]\n""".format(
        instruction=prompt,
        response_1=chosen,
        response_2=rejected,
    )
    return prompt

def clear_text(text: str) -> str:
    text = text.replace("<image>", "image").replace("<video>", "video").replace("<audio>", "audio")
    return text


def old_order():
    ROOT = "/home/export/base/ycsc_1/1/online1///GRM-Omni/data/0907_70k_criteria_meta_reward/"

    suffix2judge= {}
    for line in open(ROOT + "refinment_judge.jsonl").readlines():
        json_item = json.loads(line)
        if "critique_list" in json_item:
            suffix2judge[json_item['suffix']] = json_item['critique_list']
        else:
            suffix2judge[json_item['suffix']] = json_item['judge_list']
    # suffix2judge[suffix] -> list[Judge:str(A&&B)]
    suffix2criteria = {}
    for line in open(ROOT + "criteria.jsonl").readlines():
        json_item = json.loads(line)
        suffix2criteria[json_item['suffix']] = json_item['criteria_list']
    # suffix2criteria[suffix] -> list[criteria:str(10个)]
    suffix2a_file = "scoring_a.jsonl"
    suffix2b_file = "scoring_b.jsonl"
    suffix2rawa_file = "scoring_raw_a.jsonl"
    suffix2rawb_file = "scoring_raw_b.jsonl"

    suffix2b = {}
    for line in open(ROOT + suffix2b_file).readlines():
        json_item = json.loads(line)
        suffix2b[json_item['suffix']] = json_item

    suffix2rawa = {}
    for line in open(ROOT + suffix2rawa_file).readlines():
        json_item = json.loads(line)
        suffix2rawa[json_item['suffix']] = json_item

    suffix2rawb = {}
    for line in open(ROOT + suffix2rawb_file).readlines():
        json_item = json.loads(line)
        suffix2rawb[json_item['suffix']] = json_item

    new_train_data = []
    acc = 0
    count = 0
    for scoring_a in open(ROOT + suffix2a_file).readlines():
        # 每一个sample 
        json_item_a = json.loads(scoring_a)
        if json_item_a['suffix'] not in suffix2b:
            continue
        
        json_item_b = suffix2b[json_item_a['suffix']]
        json_item_raw_a = suffix2rawa[json_item_a['suffix']]
        json_item_raw_b = suffix2rawb[json_item_a['suffix']]
        
        response_a_list = [a for a in json_item_a['response']]
        response_b_list = [b for b in json_item_b['response']]
        
        response_list = []
        pred_list = []
        # 给这个sample的每一个criteria打分，存在对应的response_list里面
        for idx, judge in enumerate(suffix2judge[json_item_a['suffix']]):
            correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][idx], json_item_a['answer'])
            
            if pred == "A":
                if response_a_list[idx] < response_b_list[idx]:
                    response_list.append(response_a_list[idx] - response_b_list[idx])
                else:
                    response_list.append(response_b_list[idx] - json_item_raw_b["response"][0])
            else:
                if response_a_list[idx] > response_b_list[idx]:
                    response_list.append(response_b_list[idx] - response_a_list[idx])
                else:
                    response_list.append(response_a_list[idx] - json_item_raw_a["response"][0])
        
        index = sorted(range(len(response_list)), key=lambda i: response_list[i], reverse=True)
        # correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][j], json_item_a['answer'])
        # correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][index[0]], json_item_a['answer'])
        yield {
            "item": json_item_a,
            "criteria": suffix2criteria[json_item_a['suffix']],
            "judge": suffix2judge[json_item_a['suffix']],
            "index": index,
        }

        # data1, data2 = build_criteria(json_item_a['prompt'], json_item_a['chosen'], json_item_a['rejected'], suffix2criteria[json_item_a['suffix']], index)

        # # ?
        # if "audio" not in json.dumps(data1):
        #     new_train_data.append(json.dumps(data1))
        # if "audio" not in json.dumps(data2):
        #     new_train_data.append(json.dumps(data2))
        
        # data = build_judgment(json_item_a['prompt'], json_item_a['chosen'], json_item_a['rejected'], suffix2criteria[json_item_a['suffix']], suffix2judge[json_item_a['suffix']], json_item_a['answer'], index)
        # if "audio" not in json.dumps(data):
        #     new_train_data.append(json.dumps(data))
        # if correct:
        #     acc += 1
        # count += 1


def extract(text: str):
    results = {}

    match_a = re.search(r"<judge A>(.*?)</judge A>", text, re.DOTALL)
    if match_a:
        block_a = match_a.group(1)
        results["A"] = block_a.strip()

    match_b = re.search(r"<judge B>(.*?)</judge B>", text, re.DOTALL)
    if match_b:
        block_b = match_b.group(1)
        results["B"] = block_b.strip()
    assert "A" in results and "B" in results
    return results

def handle_item2cot_train_data(handle_item:dict):
    ...
    json_item = handle_item['item']
    criteria = handle_item['criteria']
    judge = handle_item['judge']
    index = handle_item['index']

    # set 
    answer = json_item['answer']
    query = json_item['prompt']


    ordered_criteria = []
    ordered_judge = []
    for order_idx in index:
        ordered_criteria.append(criteria[order_idx])
        ordered_judge.append(judge[order_idx])
    assert len(ordered_criteria) == len(index), "len(ordered_criteria) != len(index)"
    # assert criteria.index(ordered_criteria[0]) == index[0], "criteria.index(ordered_criteria[0]) != index[0]"

    
    filter_ordered_criteria = []
    filter_ordered_judge = []
    # filter the correct
    for idx, j in enumerate(ordered_judge):
        correct, pred = extract_omni_answer(response= j, answer= answer)
        if correct:
            if answer == 0: assert pred == 'A'
            if answer == 1: assert pred == 'B'
            filter_ordered_criteria.append(ordered_criteria[idx])
            filter_ordered_judge.append(ordered_judge[idx])
    
    assert len(filter_ordered_criteria) == len(filter_ordered_judge), "criteria and judge not equal, after filter"

    filter_ordered_criteria = [re.sub(r'^\d+\.\s*', '', t).strip() for t in filter_ordered_criteria]
    # judge 和 response A、B的位置相关，要和answer保持一致
    if answer == 0:
        response_1 = json_item['chosen']
        response_2 = json_item['rejected']
    elif answer == 1:
        response_1 = json_item['rejected']
        response_2 = json_item['chosen']
    
    # step 1. 
    step1_criteria = filter_ordered_criteria[:3]
    step1_judge = filter_ordered_judge[:3]
    assert len(step1_criteria) == len(step1_judge) == 3, "step 1. don't equal to 3"
    step1_makeup_criteria = []
    for c, j in zip(step1_criteria, step1_judge):
        judge_pair = extract(j)
        
        step1_makeup_criteria.append(
            {
                "criterion": c,
                "judge_A": judge_pair['A'],
                "judge_B": judge_pair['B']
            }
        )
    
    assert len(step1_makeup_criteria) == len(step1_criteria) == len(step1_judge), "step 1 all not equal."
    step1_prompt = build_judge_prompt(query= query, response_1= response_1, response_2= response_2, candidate_criteria= [])
    step1_response = build_long_judge_response(criteria= step1_makeup_criteria, answer= answer)
    step1_data = dict(
        conversations=[
            {
                "role": "user",
                "content": clear_text(step1_prompt),
            },
            {
                "role": "assistant",
                "content": clear_text(step1_response)
            }
        ]
    )
    # yield step1_data
    yield step1_data
    # import pdb; pdb.set_trace()

    # step 2.
    step2_criteria = filter_ordered_criteria[3:6]
    step2_judge = filter_ordered_judge[3:6]
    if len(step2_criteria) < 3: return None
    assert len(step2_criteria) == len(step2_judge) == 3, f"step 2. don't equal to 3. len(step2_criteria)=[{len(step2_criteria)}]"
    step2_makeup_criteria = []
    for c, j in zip(step2_criteria, step2_judge):
        judge_pair = extract(j)
        
        step2_makeup_criteria.append(
            {
                "criterion": c,
                "judge_A": judge_pair['A'],
                "judge_B": judge_pair['B']
            }
        )
    step2_prompt = build_judge_prompt(query= query, response_1= response_1, response_2= response_2, candidate_criteria= step1_criteria)
    step2_response = build_long_judge_response(criteria= step2_makeup_criteria, answer= answer)
    step2_data = dict(
        conversations=[
            {
                "role": "user",
                "content": clear_text(step2_prompt),
            },
            {
                "role": "assistant",
                "content": clear_text(step2_response)
            }
        ]
    )
    yield step2_data
    return 

            
    
def prepare_test_cot_from_previous_data():
    train_dataset = []
    ...
    m = Counter()
    for handle_item in old_order():
        step1_data = None
        step2_data = None
        g = handle_item2cot_train_data(handle_item)
        try:
            step1_data = next(g)
            step2_data = next(g)
        except Exception as e:
            pass
            # print(f"Exception: [{e}]")
        if step1_data:
            train_dataset.append(step1_data)
            m['step1_data'] += 1
        if step2_data:
            train_dataset.append(step2_data)
            m['step2_data'] += 1
        
        # as short sft data
        if random.random()>0.5:
            json_item = handle_item['item']
            query = json_item['prompt']
            chosen = json_item['chosen']
            rejected = json_item['rejected']
            answer = random.choice(["A","B"])
            if answer == "A":
                prompt = build_short_omni_prompt(query, chosen, rejected)
                response = "The final verdict is [[A]]"
            elif answer == "B":
                prompt = build_short_omni_prompt(query, rejected, chosen)
                response = "The final verdict is [[B]]"
            text_no_cot_data = dict(
                conversations=[
                    {
                        "role": "user",
                        "content": clear_text(prompt),
                    },
                    {
                        "role": "assistant",
                        "content": clear_text(response)
                    }
                ]
            )
            train_dataset.append(text_no_cot_data)
            m['text_no_cot_data'] += 1
        # if step1_data and step2_data:
            # import pdb; pdb.set_trace()
            # ...
        
    print(f"old text cot data: [{len(train_dataset)}]")
    print(f"old text cot data: Metrics: [{m}]")
    return train_dataset
def prepare_mm_no_cot_data():
    train_data = []
    for line in open("/home/export/base/ycsc_1/1/online1//GRM-Omni-v1/dataset/training/audio/audio_generation_scalar_rm.jsonl").readlines():
        json_item = json.loads(line)

        if random.random() > 0.5:
            prompt = build_short_omni_prompt(json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'])
            data = {
                "conversations": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": "The final verdict is [[B]]"}
                ],
                "audios": [json_item['rejected_audios'][0], json_item['chosen_audios'][0]]
            }
        else:
            prompt = build_short_omni_prompt(json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'])
            data = {
                "conversations": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": "The final verdict is [[A]]"}
                ],
                "audios": [json_item['chosen_audios'][0], json_item['rejected_audios'][0]]
            }
        train_data.append(json.dumps(data))
    
    for line in open("/home/export/base/ycsc_1/1/online1//GRM-Omni-v1/dataset/training/audio/audio_understanding_scalar_rm.jsonl").readlines():
        json_item = json.loads(line)
        if random.random() > 0.5:
            prompt = build_short_omni_prompt(json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'])
            data = {
                "conversations": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": "The final verdict is [[B]]"}
                ],
                "audios": json_item['audios']
            }
        else:
            prompt = build_short_omni_prompt(json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'])
            data = {
                "conversations": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": "The final verdict is [[A]]"}
                ],
                "audios": json_item['audios']
            }
        train_data.append(json.dumps(data))

    for line in open("/home/export/base/ycsc_1/1/online1//GRM-Omni-v1/dataset/training/image/image_undstanding_scalar_rm.jsonl").readlines():
        json_item = json.loads(line)

        if random.random() > 0.5:
            prompt = build_short_omni_prompt(json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'])
            data = {
                "conversations": [
                    {"role": "user", "content": "Image: <image>\n" + prompt},
                    {"role": "assistant", "content": "The final verdict is [[B]]"}
                ],
                "images": json_item['images']
            }
        else:
            prompt = build_short_omni_prompt(json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'])
            data = {
                "conversations": [
                    {"role": "user", "content": "Image: <image>\n" + prompt},
                    {"role": "assistant", "content": "The final verdict is [[A]]"}
                ],
                "images": json_item['images']
            }
        train_data.append(json.dumps(data))

    for line in open("/home/export/base/ycsc_1/1/online1//GRM-Omni-v1/dataset/training/image/image_generation_scalar_rm.jsonl").readlines():
        json_item = json.loads(line)
        
        if random.random() > 0.5:
            prompt = build_short_omni_prompt(json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'])
            data = {
                "conversations": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": "The final verdict is [[B]]"}
                ],
                "images": [json_item['rejected_images'][0], json_item['chosen_images'][0]]
            }
        else:
            prompt = build_short_omni_prompt(json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'])
            data = {
                "conversations": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": "The final verdict is [[A]]"}
                ],
                "images": [json_item['chosen_images'][0], json_item['rejected_images'][0]]
            }
        train_data.append(json.dumps(data))

    for line in open("/home/export/base/ycsc_1/1/online1//GRM-Omni-v1/dataset/training/video/video_generation_scalar_rm.jsonl").readlines():
        json_item = json.loads(line)
        
        if random.random() > 0.5:
            prompt = build_short_omni_prompt(json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'])
            data = {
                "conversations": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": "The final verdict is [[B]]"}
                ],
                "videos": [json_item['rejected_video'][0], json_item['chosen_video'][0]]
            }
        else:
            prompt = build_short_omni_prompt(json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'])
            data = {
                "conversations": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": "The final verdict is [[A]]"}
                ],
                "videos": [json_item['chosen_video'][0], json_item['rejected_video'][0]]
            }
        train_data.append(json.dumps(data))

    for line in open("/home/export/base/ycsc_1/1/online1//GRM-Omni-v1/dataset/training/video/video_understanding_scalar_rm.jsonl").readlines():
        json_item = json.loads(line)

        if random.random() > 0.5:
            prompt = build_short_omni_prompt("Video: <video>\n" + json_item['conversations'][0]['content'], json_item['rejected']['content'], json_item['chosen']['content'])
            data = {
                "conversations": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": "The final verdict is [[B]]"}
                ],
                "videos": [json_item['videos'][0]]
            }
        else:
            prompt = build_short_omni_prompt("Video: <video>\n" + json_item['conversations'][0]['content'], json_item['chosen']['content'], json_item['rejected']['content'])
            data = {
                "conversations": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": "The final verdict is [[A]]"}
                ],
                "videos": [json_item['videos'][0]]
            }
        train_data.append(json.dumps(data))

    print(len(train_data))
    random.shuffle(train_data)
    return train_data







def order_new():
    print(f"[_order_n] [_order_n] [_order_n] [_order_n]")
    ROOT = "/home/export/base/ycsc_1/1/online1///GRM-Omni-main/omni_results/0913_synthetic"

    Criteria_step = 3 # 将会有30个criteria
    Use_criteria_num = Criteria_step*3
    print(f"Criteria_step: [{Criteria_step}]")
    suffix2a = {}
    suffix2b = {}
    suffix2rawa = {}
    suffix2rawb = {}
    suffix2judge = {}
    suffix2answer = {}
    suffix2crteria = {}
    suffix2item = {}
    ranking_file = os.path.join(ROOT, "ranking.jsonl")
    judge_file = os.path.join(ROOT, "judge.jsonl")
    criteria_file = os.path.join(ROOT, "criteria_merge.jsonl")

    # read the ranking
    for line in open(ranking_file):
        json_item = json.loads(line)
        rank_situation = json_item['ranking_pair']
        # debug
        if len(rank_situation['ranking_a']) < Criteria_step*3 or len(rank_situation['ranking_b']) < Criteria_step*3:
            print(f'error length.  < [{Criteria_step*3}] ')
            continue
        sample_id = json_item['id']

        suffix2rawa[sample_id] = rank_situation['ranking_raw_a']
        suffix2rawb[sample_id] = rank_situation['ranking_raw_b']
        suffix2a[sample_id] = rank_situation['ranking_a'][:Use_criteria_num]
        # import pdb; pdb.set_trace()
        suffix2b[sample_id] = rank_situation['ranking_b'][:Use_criteria_num]
        suffix2answer[sample_id] = json_item['paired_data']['answer']
    
    # read judge
    suffix2judge_a_list = {}
    suffix2judge_b_list = {}
    for line in open(judge_file).readlines():
        json_item = json.loads(line)
        suffix2judge[json_item['id']] = json_item['judge']
        suffix2judge_a_list[json_item['id']] = json_item['judge_pair']['judge_a_list']
        suffix2judge_b_list[json_item['id']] = json_item['judge_pair']['judge_b_list']

    # load criteria
    for line in open(criteria_file):
        json_item = json.loads(line)
        
        criteria = []
        criteria += json_item['criteria']['0']
        criteria += json_item['criteria']['1']
        criteria += json_item['criteria']['2']
        suffix2crteria[json_item['id']] = criteria
        suffix2item[json_item['id']] = json_item
    
    # for j in range(Use_criteria_num):
    #     acc = 0
    #     count = 0
        # for scoring_a in open(ROOT).readlines():
    for sample_id in suffix2a.keys():
        # set here!
        json_item_raw_a = suffix2rawa[sample_id]
        json_item_raw_b = suffix2rawb[sample_id]
        response_a_list = suffix2a[sample_id]
        response_b_list = suffix2b[sample_id]
        
        response_list = []
        pred_list = []
        # 遍历 这个sample的所有的 criteria, 将其排序
        criteria_num = len(response_a_list)
        Judge = []
        Criteria = suffix2crteria[sample_id][:criteria_num]
        for idx in range(criteria_num):
            # 0, 1, 2 -> judge 0
            # 3, 4, 5 -> judge 1
            judge = suffix2judge[sample_id][idx//3]
            Judge.append(judge)

            correct, pred = extract_omni_answer(judge, suffix2answer[sample_id])
            if pred == "A":
                if response_a_list[idx] < response_b_list[idx]:
                    # response_list.append(- abs(response_a_list[idx] + response_b_list[idx] - json_item_raw_a["response"][0] - json_item_raw_b["response"][0]))
                    response_list.append(response_a_list[idx] - response_b_list[idx])
                else:
                    response_list.append(response_b_list[idx] - json_item_raw_b)
            else:
                if response_a_list[idx] > response_b_list[idx]:
                    # response_list.append(- abs(response_a_list[idx] + response_b_list[idx] - json_item_raw_a["response"][0] - json_item_raw_b["response"][0]))
                    response_list.append(response_b_list[idx] - response_a_list[idx])
                else:
                    response_list.append(response_a_list[idx] - json_item_raw_a)
        
        assert len(Judge) == len(Criteria)
        index = sorted(range(len(response_list)), key=lambda i: response_list[i], reverse=True)
        
        answer = suffix2item[sample_id]['paired_data']['answer']
        if answer == 0:
            chosen = suffix2item[sample_id]['paired_data']['response_1']['content']
            rejected = suffix2item[sample_id]['paired_data']['response_2']['content']
        elif answer == 1:
            chosen = suffix2item[sample_id]['paired_data']['response_2']['content']
            rejected = suffix2item[sample_id]['paired_data']['response_1']['content']
        item = {
            "prompt": suffix2item[sample_id]['paired_data']['query']['content'],
            "answer": suffix2item[sample_id]['paired_data']['answer'],
            "chosen": chosen,
            "rejected": rejected
        }
        yield {
            "item": item,
            "criteria": Criteria,
            "judge": Judge,
            "judge_a_list": suffix2judge_a_list[sample_id],
            "judge_b_list": suffix2judge_b_list[sample_id],
            "index": index,
        }
        # correct, pred = extract_omni_answer(suffix2judge[sample_id][j//3], suffix2answer[sample_id])
        
        # if correct:
        #     acc += 1
        # count += 1

        # score_list.append(acc/count)
        # print(acc, count, acc/count)
        
    # print(score_list)


def new_handle_item2cot_train_data(handle_item:dict):
    ...
    json_item = handle_item['item']
    criteria = handle_item['criteria']
    criteria = [c['content'] for c in criteria]
    judge = handle_item['judge']
    judge_a_list = handle_item['judge_a_list']
    judge_b_list = handle_item['judge_b_list']

    # pack 
    assert len(judge) == len(judge_a_list) == len(judge_b_list)
    judge = [(j, ja, jb) for j, ja, jb in zip(judge, judge_a_list, judge_b_list)]

    # import pdb; pdb.set_trace()

    index = handle_item['index']

    # set 
    answer = json_item['answer']
    query = json_item['prompt']


    ordered_criteria = []
    ordered_judge = []
    for order_idx in index:
        ordered_criteria.append(criteria[order_idx])
        ordered_judge.append(judge[order_idx])
    assert len(ordered_criteria) == len(index), "len(ordered_criteria) != len(index)"
    # try:
    #     assert criteria.index(ordered_criteria[0]) == index[0], "criteria.index(ordered_criteria[0]) != index[0]"
    # except AssertionError as e:
    #     import pdb; pdb.set_trace()
    # import pdb; pdb.set_trace()
    
    filter_ordered_criteria = []
    filter_ordered_judge = []
    # filter the correct
    for idx, j in enumerate(ordered_judge):
        correct, pred = extract_omni_answer(response= j[0], answer= answer)
        if correct:
            if answer == 0: assert pred == 'A'
            if answer == 1: assert pred == 'B'
            filter_ordered_criteria.append(ordered_criteria[idx])
            filter_ordered_judge.append(ordered_judge[idx])
    
    assert len(filter_ordered_criteria) == len(filter_ordered_judge), "criteria and judge not equal, after filter"
    
    # import pdb; pdb.set_trace()
    filter_ordered_criteria = [re.sub(r'^\d+\.\s*', '', t).strip() for t in filter_ordered_criteria]
    # judge 和 response A、B的位置相关，要和answer保持一致
    if answer == 0:
        response_1 = json_item['chosen']
        response_2 = json_item['rejected']
    elif answer == 1:
        response_1 = json_item['rejected']
        response_2 = json_item['chosen']
    
    # import pdb; pdb.set_trace()
    
    # step 1. 
    step1_criteria = filter_ordered_criteria[:3]
    step1_judge = filter_ordered_judge[:3]
    assert len(step1_criteria) == len(step1_judge) == 3, "step 1. don't equal to 3"
    step1_makeup_criteria = []
    for c, j in zip(step1_criteria, step1_judge):
        # judge_pair = extract(j)
        
        step1_makeup_criteria.append(
            {
                "criterion": c,
                "judge_A": j[1],
                "judge_B": j[2]
            }
        )
    
    assert len(step1_makeup_criteria) == len(step1_criteria) == len(step1_judge), "step 1 all not equal."
    step1_prompt = build_judge_prompt(query= query, response_1= response_1, response_2= response_2, candidate_criteria= [])
    step1_response = build_long_judge_response(criteria= step1_makeup_criteria, answer= answer)
    step1_data = dict(
        conversations=[
            {
                "role": "user",
                "content": clear_text(step1_prompt),
            },
            {
                "role": "assistant",
                "content": clear_text(step1_response)
            }
        ]
    )
    ...
    # yield step1_data
    # import pdb; pdb.set_trace()
    yield step1_data
    # import pdb; pdb.set_trace()

    # step 2.
    step2_criteria = filter_ordered_criteria[3:6]
    step2_judge = filter_ordered_judge[3:6]
    if len(step2_criteria) < 3: return None
    assert len(step2_criteria) == len(step2_judge) == 3, f"step 2. don't equal to 3. len(step2_criteria)=[{len(step2_criteria)}]"
    step2_makeup_criteria = []
    for c, j in zip(step2_criteria, step2_judge):
        
        step2_makeup_criteria.append(
            {
                "criterion": c,
                "judge_A": j[1],
                "judge_B": j[2]
            }
        )
    step2_prompt = build_judge_prompt(query= query, response_1= response_1, response_2= response_2, candidate_criteria= step1_criteria)
    step2_response = build_long_judge_response(criteria= step2_makeup_criteria, answer= answer)
    step2_data = dict(
        conversations=[
            {
                "role": "user",
                "content": clear_text(step2_prompt),
            },
            {
                "role": "assistant",
                "content": clear_text(step2_response)
            }
        ]
    )
    yield step2_data
    return 

def prepare_text_cot_from_new():
    train_dataset = []
    ...
    m = Counter()
    for handle_item in order_new():
        step1_data = None
        step2_data = None
        g = new_handle_item2cot_train_data(handle_item)
        try:
            step1_data = next(g)
            step2_data = next(g)
        except Exception as e:
            pass
            # print(f"Exception: [{e}]")
        if step1_data:
            train_dataset.append(step1_data)
            m['step1_data'] += 1
        if step2_data:
            train_dataset.append(step2_data)
            m['step2_data'] += 1
        # if step1_data and step2_data:
        #     import pdb; pdb.set_trace()
        #     ...

    print(f"New text cot data: [{len(train_dataset)}]")
    print(f"New text cot data: Metrics: [{m}]")
    return train_dataset

if __name__ == "__main__":
    mix_dataset = []
    # build 
    # load from 0910 generate data. the previous data
    old_text_cot_data = prepare_test_cot_from_previous_data()

    # load no cot mm-data
    mm_no_cot_data = prepare_mm_no_cot_data()

    # load from the newest data generate on 0914
    newest_text_cot_data = prepare_text_cot_from_new()
    # import pdb; pdb.set_trace()


    ##########################
    # mix all data into mix_dataset
    mix_dataset += old_text_cot_data
    mix_dataset += newest_text_cot_data
    mix_dataset += mm_no_cot_data[:100_000]
    random.shuffle(mix_dataset)
    print(f"Total dataset: [{len(mix_dataset)}]")

    # write
    with open("/home/export/base/ycsc_1/1/online1//GRM-Omni-v1/data/grm_lang/grm_lang_sft.jsonl", 'w') as fw:
        for data in mix_dataset:
            if not isinstance(data, str):
                data = json.dumps(data)
            fw.write(data + "\n")
    
    