import argparse
import json
import os
from tqdm import tqdm
import tools
import math

import random
random.seed(42)
parser = argparse.ArgumentParser()
parser.add_argument('--train_model', type=str, required=True,
                    choices=['models--deepseek-ai--DeepSeek-R1-0528-Qwen3-8B', 'models--Qwen--Qwen3-8B', 'models--deepseek-ai--DeepSeek-R1-Distill-Llama-8B', "models--meta-llama--Llama-3.1-8B-Instruct"])
parser.add_argument('--subset', type=str, required=True,
                    choices=["train_ANS", "train_NOANS", "combine"],)
parser.add_argument('--traindataset', type=str, required=True,
                    choices=["datasets--allenai--reward-bench-2", "datasets--RUC-NLPIR--FlashRAG_datasets@hotpotqa_RAG"],)
parser.add_argument('--score_name', type=str, required=True)


args = parser.parse_args()


TRAIN_MODEL = args.train_model
print(TRAIN_MODEL)
EACH_SAMPLE_NUM = 1
SAMPLE_NUM = 20
save_traindataset = ""
if args.traindataset == "datasets--allenai--reward-bench-2":
    pass
elif args.traindataset == "datasets--RUC-NLPIR--FlashRAG_datasets@hotpotqa_RAG":
    save_traindataset = "RAG"


def zscore(data_points, section="score"):
    scores = [d[section] for d in data_points]
    mean = sum(scores) / len(scores)
    std = math.sqrt(sum((x - mean) ** 2 for x in scores) / len(scores))

    if std == 0:
        zscores = [0 for _ in scores]
    else:
        zscores = [(x - mean) / std for x in scores]

    new_data = []
    for d, z in zip(data_points, zscores):
        new_d = d.copy()
        new_d[section] = z
        new_data.append(new_d)
    return new_data


def construct_partial_orders(data_points, method):

    n = len(data_points)
    pairs = []

    for i in range(n):
        for j in range(n):
            if i == j:
                continue

            score_i = data_points[i]['score']
            score_j = data_points[j]['score']
            diff = score_i - score_j
            pairs.append((data_points[i], data_points[j], diff))

    if "T+T" in method:
        pairs = [p for p in pairs if p[0]
                 ['correctness'][0] and p[1]['correctness'][0]]
    elif "T+F" in method:
        pairs = [p for p in pairs if p[0]
                 ['correctness'][0] and not p[1]['correctness'][0]]
    elif "F+F" in method:
        pairs = [p for p in pairs if not p[0]
                 ['correctness'][0] and not p[1]['correctness'][0]]
    elif "any" in method:
        pass

    selected = []

    if "RLVR" in method:
        random.shuffle(pairs)
    else:
        pairs.sort(key=lambda x: x[2], reverse=True)

    selected = pairs

    selected = selected[:EACH_SAMPLE_NUM]
    return [(p[0], p[1]) for p in selected]


def parse_weight_expr(expr: str):

    parts = expr.split('+')
    result = []
    for part in parts:
        w_str, name = part.split('*', 1)
        weight = float(w_str.strip())
        score_name = name.strip()
        result.append([weight, score_name])
    return result


def get_all_pairs(json_list):
    all_data = []
    for json_name in json_list:
        with open(json_name, 'r', encoding='utf-8') as f:
            data = json.load(f)
            tmp = {
                'name': json_name,
                'correctness': data[f'correctness']
            }
            if args.score_name == "RLVR":
                tmp['score'] = [0 for _ in range(len(data['correctness']))]
            elif args.score_name in data.keys():
                tmp['score'] = data[args.score_name]
            else:
                all_weight_score = parse_weight_expr(args.score_name)
                tmp['score'] = [None for _ in range(len(data['correctness']))]
                for weight, score_name in all_weight_score:
                    tmp[score_name] = {
                        "data": data[score_name], "weight": weight}
            all_data.append(tmp)

    N = len(all_data[0]['correctness'])
    result = {each: []
              for each in tools.METHODS if args.score_name == each.split('|')[0]}
    print(result)
    for i in tqdm(range(N)):
        data_point = []
        if "+" in args.score_name:
            all_weight_score = parse_weight_expr(args.score_name)
            data_point = []
            for json_idx in range(len(json_list)):
                point = {
                    'name': all_data[json_idx]['name'],
                    'correctness': all_data[json_idx]['correctness'][i]
                }
                for weight, score_name in all_weight_score:
                    point[score_name] = all_data[json_idx][score_name]['data'][i]
                data_point.append(point)
            for weight, score_name in all_weight_score:
                data_point = zscore(data_point, score_name)
            for json_idx in range(len(json_list)):
                combined_score = 0.0
                for weight, score_name in all_weight_score:
                    combined_score += weight * data_point[json_idx][score_name]
                all_data[json_idx]['score'][i] = combined_score

        data_point = [{'name': all_data[json_idx]['name'], 'score': all_data[json_idx]['score'][i],
                       'correctness': all_data[json_idx]['correctness'][i]} for json_idx in range(len(json_list))]
        all_pairs = {}
        for each in result.keys():
            all_pairs[each] = construct_partial_orders(data_point, each)
            for (high, low) in all_pairs[each]:
                result[each].append((
                    high['name'], low['name'], i,
                    high['score'], low['score'],
                    high['correctness'], low['correctness'],
                ))
    for each in result.keys():
        print(f"Method {each} found {len(result[each])} pairs",
              f"winrate is {sum([p[5][0] for p in result[each]])/len(result[each]):.4f}")
    return result


outputname2id = {}
evalname2id = {}

if args.subset in ["train_ANS", "train_NOANS"]:
    for i in range(SAMPLE_NUM):
        output_name = f"{tools.machine_pather()}/works/DPO/output/{args.traindataset}/{args.subset}/{TRAIN_MODEL}/output_{i}.jsonl"
        outputname2id[output_name] = i

        eval_name = f"{tools.machine_pather()}/works/DPO/output/{args.traindataset}/{args.subset}/{TRAIN_MODEL}/output_{i}_eval.json"
        evalname2id[eval_name] = i

elif args.subset == "combine":

    for i in range(SAMPLE_NUM):
        output_name = f"{tools.machine_pather()}/works/DPO/output/{args.traindataset}/train_ANS/{TRAIN_MODEL}/output_{i}.jsonl"
        outputname2id[output_name] = i

        eval_name = f"{tools.machine_pather()}/works/DPO/output/{args.traindataset}/train_ANS/{TRAIN_MODEL}/output_{i}_eval.json"
        evalname2id[eval_name] = i

    for i in range(SAMPLE_NUM):
        output_name = f"{tools.machine_pather()}/works/DPO/output/{args.traindataset}/train_NOANS/{TRAIN_MODEL}/output_{i}.jsonl"
        outputname2id[output_name] = i+SAMPLE_NUM

        eval_name = f"{tools.machine_pather()}/works/DPO/output/{args.traindataset}/train_NOANS/{TRAIN_MODEL}/output_{i}_eval.json"
        evalname2id[eval_name] = i+SAMPLE_NUM

    SAMPLE_NUM *= 2


id2outputname = {v: k for k, v in outputname2id.items()}
id2evalname = {v: k for k, v in evalname2id.items()}

all_eval_jsons = []
for i in range(SAMPLE_NUM):
    all_eval_jsons.append(
        id2evalname[i]
    )


all_pairs = get_all_pairs(all_eval_jsons)
input_jsonl_path = f"{tools.machine_pather()}/works/DPO/data/{args.traindataset}/train_NOANS.jsonl"
all_input_jsons = tools.read_jsonl(input_jsonl_path)

all_output_jsons_list = []
for i in tqdm(range(SAMPLE_NUM)):
    output_jsonl_path = id2outputname[i]
    all_output_jsons_list.append(tools.read_jsonl(output_jsonl_path))


for method, pairs in all_pairs.items():
    print(f"Method: {method}, Pairs: {len(pairs)}, examples:")
    for p in pairs[:5]:
        print(p[0].replace(
            f"{tools.machine_pather()}/works/DPO/output/{args.traindataset}/", "").replace(f"{TRAIN_MODEL}/output_", "").replace(f"_eval.json", ""), p[1].replace(
            f"{tools.machine_pather()}/works/DPO/output/{args.traindataset}/", "").replace(f"{TRAIN_MODEL}/output_", "").replace(f"_eval.json", ""), p[2], p[3], p[4], p[5][0], p[6][0])
    data_dict = {'chosen': [], 'rejected': []}

    for high_name, low_name, idx, high_score, low_score, high_correctness, low_correctness in tqdm(pairs):
        high_id = evalname2id[high_name]
        low_id = evalname2id[low_name]
        user_content = all_input_jsons[idx]['content']
        high_output = all_output_jsons_list[high_id][idx]['text']
        low_output = all_output_jsons_list[low_id][idx]['text']
        data_dict['chosen'].append([
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": high_output}
        ])
        data_dict['rejected'].append([
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": low_output}
        ])

    os.makedirs(
        f"./train_data/{TRAIN_MODEL}/{save_traindataset}{args.subset}", exist_ok=True)

    with open(f'./train_data/{TRAIN_MODEL}/{save_traindataset}{args.subset}/rl_data_{method}.json', 'w', encoding='utf-8') as f:
        json.dump(data_dict, f, ensure_ascii=False, indent=2)

    open(f'./train_data/{TRAIN_MODEL}/{save_traindataset}{args.subset}/sft_from_rl_data_{method}.jsonl',
         'w', encoding='utf-8').close()
    for each in tqdm(data_dict['chosen']):
        tools.write_jsonl({"messages": each},
                          f'./train_data/{TRAIN_MODEL}/{save_traindataset}{args.subset}/sft_from_rl_data_{method}.jsonl')

    with open(f'./train_data/{TRAIN_MODEL}/{save_traindataset}{args.subset}/rl_data_{method}.json', 'r', encoding='utf-8') as f:
        data = json.load(f)

    open(
        f'./train_data/{TRAIN_MODEL}/{save_traindataset}{args.subset}/swift_rl_data_{method}.jsonl', 'w').close()
    for i in tqdm(range(len(data['chosen']))):
        output = {
            "messages": [
            ] + data['chosen'][i],
            "rejected_response": data['rejected'][i][1]['content']
        }
        tools.write_jsonl(
            output, f'./train_data/{TRAIN_MODEL}/{save_traindataset}{args.subset}/swift_rl_data_{method}.jsonl')

    print("="*20)
