import argparse
import random
import jsonlines

import numpy as np
from collections import defaultdict
from common.utils import async_http_process_requests, simple_promptify, list_to_string
from common.model_configs import config_model, config_panda, config_aliyun
from itertools import combinations, product

chat_template = """
Below are two unclear questions and the reasons that why the questions is unclear. Compare these two questions and determine for which question the identification of its unclearness is inherently more difficult. \
Consider the principle that the missing information affecting later stages of multi-step reasoning is harder to identify.

Avoid any position biases and ensure that the order in which the questions are presented does not influence your decision.
Do not favor certain names of the Questions. Be as objective as possible.

# Question A

## Question

{qa}

## Reason of Unclearness

{reason_of_unclearness_a}

# Question B

## Question

{qb}

## Reason of Unclearness

{reason_of_unclearness_b}

# Instruction
    
Output your answer by strictly following this format:
Analysis: [Your analysis about evaluating the difficulties]
Judge: "[[A]]" if the unclearness of Question A is more difficult to identify, otherwise "[[B]]"
"""


def parse_res(res):
    if "[[A]]" in res:
        return 0
    elif "[[B]]" in res:
        return 1
    else:
        return -1


def main():
    random.seed(42)
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file', type=str, required=True)
    parser.add_argument('--reverse', action='store_true')
    args = parser.parse_args()

    data = list(jsonlines.open(args.input_file))
    merge_dict = defaultdict(list)
    for item in data:
        if 'clarification' in item['answer'].lower():
            cls = True
        else:
            cls = False
        merge_dict[item['metadata']['metadata']['raw_task']].append({
            'unclear_task': item['metadata']['unclear_task'],
            'reason_of_unclearness': item['metadata']['reason_of_unclearness'],
            'cls': cls,
            'level': item['metadata']['metadata']['level'],
        })
    math_level_non_cls_q = defaultdict(list)
    for k, vs in merge_dict.items():
        if len(vs) == 1:
            continue
        if all([v['cls'] for v in vs]):
            continue
        non_cls_set = [v for v in vs if not v['cls']]
        assert len(set([v['level'] for v in vs])) == 1
        math_level_non_cls_q[vs[0]['level']].extend(non_cls_set)
    keys = list(math_level_non_cls_q.keys())
    keys = sorted(keys)
    key_tuples = list(combinations(keys, 2))
    prompts, labels, metadatas = [], [], []
    for key_tuple in key_tuples:
        easy_key, hard_key = key_tuple
        easy_qs, hard_qs = math_level_non_cls_q[easy_key], math_level_non_cls_q[hard_key]
        q_tuples = list(product(easy_qs, hard_qs))
        random.shuffle(q_tuples)
        q_tuples = q_tuples[:10]
        for (easy_q, hard_q) in q_tuples:
            if not args.reverse:
                prompt = chat_template.format(
                    qa=easy_q['unclear_task'], reason_of_unclearness_a=easy_q['reason_of_unclearness'],
                    qb=hard_q['unclear_task'], reason_of_unclearness_b=hard_q['reason_of_unclearness']
                )
                label = 1
            else:
                prompt = chat_template.format(
                    qa=hard_q['unclear_task'], reason_of_unclearness_a=hard_q['reason_of_unclearness'],
                    qb=easy_q['unclear_task'], reason_of_unclearness_b=easy_q['reason_of_unclearness']
                )
                label = 0
            prompts.append(prompt)
            labels.append(label)
            metadatas.append(f"{easy_q['level']}-{hard_q['level']}")
    model_config = config_model(config_aliyun, 'deepseek-r1', 0.6, 10)
    requests = simple_promptify(prompts)
    responses = async_http_process_requests(requests, model_config)
    data_to_save = [{
        'diff_res': parse_res(res[0]),
        'diff_evaluate': res[0],
        'label': label,
        'metadata': metadata,
    } for res, label, metadata in zip(responses, labels, metadatas)]
    acc_level = defaultdict(list)
    for item in data_to_save:
        print(item)
        acc_level[item['metadata']].append(item['label'] == item['diff_res'])
    acc_level = {k:np.mean(v) for k, v in acc_level.items()}
    for key, value in acc_level.items():
        print(f'Level: {key}, Accuracy: {value*100:.2f}%')


if __name__ == '__main__':
    main()
