import re
import json
import random
import itertools
import statistics


def build_long_omni_prompt(prompt, chosen, rejected):
    
    prompt = """You are a fair, professional, and neutral multimodal AI evaluator.  \nYou are tasked with evaluating two different multimodal responses (which may include Text, Image, Video, and Audio) generated for the same user query, and determining which one is better.   \nBased on the overall analysis, clearly determine which response is superior.   \n\n\n###Important Notes:\nStay completely neutral: Do not be influenced by the order, length, writing style, or the assistant’s name.       \nDo not favor responses simply because they are longer or use more elaborate language.   \n\n\n###Final Output Format\nAfter completing the full analysis, you must output the final verdict in the following format:\nIf you believe Assistant A’s response is better, output: [[A]]\nIf you believe Assistant B’s response is better, output: [[B]]. \n\n\n###Input Format\n[User Question]\n{instruction}\n\n[The Start of Assistant A's Answer]\n{response_1}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{response_2}\n[The End of Assistant B's Answer]\n""".format(
        instruction=prompt,
        response_1=chosen,
        response_2=rejected,
    )
    return prompt


def extract_omni_answer(response, answer):

    pred = None
    pattern = r'\[{1,2}([abcABC])\]{1,2}'
    matches = re.findall(pattern, response)
    if len(matches) > 0:
        pred = matches[-1]
    else:
        pred = None
    if isinstance(pred, str):
        pred = pred.upper()
    
    correct = False
    if pred == 'A' and answer == 0:
        correct = True
    elif pred == 'B' and answer == 1:
        correct = True
    elif pred == 'C' and answer == 2:
        correct = True
    
    return correct, pred

def best_of_n():

    ROOT = "/home/export/base/ycsc_1/1/online1//GRM-Omni/data/trainset/language/criteria/chat/"

    suffix2criteria = {}
    for line in open(ROOT + "criteria.jsonl").readlines():
        json_item = json.loads(line)
        suffix2criteria[json_item['suffix']] = json_item['criteria_list']
    
    suffix2judge= {}
    for line in open(ROOT + "refinment_judge.jsonl").readlines():
        json_item = json.loads(line)
        suffix2judge[json_item['suffix']] = json_item['judge_list']

    acc = 0
    count = 0
    for line in open(ROOT + "scoring.jsonl").readlines():
        json_item = json.loads(line)
        index = sorted(range(len(json_item['response'])), key=lambda i: json_item['response'][i], reverse=True)

        pred_list = []

        # for i in range(9):
        for i in index[:9]:
            correct, pred = extract_omni_answer(suffix2judge[json_item['suffix']][i], json_item['answer'])
            pred_list.append(pred)
        
        count_a = pred_list.count("A")
        count_b = pred_list.count("B")

        if count_a > count_b and json_item['answer'] == 0:
            acc += 1
        elif count_a < count_b and json_item['answer'] == 1:
            acc += 1

        # correct, pred = extract_omni_answer(suffix2judge[json_item['suffix']][j], json_item['answer'])
        # correct, pred = extract_omni_answer(suffix2judge[json_item['suffix']][index[j]], json_item['answer'])
        # if correct:
        #     acc += 1

        count += 1

    print(acc, count, acc/count)

# best_of_n()

# from transformers import AutoProcessor, AutoTokenizer, pipeline
# tokenizer = AutoTokenizer.from_pretrained("/home/export/base/ycsc_1/1/online1/hf_models/Qwen3-4B")


def build_criteria_sft():

    acc = 0
    count = 0
    paired_data = []
    all_score = []
    mean_score = []
    # ROOT = "/home/export/base/ycsc_1/1/online1//GRM-Omni/data/trainset/language/criteria/language_5k/"
    # 
    for ROOT in ["/home/export/base/ycsc_1/1/online1//GRM-Omni/data/trainset/language/criteria/language_5k/", "/home/export/base/ycsc_1/1/online1//GRM-Omni/data/trainset/language/criteria/language_10k/", "/home/export/base/ycsc_1/1/online1//GRM-Omni/data/trainset/language/criteria/language_20k/", "/home/export/base/ycsc_1/1/online1//GRM-Omni/data/trainset/language/criteria/"]:
        
        def clear_criteria(c):
            r = re.sub(r'^(10|[1-9])[.:、\s-]*', '', c)
            return r

        suffix2criteria = {}
        for line in open(ROOT + "criteria.jsonl").readlines():
            json_item = json.loads(line)
            suffix2criteria[json_item['suffix']] = []
            for c in json_item['criteria_list']:
                suffix2criteria[json_item['suffix']].append(clear_criteria(c))
        
        print("suffix2criteria", len(suffix2criteria))
        
        error = 0
        suffix2judge= {}
        suffix2prompt= {}
        for line in open(ROOT + "refinment_judge.jsonl").readlines():
            json_item = json.loads(line)
            suffix2judge[json_item['suffix']] = json_item['judge_list']
            for j in json_item['judge_list']:
                if "[[A]]" not in j and "[[B]]" not in j:
                    error += 1

        print("suffix2judge", len(suffix2judge), error)

        for line in open(ROOT + "scoring.jsonl").readlines():
            json_item = json.loads(line)
            mean = sum((json_item['response'])) / 10
            index = sorted(range(len(json_item['response'])), key=lambda i: json_item['response'][i], reverse=True)
            best_response = worst_response = None
            
            correct, pred = extract_omni_answer(suffix2judge[json_item['suffix']][index[0]], json_item['answer'])
            if correct:
                acc += 1
            count += 1

            score_list = []
            best_list = []
            for i in index:
                
                all_score.append(float(json_item['response'][index[i]]))
                correct, pred = extract_omni_answer(suffix2judge[json_item['suffix']][index[i]], json_item['answer'])
                if correct and pred is not None and pred in ["A", "B"]: 
                    best_response = suffix2criteria[json_item['suffix']][index[i]].strip() + "\n" + suffix2judge[json_item['suffix']][index[i]].strip()
                    best_response = best_response.replace("<audio>", "audio").replace("<image>", "image").replace("<video>", "video")
                    if "\n**Final Decision: [[A]]**" in best_response or "\n**Final Decision: [[B]]**" in best_response:
                        best_response = best_response.replace("\n**Final Decision: [[A]]**", "").replace("\n**Final Decision: [[B]]**", "")
                    elif "$$\n\\boxed{[[A]]}\n$$" in best_response or "$$\n\\boxed{[[B]]}\n$$" in best_response:
                        best_response = best_response.replace("$$\n\\boxed{[[A]]}\n$$", "").replace("$$\n\\boxed{[[B]]}\n$$", "")
                    elif "### [[A]]" in best_response or "### [[B]]" in best_response:
                        best_response = best_response.replace("### [[A]]", "").replace("### [[B]]", "")
                    elif "**Better Response**: **[[A]]**" in best_response or "**Better Response**: **[[B]]**" in best_response:
                        best_response = best_response.replace("**Better Response**: **[[A]]**", "").replace("**Better Response**: **[[B]]**", "")
                    elif "$$\n\\text{[[A]]}\n$$" in best_response or "$$\n\\text{[[B]]}\n$$" in best_response:
                        best_response = best_response.replace("$$\n\\text{[[A]]}\n$$", "").replace("$$\n\\text{[[B]]}\n$$", "")
                    elif "*Better Response: [[A]]**" in best_response or "*Better Response: [[B]]**" in best_response:
                        best_response = best_response.replace("*Better Response: [[A]]**", "").replace("*Better Response: [[B]]**", "")
                    elif  "**Better Response**: [[A]]" in best_response or "**Better Response**: [[B]]" in best_response:
                        best_response = best_response.replace("**Better Response**: [[A]]", "").replace("**Better Response**: [[B]]", "")
                    elif  "$$\n\\boxed{\\text{[[A]]}}\n$$" in best_response or "$$\n\\boxed{\\text{[[B]]}}\n$$" in best_response:
                        best_response = best_response.replace("$$\n\\boxed{\\text{[[A]]}}\n$$", "").replace("$$\n\\boxed{\\text{[[B]]}}\n$$", "")
                    elif  "$$[A]$$" in best_response or "$$[B]$$" in best_response:
                        best_response = best_response.replace("$$[A]$$", "").replace("$$[B]$$", "")
                    elif  "**Final Answer: [[A]]**" in best_response or "**Final Answer: [[B]]**" in best_response:
                        best_response = best_response.replace("**Final Answer: [[A]]**", "").replace("**Final Answer: [[B]]**", "")
                    elif  "**Final Decision**: [[A]]" in best_response or "**Final Decision**: [[B]]" in best_response:
                        best_response = best_response.replace("**Final Decision**: [[A]]", "").replace("**Final Decision**: [[B]]", "")
                    elif "**Final decision: [[A]]**" in best_response or "**Final decision: [[B]]**" in best_response:
                         best_response = best_response.replace("**Final decision: [[A]]**", "").replace("**Final decision: [[B]]**", "")
                    elif "**Final Judgment: [[A]]**" in best_response or "**Final Judgment: [[B]]**" in best_response:
                         best_response = best_response.replace("**Final Judgment: [[A]]**", "").replace("**Final Judgment: [[B]]**", "")
                    elif "**Final Tag**: `[[A]]`" in best_response or "**Final Tag**: `[[B]]`" in best_response:
                         best_response = best_response.replace("**Final Tag**: `[[A]]`", "").replace("**Final Tag**: `[[B]]`", "")
                    elif "Assistant A: [[A]]" in best_response or "Assistant A: [[B]]" in best_response:
                         best_response = best_response.replace("Assistant A: [[A]]", "").replace("Assistant A: [[B]]", "")
                    elif "Final Answer: [[A]]" in best_response or "Final Answer: [[B]]" in best_response:
                         best_response = best_response.replace("Final Answer: [[A]]", "").replace("Final Answer: [[B]]", "")
                    elif "\n[[A]]" in best_response or "\n[[B]]" in best_response or "\n`[[A]]`" in best_response or "\n`[[B]]`" in best_response or "\n**[[A]]**" in best_response or "\n**[[B]]**" in best_response:
                        best_response = best_response.replace("\n[[A]]", "").replace("\n[[B]]", "").replace("\n`[[A]]`", "").replace("\n`[[B]]`", "").replace("\n**[[A]]**", "").replace("\n**[[B]]**", "")
                    elif "[[A]] and [[B]] are equally valid." in best_response:
                        continue
                    else:
                        continue
                    
                    if float(json_item['response'][index[i]]) - mean > (1 - 1 / ( 1 + len(best_list)) ):
                        best_list.append(best_response.strip())
            
            if json_item['answer'] == 0:
                prompt = build_long_omni_prompt(json_item['prompt'], json_item['chosen'], json_item['rejected'])
            else:
                prompt = build_long_omni_prompt(json_item['prompt'], json_item['rejected'], json_item['chosen'])
            prompt = prompt.replace("<audio>", "audio").replace("<image>", "image").replace("<video>", "video")
            
            if len(best_list) > 2:
                
                if json_item['answer'] == 0:
                    response = "\n".join(best_list[:3]).strip() + "\nThe Final Judgment is [[A]]"
                else:
                    response = "\n".join(best_list[:3]).strip() + "\nThe Final Judgment is [[B]]"
                if len(re.findall(r'\[\[A\]\]', response)) != 1 and len(re.findall(r'\[\[B\]\]', response)) != 1:
                    continue
                
                response.replace("<audio>", "audio").replace("<image>", "image").replace("<video>", "video")
                paired_data.append({
                    "conversations": [
                        {
                            "role": "user",
                            "content": prompt
                        }, 
                        {
                            "role": "assistant",
                            "content": response
                        }
                    ]
                })

    print(acc, count, acc/count)
    # print(sum(mean_score)/len(mean_score), max(mean_score), min(mean_score))
    random.shuffle(paired_data)
    print(len(paired_data))
    c = 0
    with open("/home/export/base/ycsc_1/1/online1//GRM-Omni/data/trainset/language/meta_criteria_sft.jsonl", 'w') as fw:
        for d in paired_data:
            c += 1
            fw.writelines(json.dumps(d) + "\n")
    print(c)


build_criteria_sft()

def build_criteria_cold_start_data():

    # ROOT = "/home/export/base/ycsc_1/1/online1//GRM-Omni/data/trainset/language/criteria/language_5k/"
    # ROOT = "/home/export/base/ycsc_1/1/online1//GRM-Omni/data/trainset/language/criteria/"
    acc = 0
    count = 0
    paired_data = []
    paired_score = []
    all_score = []

    for ROOT in ["/home/export/base/ycsc_1/1/online1//GRM-Omni/data/trainset/language/criteria/language_5k/", "/home/export/base/ycsc_1/1/online1//GRM-Omni/data/trainset/language/criteria/language_10k/", "/home/export/base/ycsc_1/1/online1//GRM-Omni/data/trainset/language/criteria/"]:

        def clear_criteria(c):
            r = re.sub(r'^(10|[1-9])[.:、\s-]*', '', c)
            return r

        suffix2criteria = {}
        for line in open(ROOT + "criteria.jsonl").readlines():
            json_item = json.loads(line)
            suffix2criteria[json_item['suffix']] = []
            for c in json_item['criteria_list']:
                suffix2criteria[json_item['suffix']].append(clear_criteria(c))
        
        print("suffix2criteria", len(suffix2criteria))
        
        error = 0
        suffix2judge= {}
        suffix2prompt= {}
        for line in open(ROOT + "refinment_judge.jsonl").readlines():
            json_item = json.loads(line)
            suffix2judge[json_item['suffix']] = json_item['judge_list']
            for j in json_item['judge_list']:
                # assert "[[A]]" in j or "[[B]]" in j
                if "[[A]]" not in j and "[[B]]" not in j:
                    error += 1

        print("suffix2judge", len(suffix2judge), error)

        for line in open(ROOT + "scoring.jsonl").readlines():
            json_item = json.loads(line)
            
            index = sorted(range(len(json_item['response'])), key=lambda i: json_item['response'][i], reverse=True)
            best_response = worst_response = None
            
            correct, pred = extract_omni_answer(suffix2judge[json_item['suffix']][index[0]], json_item['answer'])
            if correct:
                acc += 1
            count += 1

            best_list = []
            best_score = []
            worst_list = []
            worst_score = []
            for i in index:
                all_score.append(float(json_item['response'][index[i]]))
                correct, pred = extract_omni_answer(suffix2judge[json_item['suffix']][index[i]], json_item['answer'])
                if correct and pred is not None and pred in ["A", "B"]: 
                    best_response = suffix2criteria[json_item['suffix']][index[i]].strip() + "\n" + suffix2judge[json_item['suffix']][index[i]].strip()
                    best_response = best_response.replace("<audio>", "audio").replace("<image>", "image").replace("<video>", "video")
                    best_list.append(best_response)
                    best_score.append(float(json_item['response'][index[i]]))
                    
                if (not correct) and (pred is not None) and pred in ["A", "B"]: 
                    best_response = suffix2criteria[json_item['suffix']][index[i]].strip() + "\n" + suffix2judge[json_item['suffix']][index[i]].strip()
                    best_response = best_response.replace("<audio>", "audio").replace("<image>", "image").replace("<video>", "video")
                    worst_list.append(best_response)
                    worst_score.append(float(json_item['response'][index[i]]))
            
            if json_item['answer'] == 0:
                prompt = build_long_omni_prompt(json_item['prompt'], json_item['chosen'], json_item['rejected'])
            else:
                prompt = build_long_omni_prompt(json_item['prompt'], json_item['rejected'], json_item['chosen'])
            prompt = prompt.replace("<audio>", "audio").replace("<image>", "image").replace("<video>", "video")

            if len(best_list) > 1:
                paired_data.append({
                    "conversations": [
                        {
                            "role": "user",
                            "content": prompt
                        }
                    ],
                    "chosen": {
                        "role": "assistant",
                        "content": best_list[0]
                    },
                    "rejected": {
                        "role": "assistant",
                        "content": best_list[-1]
                    }
                })
                paired_score.append(best_score[0] - best_score[-1])
            
            if len(best_list) > 0 and len(worst_list) > 1 and (best_score[0] - worst_score[0]) > 0:
                paired_data.append({
                    "conversations": [
                        {
                            "role": "user",
                            "content": prompt
                        }
                    ],
                    "chosen": {
                        "role": "assistant",
                        "content": best_list[0]
                    },
                    "rejected": {
                        "role": "assistant",
                        "content": worst_list[-1]
                    }
                })
                paired_score.append(best_score[0] - worst_score[-1])
                paired_data.append({
                    "conversations": [
                        {
                            "role": "user",
                            "content": prompt
                        }
                    ],
                    "chosen": {
                        "role": "assistant",
                        "content": best_list[0]
                    },
                    "rejected": {
                        "role": "assistant",
                        "content": worst_list[0]
                    }
                })
                paired_score.append(best_score[0] - worst_score[0])

    mean = statistics.mean(paired_score)
    variance = statistics.variance(paired_score)
    maximum = max(paired_score)
    minimum = min(paired_score)

    print("Mean:", mean)
    print("Variance:", variance)
    print("Max:", maximum)
    print("Min:", minimum)
    
    mean = statistics.mean(all_score)
    variance = statistics.variance(all_score)
    maximum = max(all_score)
    minimum = min(all_score)

    print("All_Score Mean:", mean)
    print("All_Score Variance:", variance)
    print("All_Score Max:", maximum)
    print("All_Score Min:", minimum)

    print(acc, count, acc/count)
    random.shuffle(paired_data)
    print(len(paired_data))
    c = 0
    
    with open("/home/export/base/ycsc_1/1/online1//GRM-Omni/data/trainset/language/meta_criteria.jsonl", 'w') as fw:
        for d in paired_data:
            # if tokenizer(d['conversations'][0]['content'], return_tensors="pt")['input_ids'].size(1) + tokenizer(d['chosen']['content'], return_tensors="pt")['input_ids'].size(1) > 4096 or tokenizer(d['conversations'][0]['content'], return_tensors="pt")['input_ids'].size(1) + tokenizer(d['rejected']['content'], return_tensors="pt")['input_ids'].size(1) > 4096:
            #     continue
            c += 1
            fw.writelines(json.dumps(d) + "\n")
    print(c)

# build_criteria_cold_start_data()
