import json
import re
import math
import os

def extract_omni_answer(response, answer):

    pred = None
    pattern = r'\[{1,2}([abcABC])\]{1,2}'
    matches = re.findall(pattern, response)
    if len(matches) > 0:
        pred = matches[-1]
    else:
        pred = None
    if isinstance(pred, str):
        pred = pred.upper()
    
    correct = False
    if pred == 'A' and answer == 0:
        correct = True
    elif pred == 'B' and answer == 1:
        correct = True
    elif pred == 'C' and answer == 2:
        correct = True
    
    return correct, pred


def order_n():
    
    # ROOT = "/data//GRM-Omni-v1/results/safety_meta_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/image_general_criteria_reward_final2/"
    # ROOT = "/data//GRM-Omni-v1/results/chat_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/reasoning_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/safety_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/chat_hard_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/omni_results/chat_meta_reward/"
    ROOT = "/data//GRM-Omni-v1/results/chat_criteria_reward_final_v2"
    
    bound = 3

    skywork_v1 = False
    
    suffix2judge= {}
    for line in open(os.path.join(ROOT, "refinment_judge.jsonl")).readlines():
        json_item = json.loads(line)
        if "critique_list" in json_item:
            suffix2judge[json_item['suffix']] = json_item['critique_list']
        else:
            suffix2judge[json_item['suffix']] = json_item['judge_list']
    
    if skywork_v1:
        suffix2a_file = "skywork_v1_scoring_a.jsonl"
        suffix2b_file = "skywork_v1_scoring_b.jsonl"
        suffix2rawa_file = "skywork_v1_scoring_raw_a.jsonl"
        suffix2rawb_file = "skywork_v1_scoring_raw_b.jsonl"
    else:
        suffix2a_file = "scoring_a.jsonl"
        suffix2b_file = "scoring_b.jsonl"
        suffix2rawa_file = "scoring_raw_a.jsonl"
        suffix2rawb_file = "scoring_raw_b.jsonl"

    suffix2b = {}
    for line in open(ROOT + suffix2b_file).readlines():
        json_item = json.loads(line)
        suffix2b[json_item['suffix']] = json_item

    suffix2rawa = {}
    for line in open(ROOT + suffix2rawa_file).readlines():
        json_item = json.loads(line)
        suffix2rawa[json_item['suffix']] = json_item

    suffix2rawb = {}
    for line in open(ROOT + suffix2rawb_file).readlines():
        json_item = json.loads(line)
        suffix2rawb[json_item['suffix']] = json_item

    score_list = []
    for j in range(bound):
        
        # if j % 2 == 0:
        #     continue
        
        acc = 0
        count = 0
        for scoring_a in open(ROOT + suffix2a_file).readlines():
            
            json_item_a = json.loads(scoring_a)
            if json_item_a['suffix'] not in suffix2b:
                continue

            json_item_b = suffix2b[json_item_a['suffix']]
            json_item_raw_a = suffix2rawa[json_item_a['suffix']]
            json_item_raw_b = suffix2rawb[json_item_a['suffix']]
            
            response_a_list = [a for a in json_item_a['response']]
            response_b_list = [b for b in json_item_b['response']]
            
            response_list = []
            pred_list = []
            for idx, judge in enumerate(suffix2judge[json_item_a['suffix']]):
                correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][idx], json_item_a['answer'])
                
                # if response_a_list[idx] > response_b_list[idx] and pred == "A":
                #     response_list.append(1)
                # elif response_a_list[idx] < response_b_list[idx] and pred == "B":
                #     response_list.append(1)
                # else:
                #     response_list.append(0)

                # if response_a_list[idx] < response_b_list[idx]:
                #     response_list.append(response_a_list[idx] - response_b_list[idx])
                # else:
                #     response_list.append(response_b_list[idx] - response_a_list[idx])
                
                if pred == "A":
                    if response_a_list[idx] < response_b_list[idx]:
                        # response_list.append(- abs(response_a_list[idx] + response_b_list[idx] - json_item_raw_a["response"][0] - json_item_raw_b["response"][0]))
                        response_list.append(response_a_list[idx] - response_b_list[idx])
                    else:
                        response_list.append(response_b_list[idx] - json_item_raw_b["response"][0])
                else:
                    if response_a_list[idx] > response_b_list[idx]:
                        # response_list.append(- abs(response_a_list[idx] + response_b_list[idx] - json_item_raw_a["response"][0] - json_item_raw_b["response"][0]))
                        response_list.append(response_b_list[idx] - response_a_list[idx])
                    else:
                        response_list.append(response_a_list[idx] - json_item_raw_a["response"][0])
            
            index = sorted(range(len(response_list)), key=lambda i: response_list[i], reverse=True)
            correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][j], json_item_a['answer'])
            # correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][index[j]], json_item_a['answer'])
            
            if correct:
                acc += 1
            count += 1

        score_list.append(acc/count)
        print(acc, count, acc/count)
        
    print(score_list)

### 
def _order_n():
    print(f"[_order_n] [_order_n] [_order_n] [_order_n]")
    name2file = {
        "chat": "omni_results_direct_new/chat_meta_reward_v1",
        # "chat_hard": "omni_results_direct/chat_hard_meta_reward1",
        "chat_hard": "omni_results_direct/chat_hard_meta_reward",
        "safety": "omni_results_direct/safety_meta_reward",
        "reasoning": "omni_results_direct/reasoning_meta_reward"
    }
    NAME = "chat"
    # NAME = "chat_hard"
    # NAME = "safety"
    # NAME = "reasoning"
    ROOT = name2file[NAME]
    print(f"Look at: [{ROOT}] on subset: [{NAME}]")
    
    Criteria_step = 3 # 将会有30个criteria
    Use_criteria_num = Criteria_step*3
    print(f"Criteria_step: [{Criteria_step}]")
    suffix2a = {}
    suffix2b = {}
    suffix2rawa = {}
    suffix2rawb = {}
    suffix2judge = {}
    suffix2answer = {}

    ranking_file = os.path.join(ROOT, "ranking.jsonl")
    judge_file = os.path.join(ROOT, "judge.jsonl")
    # read the ranking
    for line in open(ranking_file):
        json_item = json.loads(line)
        rank_situation = json_item['ranking_pair']
        # debug
        if len(rank_situation['ranking_a']) < Criteria_step*3 or len(rank_situation['ranking_b']) < Criteria_step*3:
            print(f'error length.  < [{Criteria_step*3}] ')
            continue
        sample_id = json_item['paired_data']['id']

        suffix2rawa[sample_id] = rank_situation['ranking_raw_a']
        suffix2rawb[sample_id] = rank_situation['ranking_raw_b']
        suffix2a[sample_id] = rank_situation['ranking_a'][:Use_criteria_num]
        # import pdb; pdb.set_trace()
        suffix2b[sample_id] = rank_situation['ranking_b'][:Use_criteria_num]
        suffix2answer[sample_id] = json_item['answer']
    
    # read judge
    for line in open(judge_file).readlines():
        json_item = json.loads(line)
        suffix2judge[json_item['paired_data']['id']] = json_item['judge']
    
    score_list = []

    for j in range(Use_criteria_num):
        acc = 0
        count = 0
        # for scoring_a in open(ROOT).readlines():
        for sample_id in suffix2a.keys():
            try:
                # set here!
                json_item_raw_a = suffix2rawa[sample_id]
                json_item_raw_b = suffix2rawb[sample_id]
                response_a_list = suffix2a[sample_id]
                response_b_list = suffix2b[sample_id]
                response_list = []
                pred_list = []
                # 遍历 这个sample的所有的 criteria
                criteria_num = len(response_a_list)
                for idx in range(criteria_num):
                    # 0, 1, 2 -> judge 0
                    # 3, 4, 5 -> judge 1
                    judge = suffix2judge[sample_id][idx//3]
                    
                    correct, pred = extract_omni_answer(judge, suffix2answer[sample_id])
                    if pred == "A":
                        if response_a_list[idx] < response_b_list[idx]:
                            # response_list.append(- abs(response_a_list[idx] + response_b_list[idx] - json_item_raw_a["response"][0] - json_item_raw_b["response"][0]))
                            response_list.append(response_a_list[idx] - response_b_list[idx])
                        else:
                            response_list.append(response_b_list[idx] - json_item_raw_b)
                    else:
                        if response_a_list[idx] > response_b_list[idx]:
                            # response_list.append(- abs(response_a_list[idx] + response_b_list[idx] - json_item_raw_a["response"][0] - json_item_raw_b["response"][0]))
                            response_list.append(response_b_list[idx] - response_a_list[idx])
                        else:
                            response_list.append(response_a_list[idx] - json_item_raw_a)
                
                index = sorted(range(len(response_list)), key=lambda i: response_list[i], reverse=True)
                # correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][j], json_item_a['answer'])
                try:
                    judge_idx = index[j]//3 # 因为每3个标准共享一个 final judge
                except Exception as e:
                    import pdb; pdb.set_trace()
                    ...
                correct, pred = extract_omni_answer(suffix2judge[sample_id][judge_idx], suffix2answer[sample_id])
                # correct, pred = extract_omni_answer(suffix2judge[sample_id][j//3], suffix2answer[sample_id])
                
                if correct:
                    acc += 1
                count += 1
            except:
                continue

        score_list.append(acc/count)
        print(acc, count, acc/count)
        
    print(score_list)

def best_of_n():

    # ROOT = "/data//GRM-Omni-v1/results/chat_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/reasoning_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/safety_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/chat_hard_criteria_reward_final/"
    bound = 60

    skywork_v1 = False
    
    score_list = []
    for k in range(1, bound):
        
        suffix2judge= {}
        for line in open(ROOT + "refinment_judge.jsonl").readlines():
            json_item = json.loads(line)
            suffix2judge[json_item['suffix']] = json_item['judge_list'][:k]
        
        if skywork_v1:
            suffix2a_file = "skywork_v1_scoring_a.jsonl"
            suffix2b_file = "skywork_v1_scoring_b.jsonl"
            suffix2rawa_file = "skywork_v1_scoring_raw_a.jsonl"
            suffix2rawb_file = "skywork_v1_scoring_raw_b.jsonl"
        else:
            suffix2a_file = "scoring_a.jsonl"
            suffix2b_file = "scoring_b.jsonl"
            suffix2rawa_file = "scoring_raw_a.jsonl"
            suffix2rawb_file = "scoring_raw_b.jsonl"

        suffix2b = {}
        for line in open(ROOT + suffix2b_file).readlines():
            json_item = json.loads(line)
            suffix2b[json_item['suffix']] = json_item

        suffix2rawa = {}
        for line in open(ROOT + suffix2rawa_file).readlines():
            json_item = json.loads(line)
            suffix2rawa[json_item['suffix']] = json_item

        suffix2rawb = {}
        for line in open(ROOT + suffix2rawb_file).readlines():
            json_item = json.loads(line)
            suffix2rawb[json_item['suffix']] = json_item
            
        acc = 0
        count = 0
        for scoring_a in open(ROOT + suffix2a_file).readlines():
            
            json_item_a = json.loads(scoring_a)
            if json_item_a['suffix'] not in suffix2b:
                continue
            
            json_item_b = suffix2b[json_item_a['suffix']]
            json_item_raw_a = suffix2rawa[json_item_a['suffix']]
            json_item_raw_b = suffix2rawb[json_item_a['suffix']]

            response_a_list = [a for a in json_item_a['response']]
            response_b_list = [b for b in json_item_b['response']]

            response_list = []
            for idx, judge in enumerate(suffix2judge[json_item_a['suffix']]):
                correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][idx], json_item_a['answer'])
                
                if pred == "A":
                    if response_a_list[idx] < response_b_list[idx]:
                        # response_list.append(- abs(response_a_list[idx] + response_b_list[idx] - json_item_raw_a["response"][0] - json_item_raw_b["response"][0]))
                        response_list.append(response_a_list[idx] - response_b_list[idx])
                    else:
                        response_list.append(response_b_list[idx] - json_item_raw_b["response"][0])
                else:
                    if response_a_list[idx] > response_b_list[idx]:
                        # response_list.append(- abs(response_a_list[idx] + response_b_list[idx] - json_item_raw_a["response"][0] - json_item_raw_b["response"][0]))
                        response_list.append(response_b_list[idx] - response_a_list[idx])
                    else:
                        response_list.append(response_a_list[idx] - json_item_raw_a["response"][0])

            response_list = response_list[:k]
            index = sorted(range(len(response_list)), key=lambda i: response_list[i], reverse=True)
            correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][index[0]], json_item_a['answer'])
            
            if correct:
                acc += 1
            count += 1
        
        score_list.append(acc/count)
        print(score_list)

def voting_of_n():

    ROOT = "/data//GRM-Omni-v1/results/chat_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/reasoning_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/safety_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/chat_hard_criteria_reward_final/"
    bound = 60
    
    skywork_v1 = False
    
    if skywork_v1:
        suffix2a_file = "skywork_v1_scoring_a.jsonl"
        suffix2b_file = "skywork_v1_scoring_b.jsonl"
        suffix2rawa_file = "skywork_v1_scoring_raw_a.jsonl"
        suffix2rawb_file = "skywork_v1_scoring_raw_b.jsonl"
    else:
        suffix2a_file = "scoring_a.jsonl"
        suffix2b_file = "scoring_b.jsonl"
        suffix2rawa_file = "scoring_raw_a.jsonl"
        suffix2rawb_file = "scoring_raw_b.jsonl"

    suffix2b = {}
    for line in open(ROOT + suffix2b_file).readlines():
        json_item = json.loads(line)
        suffix2b[json_item['suffix']] = json_item

    suffix2rawa = {}
    for line in open(ROOT + suffix2rawa_file).readlines():
        json_item = json.loads(line)
        suffix2rawa[json_item['suffix']] = json_item

    suffix2rawb = {}
    for line in open(ROOT + suffix2rawb_file).readlines():
        json_item = json.loads(line)
        suffix2rawb[json_item['suffix']] = json_item

    score_list = []
    for k in range(bound):
        
        if k % 2 == 0:
            continue

        suffix2judge= {}
        for line in open(ROOT + "refinment_judge.jsonl").readlines():
            json_item = json.loads(line)
            suffix2judge[json_item['suffix']] = json_item['judge_list'][:k]

        acc = 0
        count = 0
        
        for scoring_a in open(ROOT + suffix2a_file).readlines():
            json_item_a = json.loads(scoring_a)

            json_item_b = suffix2b[json_item_a['suffix']]
            json_item_raw_a = suffix2rawa[json_item_a['suffix']]
            json_item_raw_b = suffix2rawb[json_item_a['suffix']]

            response_a_list = [a for a in json_item_a['response']]
            response_b_list = [b for b in json_item_b['response']]

            response_list = []
            for idx, judge in enumerate(suffix2judge[json_item_a['suffix']]):
                correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][idx], json_item_a['answer'])
                
                if pred == "A":
                    if response_a_list[idx] < response_b_list[idx]:
                        # response_list.append(- abs(response_a_list[idx] + response_b_list[idx] - json_item_raw_a["response"][0] - json_item_raw_b["response"][0]))
                        response_list.append(response_a_list[idx] - response_b_list[idx])
                    else:
                        response_list.append(response_b_list[idx] - json_item_raw_b["response"][0])
                else:
                    if response_a_list[idx] > response_b_list[idx]:
                        # response_list.append(- abs(response_a_list[idx] + response_b_list[idx] - json_item_raw_a["response"][0] - json_item_raw_b["response"][0]))
                        response_list.append(response_b_list[idx] - response_a_list[idx])
                    else:
                        response_list.append(response_a_list[idx] - json_item_raw_a["response"][0])
            
            index = sorted(range(len(response_list)), key=lambda i: response_list[i], reverse=True)
            
            pred_list = []
            for i in index[:k]:
            # for i in range(k):
                correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][i], json_item_a['answer'])
                pred_list.append(pred)
            
            count_a = pred_list.count("A")
            count_b = pred_list.count("B")

            if count_a > count_b and json_item_a['answer'] == 0:
                acc += 1
            elif count_a < count_b and json_item_a['answer'] == 1:
                acc += 1
            count += 1
        
        score_list.append(acc/count)
        print(score_list)


def test():

    # ROOT = "/data//GRM-Omni-v1/results/image_general_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/chat_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/reasoning_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/safety_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/chat_hard_criteria_reward_final/"
    
    suffix2a = {}
    for line in open(ROOT + "scoring_raw_a.jsonl").readlines():
        json_item = json.loads(line)
        suffix2a[json_item['suffix']] = json_item

    acc = 0
    count = 0
    for line in open(ROOT + "scoring_raw_b.jsonl").readlines():
        json_item = json.loads(line)
        a_score = suffix2a[json_item['suffix']]['response'][0]
        b_score = json_item['response'][0]
        if a_score > b_score and json_item['answer'] == 0:
            acc += 1
        elif a_score < b_score and json_item['answer'] == 1:
            acc += 1
        count += 1
    
    print(acc, count, acc/count)


def consistency_of_n():

    ROOT = "/data//GRM-Omni-v1/results/chat_hard_meta_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/image_general_criteria_reward_final2/"
    # ROOT = "/data//GRM-Omni-v1/results/chat_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/reasoning_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/safety_criteria_reward_final/"
    # ROOT = "/data//GRM-Omni-v1/results/chat_hard_criteria_reward_final/"
    
    bound = 60

    skywork_v1 = False
    
    suffix2judge= {}
    for line in open(ROOT + "refinment_judge.jsonl").readlines():
        json_item = json.loads(line)
        suffix2judge[json_item['suffix']] = json_item['critique_list']

    if skywork_v1:
        suffix2a_file = "skywork_v1_scoring_a.jsonl"
        suffix2b_file = "skywork_v1_scoring_b.jsonl"
        suffix2rawa_file = "skywork_v1_scoring_raw_a.jsonl"
        suffix2rawb_file = "skywork_v1_scoring_raw_b.jsonl"
    else:
        suffix2a_file = "scoring_a.jsonl"
        suffix2b_file = "scoring_b.jsonl"
        suffix2rawa_file = "scoring_raw_a.jsonl"
        suffix2rawb_file = "scoring_raw_b.jsonl"

    suffix2b = {}
    for line in open(ROOT + suffix2b_file).readlines():
        json_item = json.loads(line)
        suffix2b[json_item['suffix']] = json_item

    suffix2rawa = {}
    for line in open(ROOT + suffix2rawa_file).readlines():
        json_item = json.loads(line)
        suffix2rawa[json_item['suffix']] = json_item

    suffix2rawb = {}
    for line in open(ROOT + suffix2rawb_file).readlines():
        json_item = json.loads(line)
        suffix2rawb[json_item['suffix']] = json_item
        
    acc = 0
    count = 0
    for scoring_a in open(ROOT + suffix2a_file).readlines():
        
        json_item_a = json.loads(scoring_a)
        if json_item_a['suffix'] not in suffix2b:
            continue

        json_item_b = suffix2b[json_item_a['suffix']]
        json_item_raw_a = suffix2rawa[json_item_a['suffix']]
        json_item_raw_b = suffix2rawb[json_item_a['suffix']]
        
        response_a_list = [a for a in json_item_a['response']]
        response_b_list = [b for b in json_item_b['response']]
        
        response_list = []
        pred_list = []
        for idx, judge in enumerate(suffix2judge[json_item_a['suffix']]):
            correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][idx], json_item_a['answer'])
            
            if pred == "A":
                if response_a_list[idx] < response_b_list[idx]:
                    response_list.append(response_a_list[idx] - response_b_list[idx])
                else:
                    response_list.append(response_b_list[idx] - json_item_raw_b["response"][0])
            else:
                if response_a_list[idx] > response_b_list[idx]:
                    response_list.append(response_b_list[idx] - response_a_list[idx])
                else:
                    response_list.append(response_a_list[idx] - json_item_raw_a["response"][0])

            correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][idx], json_item_a['answer'])

            if pred == "A" and json_item_raw_a["response"][0] > json_item_raw_b["response"][0]:
                acc += 1
                break
            elif pred == "B" and json_item_raw_a["response"][0] < json_item_raw_b["response"][0]:
                acc += 1
                break
        count += 1

        # index = sorted(range(len(response_list)), key=lambda i: response_list[i], reverse=True)
        # correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][j], json_item_a['answer'])
        # correct, pred = extract_omni_answer(suffix2judge[json_item_a['suffix']][index[j]], json_item_a['answer'])
        
        if correct:
            acc += 1
        count += 1

    print(acc, count, acc/count)


if __name__ == '__main__':
    # _order_n()
    order_n()
    # consistency_of_n()
    # test()
    # best_of_n()
    # voting_of_n()
