# Licensed under the MIT license.
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import sys
import os, json
from tqdm import tqdm

sys.path.append(".")

from common.utils import fix_seeds, read_json, read_txt
from eval_src.Evaluator import *
from run_src.utils import concat_solution_trace, mask_solution_trace
from argparse import ArgumentParser

from datetime import datetime

def extract_first_uppercase_after_answer(text):
    # 使用正则表达式查找“答案是”后的第一个大写字母，允许中间有无关字符
    # match = re.search(r'The answer is[^A-Z]*([A-Z])', text)
    match = re.search(r'答案是[^A-Z]*([A-Z])', text)
    if match:
        return match.group(1)  # 返回匹配到的大写字母
    else:
        return "None"  # 如果没有找到，返回 None
from collections import Counter

def most_frequent_element(lst):
    if len(lst) == 0:
        return None,0
    # 使用 Counter 统计每个元素的出现次数
    counter = Counter(lst)
    # 返回出现次数最多的元素及其次数
    most_common_element, count = counter.most_common(1)[0]
    return most_common_element, count
def second_most_frequent_element(data):
    if not data:
        return None ,None 
    element_counts = Counter(data)
    sorted_counts = element_counts.most_common()
    
    if len(sorted_counts) < 2:
        return None,None  # 如果元素少于2个，返回 None，因为不存在第二多的元素
    
    # 返回第二多的元素
    return sorted_counts[1][0],sorted_counts[1][1]

def calculate_diversity(solutions_list):
    options = [extract_first_uppercase_after_answer(trace) for trace in solutions_list]
    model = SentenceTransformer('checkpoints/lier007__xiaobu-embedding-v2')
    embeddings = model.encode(solutions_list, convert_to_tensor=True)
    embeddings = embeddings.cpu().numpy()

    similarity_matrix = cosine_similarity(embeddings)

    # 计算每个选项的多样性
    diversity = {}
    for i, option in enumerate(options):
        if option not in diversity:
            diversity[option] = []
        # 计算该轨迹与其他轨迹的平均相似度
        avg_similarity = np.mean(similarity_matrix[i])
        diversity[option].append(1 - avg_similarity)  # 多样性 = 1 - 平均相似度

    # 计算每个选项的平均多样性
    for option in diversity:
        diversity[option] = np.mean(diversity[option])

    return diversity

def main():
    parser = ArgumentParser()
    parser.add_argument("--note", type=str, default="default")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--api", type=str, default="vllm")
    parser.add_argument("--model_ckpt", type=str, required=True)
    parser.add_argument("--root_dir", type=str, required=True)
    parser.add_argument("--dataset_name", type=str, required=True)
    parser.add_argument("--resume", type=str, default=None)

    parser.add_argument("--threshold", type=float, default=0.999)

    # vLLM
    parser.add_argument("--max_num_seqs", type=int, default=256)

    # For multi-choice
    parser.add_argument("--multi_choice_prompt_type", type=str, default=None, choices=["fewshot", "instruct"])

    # For reasoning consistency
    parser.add_argument("--mask_left_boundary", type=float, default=0.2)
    parser.add_argument("--mask_right_boundary", type=float, default=0.5)
    parser.add_argument("--num_masked_solution_traces", type=int, default=4)
    parser.add_argument("--rc_mode", type=str, default="mid", choices=["loose", "mid", "strict", "maj"])
    parser.add_argument("--rc_temperature", type=float, default=1.0)
    parser.add_argument("--rc_n_completions", type=int, default=1)
    parser.add_argument("--rc_criteria", type=str, default="reward", choices=["freq", "reward"])

    # For rollout
    parser.add_argument("--cutoff_rollout", type=int, default=-1)
    parser.add_argument("--start_idx", type=int, default=-1)
    parser.add_argument("--end_idx", type=int, default=-1)

    args = parser.parse_args()

    args.fewshot_config_path = os.path.join("prompts", args.dataset_name, "fewshot_cot", "fewshot_cot_config.json")
    args.fewshot_prompt_path = os.path.join("prompts", args.dataset_name, "fewshot_cot", "fewshot_cot_prompt.txt")

    fix_seeds(args.seed)
    print(args)

    answer_sheets_dir = os.path.join(args.root_dir, "answer_sheets")
    if args.resume:
        exp_id = args.resume
    else:
        exp_id = f"dis_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}---{args.note}"

    discriminate_out_dir = os.path.join(args.root_dir, exp_id)
    os.makedirs(discriminate_out_dir, exist_ok=True)
    args.discriminate_results_dir = os.path.join(discriminate_out_dir, "results")
    os.makedirs(args.discriminate_results_dir, exist_ok=True)

    recording_file = os.path.join(discriminate_out_dir, "recording.json")

    recording = vars(args)

    evaluator = eval(f"{args.dataset_name}Evaluator()")
    
    #! ------ Select winner candidate for each example ------
    answer_sheet_json_files = [
        os.path.join(answer_sheets_dir, f) for f in os.listdir(answer_sheets_dir) if f.endswith("Answer.json")
    ]
    answer_sheet_json_files.sort()
    if args.start_idx > -1 and args.end_idx > -1:
        answer_sheet_json_files = answer_sheet_json_files[args.start_idx : args.end_idx]

    num_correct, num_correct_majvote, num_correct_limit, num_tested = 0, 0, 0, 0
    num_eval_correct = 0
    dpo_zh_data = []
    print(len(answer_sheet_json_files))
    input(">")
    with tqdm(total=len(answer_sheet_json_files), disable=True) as pbar:
        total_num_candidates = 0
        for file_idx, answer_js_file in enumerate(answer_sheet_json_files):
            problem_id = int(
                answer_js_file.split("/")[-1].split(".")[0].replace(" - Answer", "").replace("Question ", "")
            )
            
            if args.resume and os.path.exists(
                os.path.join(args.discriminate_results_dir, f"problem-{problem_id}.json")
            ):
                # 已经处理过该问题
                print(f"\n[Skip file {file_idx}; Total number of files: {len(answer_sheet_json_files)}]\n")
                with open(os.path.join(args.discriminate_results_dir, f"problem-{problem_id}.json"), "r") as f:
                    temp_recording = json.load(f)
                correct = temp_recording["correct"]
                correct_majvote = temp_recording["correct_majvote"]
                correct_limit = temp_recording["correct_limit"]

                num_correct += int(correct)
                num_correct_majvote += int(correct_majvote)
                num_correct_limit += int(correct_limit)
                num_tested += 1

                info = f"Acc: {num_correct / num_tested:.4f}; Majority vote acc: {num_correct_majvote / num_tested:.4f}; Limit acc: {num_correct_limit / num_tested:.4f}"
                print(info)
                pbar.set_description(info, refresh=True)
            else:
                print(f"\n[Processing file {file_idx}; Total number of files: {len(answer_sheet_json_files)}]\n")
                try:
                    answer_js = read_json(answer_js_file)
                except:
                    continue

                try:
                    problem = answer_js["problem"]
                    # assert problem_id == answer_js["id"]
                    gold_answer = answer_js["gold_answer"]
                except:
                    pass

                trace_js = read_json(answer_js_file.replace("Answer", "Final Solutions"))
                # + read_json(
                #     answer_js_file.replace("Answer", "Rollout Solutions")
                # )
                answers = []
                solutions = []
                for t in trace_js:

                    trace= concat_solution_trace(t['trace'])[0]
                    ext_answer = extract_first_uppercase_after_answer(trace)
                    
                    
                    if ext_answer.isupper():
                        solutions.append(trace)
                        answers.append(ext_answer)
                
                

                if gold_answer in answers:
                    num_correct+=1
                    most_ans,most_fre =  most_frequent_element(answers)
                    if gold_answer ==most_ans:
                        num_correct_majvote+=1
                    
                    eval_ans = most_ans
                    if most_fre/len(answers)<0.667:
                        sed_ans,sed_freq = second_most_frequent_element(answers)
                        diver = calculate_diversity(solutions)
                        if sed_ans == None:
                            continue
                        if diver[most_ans] < diver[sed_ans]:
                            eval_ans = sed_ans
                    num_eval_correct+=int(eval_ans == gold_answer)
           
  
                chosen = []
                rejected = []
                for k in range(0,len(answers)):
                    if answers[k]==gold_answer:
                        chosen.append(solutions[k])
                    else:
                        rejected.append(solutions[k])
                
               
                if len(rejected)==0 or len(chosen)==0:
                    continue
                for k in range(0,min(len(chosen),len(rejected))):
                    dpo_zh_data.append({
                        "conversations": [
                            {
                            "from": "human",
                            "value": problem+"\n请一步步分析。"
                            }
                        ],
                        "chosen": {
                            "from": "gpt",
                            "value": chosen[k%len(chosen)]
                        },
                        "rejected": {
                            "from": "gpt",
                            
                            "value": rejected[k%len(rejected)]
                        }
                        }
                    )
    
               

    print("in_num_correct:{},majvote_num_correct:{},eval_num_correct:{}".format(num_correct,num_correct_majvote,num_eval_correct))
    print("in_correct_rate:{},majvote_correct_rate:{},eval_correct_rate:{}".format(num_correct/len(answer_sheet_json_files),num_correct_majvote/len(answer_sheet_json_files),(num_eval_correct)/len(answer_sheet_json_files)))
   
    
if __name__ == "__main__":
    main()
