import argparse
import random
import jsonlines

import numpy as np
from collections import defaultdict
from common.utils import async_http_process_requests, simple_promptify, list_to_string
from common.model_configs import config_model, config_aliyun
from itertools import product

chat_template = """
Below are two unclear questions and the respective reasons for their unclearness. Compare these two questions and determine for which question the identification of its unclearness is inherently more difficult.
Factors to consider in your evaluation include the degree to which the unclearness impedes the answering process and the principle that unclearness affecting later stages of multi-step reasoning are typically harder to pinpoint.

# Question A

## Unclear Question

{qa}

## Reason of Unclearness

{reason_of_unclearness_a}

# Question B

## Unclear Question

{qb}

## Reason of Unclearness

{reason_of_unclearness_b}

# Instruction

Output your answer by strictly following this format:
Analysis: [Your analysis about evaluating the difficulties]
Judge: "[[A]]" if the unclearness of Question A is more difficult to identify, otherwise "[[B]]"
"""


def parse_res(res):
    if "[[A]]" in res:
        return 0
    elif "[[B]]" in res:
        return 1
    else:
        return -1


def main():
    random.seed(42)
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file', type=str, required=True)
    args = parser.parse_args()

    data = list(jsonlines.open(args.input_file))
    merge_dict = defaultdict(list)
    for item in data:
        if 'clarification' in item['answer'].lower():
            cls = True
        else:
            cls = False
        merge_dict[item['metadata']['metadata']['raw_task']].append({
            'unclear_task': item['metadata']['unclear_task'],
            'reason_of_unclearness': item['metadata']['reason_of_unclearness'],
            'cls': cls,
            'level': item['metadata']['metadata']['level'],
        })
    prompts, labels, metadatas = [], [], []
    for k, vs in merge_dict.items():
        if len(vs) == 1:
            continue
        if all([v['cls'] for v in vs]) or all([not v['cls'] for v in vs]):
            continue
        cls_set = [v for v in vs if v['cls']]
        non_cls_set = [v for v in vs if not v['cls']]
        all_tuples = list(product(cls_set, non_cls_set))
        random.shuffle(all_tuples)
        all_tuples = all_tuples[:5]
        for (cls_q, non_cls_q) in all_tuples:
            if random.random() < 0.5:
                prompt = chat_template.format(
                    qa=cls_q['unclear_task'], reason_of_unclearness_a=cls_q['reason_of_unclearness'],
                    qb=non_cls_q['unclear_task'], reason_of_unclearness_b=non_cls_q['reason_of_unclearness']
                )
                label = 1
            else:
                prompt = chat_template.format(
                    qa=non_cls_q['unclear_task'], reason_of_unclearness_a=non_cls_q['reason_of_unclearness'],
                    qb=cls_q['unclear_task'], reason_of_unclearness_b=cls_q['reason_of_unclearness']
                )
                label = 0
            metadatas.append([cls_q, non_cls_q, label])
            prompts.append(prompt)
            labels.append(label)
    print(prompts[0])
    input()
    model_config = config_model(config_aliyun, 'deepseek-r1', 0.6, 250)
    requests = simple_promptify(prompts)
    responses = async_http_process_requests(requests, model_config)
    data_to_save = [{
        'diff_res': parse_res(res[0]),
        'diff_evaluate': res[0],
        'label': label,
        'metadata': metadata,
    } for res, label, metadata in zip(responses, labels, metadatas)]
    acc = np.mean([_['diff_res'] == _['label'] for _ in data_to_save])
    print(f'Accuracy: {acc*100:.2f}%')


if __name__ == '__main__':
    main()
