import re
import json
import copy
import random
import argparse
import numpy as np
from tqdm import tqdm
from rouge_score import rouge_scorer
from collections import defaultdict
from sklearn.decomposition import PCA
from sklearn.cluster import AgglomerativeClustering
from concurrent.futures import ThreadPoolExecutor, as_completed

from config import *
from models import *

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--stage', type=int, default=1)
    parser.add_argument('--country', type=str, default='China')
    parser.add_argument('--model_name', type=str, default='gpt-3.5-turbo')
    parser.add_argument('--compare_countries', type=str, default=None)
    parser.add_argument('--output_dir', type=str, default='../data')
    parser.add_argument('--batch_size', type=int, default=10)
    parser.add_argument('--rep_score_type', type=str, default='Majority', choices=['Majority', 'CCT', 'Debate'])
    parser.add_argument('--eval_max_tokens', type=int, default=100)
    parser.add_argument('--eval_temp', type=float, default=0.7)
    parser.add_argument('--eval_top_p', type=float, default=0.8)
    return parser.parse_args()

def parse_gpt_questions(text):
    pattern = r"\[question \d+\]: (.+?) \((.+? question)\)"
    matches = re.findall(pattern, text)
    parsed_questions = [{"question": q.strip(), "type": q_type.strip()} for q, q_type in matches]
    return parsed_questions

def compute_cosine_similarity(vec_array, vec):
    """
    Compute cosine similarity between a vector and an array of vectors.
    :param vec_array: np.ndarray, shape (n_samples, n_features)
    :param vec: np.ndarray, shape (n_features,)
    :return: list of cosine similarities
    """
    from sklearn.metrics.pairwise import cosine_similarity
    return cosine_similarity(np.array(vec_array), np.array(vec).reshape(1, -1)).flatten().tolist()

def parse_cultural_points(text):
    client = OpenAI(api_key = API_KEYS["openai"])
    point_embeds = []
    for point in text.split("\n"):
        # point should start with a number, e.g., "1. Point description"
        point = point.strip()
        if not point or not re.match(r"^\d+\.", point):
            continue
        point = point.split(".", 1)[-1].strip()  # remove the number prefix
        embed = client.embeddings.create(input = [point], model="text-embedding-3-small").data[0].embedding
        point_embeds.append({"point": point, "embed": embed})
    return point_embeds

def majority_consensus(raw_consensus_matrix):
    final_consensus = {}
    consensus_scores = np.sum(raw_consensus_matrix, axis=0) / float(raw_consensus_matrix.shape[0])
    for i in range(raw_consensus_matrix.shape[1]):
        final_consensus[i] = consensus_scores[i]
    return final_consensus

def culture_consensus_theory(raw_consensus_matrix):
    '''
    input: raw_consensus_matrix: np.ndarray, shape (n_participants, n_cultural_points), with values 0 or 1 indicating whether they agree with the cultural point.
    output: final_consensus: dict, mapping from cultural point index to its consensus score
    '''
    ### split rows agreed by all participants
    all_agreed_points = []
    for i in range(raw_consensus_matrix.shape[1]):
        if np.sum(raw_consensus_matrix[:, i]) >= raw_consensus_matrix.shape[0] * 0.9:  # if all participants agree on this point
            all_agreed_points.append(i)
    # print(f"Consensus matrix shape: {raw_consensus_matrix.shape}, consensus matrix: {np.sum(raw_consensus_matrix, axis=0)}")

    ## compute the consensus score for each cultural point
    consensus_matrix = raw_consensus_matrix - np.mean(raw_consensus_matrix, axis=0, keepdims=True) # normalize the matrix     
    pca = PCA()
    pca.fit(consensus_matrix)

    eigenvalues = pca.explained_variance_
    explained_variance_ratio = pca.explained_variance_ratio_
    loadings = pca.components_.T * np.sqrt(eigenvalues)
    scores = pca.transform(consensus_matrix)
    count, cummulative_variance = 0, 0
    for idx, variance in enumerate(explained_variance_ratio):
        cummulative_variance += variance
        count += 1
        if idx > 0 and explained_variance_ratio[idx-1] / variance >= 3:
            break
        if cummulative_variance >= 0.8: 
            break

    principal_components = loadings[:, :count]
    representative_points = []
    for i in range(principal_components.shape[1]):
        sorted_indices = np.argsort(np.abs(principal_components[:, i]))[::-1]
        representative_points.append(sorted_indices[:5])

    participant_scores = scores[:, :count]
    for i in range(participant_scores.shape[1]):
        tmp = participant_scores[:, i] - np.min(participant_scores[:, i]) + 1e-6
        participant_scores[:, i] = tmp / np.sum(tmp)
    
    representative_point_scores = []
    for i in range(len(representative_points)):
        point_scores = []
        for idx in representative_points[i]: # 3, 7, 1,...
            point_score = np.sum(participant_scores[:, i] * raw_consensus_matrix[:, idx])
            point_scores.append(point_score)
        representative_point_scores.append(point_scores)
    
    final_consensus = defaultdict(int)
    for points, point_scores in zip(representative_points, representative_point_scores):
        for idx, score in zip(points, point_scores):
            final_consensus[int(idx)] = max(final_consensus.get(idx, 0), score)
    
    for idx in all_agreed_points:
        final_consensus[int(idx)] = 1.0
    return final_consensus

def compute_answer_consensus(args, question_batch, round_idx, generator):
    input_words, output_words = 0, 0

    ## compute the consensus score for each question
    country_list = ["United States", "United Kingdom", "Germany", "Italy", "South Korea", "China", "Japan", "India", "Singapore", "Indonesia", "Russia", "Poland"]
    if args.country.replace("_", " ") in country_list:
        country_list.remove(args.country.replace("_", " "))

    culture_profiles = json.load(open(os.path.join(args.output_dir, "WVS_demographics.json"), "r"))[args.country][:5]
    cultural_experts = random.sample(expert_roles, 5)  # use 5 experts as cultural experts
    foreign_experts = random.sample(country_list, 3)
    pattern = re.compile(r"Point\s*(\d+)\s*:\s*([1-5])\b", flags=re.IGNORECASE)
    question_labels = []

    prompts_base = []
    for question in question_batch:
        question_str = question["question"] if round_idx == 1 else question[f"round_{round_idx - 1}_content"]["refined_question"]
        all_cultural_points = []
        all_point_embeds = []
        for cidx in range(6):
            if "response_{}_cultural_points".format(cidx) in question[f"round_{round_idx}_content"]:
                cultural_points = question[f"round_{round_idx}_content"]["response_{}_cultural_points".format(cidx)]
                all_cultural_points.extend([x["point"] for x in cultural_points])
                all_point_embeds.extend([x["embed"] for x in cultural_points])
    
        X = np.array(all_point_embeds)
        hac = AgglomerativeClustering(distance_threshold=0.3, linkage='average', metric='cosine', n_clusters=None)
        hac.fit(X)
        labels = hac.labels_
        question_labels.append(labels)

        consensus_points_str = ""
        for i in range(len(set(labels))):
            cluster_idx = np.where(labels == i)[0]  # get the indices of points in this cluster
            cluster_points = [all_cultural_points[idx] for idx in cluster_idx]  # get the cultural points in this cluster
            point_str = ", ".join([x.strip() for x in cluster_points])
            consensus_points_str += f"Point {i + 1}: [{point_str}]\n"
        question[f"round_{round_idx}_content"]["consensus_point_str"] = consensus_points_str
    
        for cidx, profile in enumerate(culture_profiles):
            demographic = demographics.format(gender = profile['sex'], age = profile['age'], education = profile['highest_education'], occupation = profile['occupation'], social_class = profile['social_class'], income_level = profile['income_level'])

            prompt_base = represent_point_judgment_demographic.format(country = args.country, demographic_info = demographic, question = question_str, cultural_points = consensus_points_str)
            prompts_base.append(prompt_base)

        for cidx, expert in enumerate(cultural_experts):
            prompt_base = represent_point_judgment_expert.format(country = args.country, role = expert, question = question_str, cultural_points = consensus_points_str)
            prompts_base.append(prompt_base)
            
        for cidx, foreign in enumerate(foreign_experts):
            prompt_base = represent_point_judgment_foreign.format(foreign = foreign, country = args.country, question = question_str, cultural_points = consensus_points_str)
            prompts_base.append(prompt_base)

    responses = []
    max_batch_size = min(64, args.batch_size * 4)
    for start in range(0, len(prompts_base), max_batch_size):
        prompt_base_batch = prompts_base[start:start + max_batch_size]
        responses.extend(generator.get_batch_top_p_answer(prompt_base_batch, [""] * len(prompt_base_batch), max_tokens = 512, temperature = 0.7, top_p = 0.8))
    
    for qidx, question in enumerate(question_batch):
        labels = question_labels[qidx]
        one_responses_count = len(culture_profiles) + len(cultural_experts) + len(foreign_experts)
        raw_consensus_matrix = np.zeros((len(culture_profiles) + len(cultural_experts) + len(foreign_experts), len(set(labels))))
        one_responses = responses[qidx * one_responses_count:(qidx + 1) * one_responses_count]

        for cidx, response in enumerate(one_responses):
            input_words += len(prompts_base[qidx * one_responses_count + cidx].split())
            output_words += len(response.split())
            for match in pattern.finditer(response):
                point_idx = int(match.group(1)) - 1
                score = int(match.group(2))
                # judgment = 1 if score >= 4 else 0
                if point_idx >= raw_consensus_matrix.shape[1]:
                    break
                raw_consensus_matrix[cidx, point_idx] = score

        likert_consensus_matrix = raw_consensus_matrix.copy()
        raw_consensus_matrix[raw_consensus_matrix < 4] = 0  # convert to binary matrix, 0 or 1
        raw_consensus_matrix[raw_consensus_matrix >= 4] = 1  #
        if args.rep_score_type == "CCT":
            final_consensus = culture_consensus_theory(raw_consensus_matrix)
        elif args.rep_score_type == "Majority":
            final_consensus = majority_consensus(raw_consensus_matrix)

        question[f"round_{round_idx}_content"]["likert_consensus"] = likert_consensus_matrix.tolist()
        question[f"round_{round_idx}_content"]["consensus_scores"] = final_consensus
    return input_words, output_words

## the above is strategy 1: representativeness first, then distinctiveness
## the following is strategy 2: representativeness and distinctiveness optimization in parallel
def init_rep_dist_questions(args):
    question_generator = create_model(args.model_name, args)
    client = OpenAI(api_key = API_KEYS["openai"])
    all_topics = [k for k in wvs_topics] # a total of value-related topics
    os.makedirs(os.path.join(args.output_dir, args.country), exist_ok=True)

    if os.path.exists(os.path.join(args.output_dir, args.country, f"rep_dist_questions_{args.model_name.split('/')[-1]}.json")):
        topic_questions = json.load(open(os.path.join(args.output_dir, args.country, f"rep_dist_questions_{args.model_name.split('/')[-1]}.json"), "r"))
    else:
        topic_questions = defaultdict(list)
        for topic, examples in wvs_topics_examples.items():
            examples = [{'question': q, 'embedding': client.embeddings.create(input = [q], model="text-embedding-3-small").data[0].embedding} for q in examples]
            topic_questions[topic] = examples[:]

    input_words, output_words = 0, 0  # total: 800 words * 1.5 = 1200 tokens
    with tqdm(total=1000, desc=f"Generating questions for {args.country} cultural values") as pbar:
        for r_idx in range(1000):
            topic = random.choice(all_topics)
            if len(topic_questions[topic]) >= 20:
                continue
            definition = wvs_topics[topic]
            if not definition.lower().startswith(topic.lower().split()[0]):
                definition = f"{topic}: {definition}"
            selected_examples = random.sample(topic_questions[topic], 2)

            prompt_base = represent_distinctive_question_generate_prompt.format(topic = definition, country = args.country, example_1 = selected_examples[0]['question'], example_2 = selected_examples[1]['question'])

            output = question_generator.get_top_p_answer(prompt_base, "", max_tokens = 512, temperature = 0.7, top_p = 0.8)

            input_words += len(prompt_base.split())
            output_words += len(output.split())
            pbar.update(1)
            pbar.set_description(f"Generating questions for {args.country} on {topic}, input words: {input_words}, output words: {output_words}")
        
            questions = parse_gpt_questions(output)
            existing_questions = topic_questions[topic]
            existing_question_embeds = [q["embedding"] for q in existing_questions]
            repeat_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
            for question in questions:
                rouge_scores = [repeat_scorer.score(existing_questions[idx]["question"], question["question"])['rougeL'].fmeasure for idx in range(len(existing_questions))]
                if max(rouge_scores) > 0.6:
                    print("Questions with high similarity. Rouge scores: ", max(rouge_scores))
                    continue

                question["embedding"] = client.embeddings.create(input = [question["question"]], model="text-embedding-3-small").data[0].embedding
                cosine_scores = compute_cosine_similarity(existing_question_embeds, question["embedding"])
                if max(cosine_scores) > 0.75:
                    print("Questions with high cosine similarity. Cosine scores: ", max(cosine_scores))
                    continue
                
                most_similar = existing_questions[rouge_scores.index(max(rouge_scores))]
                question["most_similar"] = most_similar
                question["topic"] = topic
                topic_questions[topic].append(question)
            
            if len(topic_questions[topic]) >= 20:
                all_topics.remove(topic)
                if len(all_topics) == 0:
                    break
            
            if r_idx % 50 == 0:
                print("Saving questions...")

                with open(os.path.join(args.output_dir, args.country, f"rep_dist_questions_{args.model_name.split('/')[-1]}.json"), "w") as f:
                    json.dump(topic_questions, f, indent=4)

    with open(os.path.join(args.output_dir, args.country, f"rep_dist_questions_{args.model_name.split('/')[-1]}.json"), "w") as f:
        json.dump(topic_questions, f, indent=4)

def rep_dist_optimization(args):
    # generate represent and distinct answers --> refine questions --> generate represent and distinct answers
    generator = create_model(args.model_name, args)
    topic_questions = json.load(open(os.path.join(args.output_dir, args.country, f"rep_dist_questions_{args.model_name.split('/')[-1]}.json"), "r"))
    questions = []
    for topic, topic_question in topic_questions.items():
        for q in topic_question:
            q["topic"] = topic
        questions.extend(topic_question)
    
    existing_answers = []
    if os.path.exists(os.path.join(args.output_dir, args.country, f"rep_dist_answers_{args.model_name.split('/')[-1]}.jsonl")):
        existing_answers = open(os.path.join(args.output_dir, args.country, f"rep_dist_answers_{args.model_name.split('/')[-1]}.jsonl"), "r").readlines()
    fw = open(os.path.join(args.output_dir, args.country, f"rep_dist_answers_{args.model_name.split('/')[-1]}.jsonl"), "w")

    input_words, output_words = 0, 0
    pbar = tqdm(questions, desc=f"Representativeness & distinctiveness optimization for {questions[0]["topic"]} for {args.country}")

    for start_qidx in range(0, len(questions), args.batch_size):
        if start_qidx < len(existing_answers):
            for idx in range(start_qidx, min(start_qidx + args.batch_size, len(existing_answers))):
                fw.write(existing_answers[idx])
            pbar.update(args.batch_size)
            continue
        question_batch = questions[start_qidx:start_qidx + args.batch_size]
        
        for round_idx in range(1, 3):  # only do 2 rounds of optimization
            for question in question_batch:
                question[f"round_{round_idx}_content"] = {}

            ## generate answers to maximize log[p(y|x)]
            input_w1, output_w1 = generate_rep_dist_answers(args, question_batch, round_idx, generator)
            input_w2, output_w2 = compute_answer_consensus(args, question_batch, round_idx, generator)
            input_w3, output_w3 = integrate_rep_dist_answers(args, question_batch, round_idx, generator)
            ## refine the question to maximize log[p(x)] + log[p(y|x)]
            input_w4, output_w4 = refine_rep_dist_question(args, question_batch, round_idx, generator)
            ## we can also compute the increased representativeness to terminate the optimization when the increase is small
            input_words += (input_w1 + input_w2 + input_w3 + input_w4)
            output_words += (output_w1 + output_w2 + output_w3 + output_w4)

        pbar.update(args.batch_size)
        pbar.set_description(f"Representativeness & distinctiveness optimization for {question_batch[0]['topic']} for {args.country}, input words: {input_words}, output words: {output_words}")
        for question in question_batch:
            question["country"] = args.country
            question["round_idx"] = round_idx
            fw.write(json.dumps(question) + "\n")
    fw.close()

def generate_rep_dist_answers(args, question_batch, round_idx, generator):
    input_words, output_words = 0, 0
    other_countries = [None] if args.compare_countries is None else args.compare_countries

    prompts_base = []
    if args.compare_countries is not None:
        for question in question_batch:
            question_str = question["question"] if round_idx == 1 else question[f"round_{round_idx - 1}_content"]["refined_question"]
            for cidx, country in enumerate(other_countries):
                prompt_base = represent_answer_expert_role.format(country = country, role = "cross-cultural researcher", question = question_str)
                prompts_base.append(prompt_base)
                input_words += len(prompt_base.split())
        responses = generator.get_batch_top_p_answer(prompts_base, [""] * len(prompts_base), max_tokens = 512, temperature = 0.7, top_p = 0.8)
        for qidx, question in enumerate(question_batch):
            question[f"round_{round_idx}_content"]["other_country_responses"] = []
            for cidx, country in enumerate(other_countries):
                output_words += len(responses[cidx + qidx * len(other_countries)].split())
                question[f"round_{round_idx}_content"]["other_country_responses"].append({"response": responses[cidx + qidx * len(other_countries)], "country": country})
    else:
        for question in question_batch:
            question_str = question["question"] if round_idx == 1 else question[f"round_{round_idx - 1}_content"]["refined_question"]
            prompt_base = no_culture_answer.format(question = question_str)
            prompts_base.append(prompt_base)
            input_words += len(prompt_base.split())
        responses = generator.get_batch_top_p_answer(prompts_base, [""] * len(prompts_base), max_tokens = 512, temperature = 0.7, top_p = 0.8)
        for qidx, question in enumerate(question_batch):
            output_words += len(responses[0].split())
            question[f"round_{round_idx}_content"]["other_country_responses"] = [{"response": responses[qidx], "country": "Universal"}]

    culture_profiles = json.load(open(os.path.join(args.output_dir, "WVS_demographics.json"), "r"))[args.country]
    culture_profiles = random.sample(culture_profiles, 2)  # select 2 profiles for demographic responses
    cultural_experts = random.sample(expert_roles, 2)  # use the last 2 experts as cultural experts

    prompts_base = []
    ## generate answers from demographic profiles
    for question in question_batch:
        question_str = question["question"] if round_idx == 1 else question[f"round_{round_idx - 1}_content"]["refined_question"]
        other_response = "\n".join(x["response"] for x in question[f"round_{round_idx}_content"]["other_country_responses"])

        for pidx, profile in enumerate(culture_profiles):
            demographic = demographics.format(gender = profile['sex'], age = profile['age'], education = profile['highest_education'], occupation = profile['occupation'], social_class = profile['social_class'], income_level = profile['income_level'])
            prompt_base = rep_dist_answer_demographic_role.format(country = args.country, demographic_info = demographic, question = question_str, other_response = other_response)
            prompts_base.append(prompt_base)
            input_words += len(prompt_base.split())        
        ## generate answers from cultural experts
        for cidx, expert in enumerate(cultural_experts):
            prompt_base = rep_dist_answer_expert_role.format(country = args.country, role = expert, question = question_str, other_response = other_response)
            prompts_base.append(prompt_base)
            input_words += len(prompt_base.split())

    responses = generator.get_batch_top_p_answer(prompts_base, [""] * len(prompts_base), max_tokens = 256, temperature = 0.7, top_p = 0.8)

    prompts_base = []
    one_response_count = len(responses) // args.batch_size
    for qidx, question in enumerate(question_batch):
        question_str = question["question"] if round_idx == 1 else question[f"round_{round_idx - 1}_content"]["refined_question"]
        one_responses = responses[qidx * one_response_count:(qidx + 1) * one_response_count]
        for cidx, response in enumerate(one_responses):
            output_words += len(response.split())
            question[f"round_{round_idx}_content"][f"response_{cidx}"] = response
            prompt_base = point_extraction_prompt.format(question= question_str, response=response)
            input_words += len(prompt_base.split())
            prompts_base.append(prompt_base)

    responses = generator.get_batch_top_p_answer(prompts_base, [""] * len(prompts_base), max_tokens=256, temperature=0.7, top_p=0.8)

    for qidx, question in enumerate(question_batch):
        one_responses = responses[qidx * one_response_count:(qidx + 1) * one_response_count]
        for cidx, response in enumerate(one_responses):  # exclude the last response which is the main country response
            output_words += len(response.split())
            question[f"round_{round_idx}_content"][f"response_{cidx}_cultural_points"] = parse_cultural_points(response.strip())
    return input_words, output_words

def compute_rep_dist_answer_consensus(args, question_batch, round_idx, generator):
    input_words, output_words = 0, 0

    ## compute the consensus score for each question
    country_list = ["United States", "United Kingdom", "Germany", "Italy", "South Korea", "China", "Japan", "India", "Singapore", "Indonesia", "Russia", "Poland"]
    if args.country.replace("_", " ") in country_list:
        country_list.remove(args.country.replace("_", " "))

    culture_profiles = json.load(open(os.path.join(args.output_dir, "WVS_demographics.json"), "r"))[args.country][:15]
    cultural_experts = random.sample(expert_roles, 5)  # use 5 experts as cultural experts
    foreign_experts = random.sample(country_list, 3)
    pattern = re.compile(r"Point\s*(\d+)\s*:\s*([1-5])\b", flags=re.IGNORECASE)
    question_labels = []

    prompts_base = []
    for question in question_batch:
        question_str = question["question"] if round_idx == 1 else question[f"round_{round_idx - 1}_content"]["refined_question"]
        all_cultural_points = []
        all_point_embeds = []
        for cidx in range(6):
            if "response_{}_cultural_points".format(cidx) in question[f"round_{round_idx}_content"]:
                cultural_points = question[f"round_{round_idx}_content"]["response_{}_cultural_points".format(cidx)]
                all_cultural_points.extend([x["point"] for x in cultural_points])
                all_point_embeds.extend([x["embed"] for x in cultural_points])
    
        X = np.array(all_point_embeds)
        hac = AgglomerativeClustering(distance_threshold=0.3, linkage='average', metric='cosine', n_clusters=None)
        hac.fit(X)
        labels = hac.labels_
        question_labels.append(labels)

        consensus_points_str = ""
        for i in range(len(set(labels))):
            cluster_idx = np.where(labels == i)[0]  # get the indices of points in this cluster
            cluster_points = [all_cultural_points[idx] for idx in cluster_idx]  # get the cultural points in this cluster
            point_str = ", ".join([x.strip() for x in cluster_points])
            consensus_points_str += f"Point {i + 1}: [{point_str}]\n"
        question[f"round_{round_idx}_content"]["consensus_point_str"] = consensus_points_str
    
        for cidx, profile in enumerate(culture_profiles):
            demographic = demographics.format(gender = profile['sex'], age = profile['age'], education = profile['highest_education'], occupation = profile['occupation'], social_class = profile['social_class'], income_level = profile['income_level'])

            prompt_base = represent_point_judgment_demographic.format(country = args.country, demographic_info = demographic, question = question_str, cultural_points = consensus_points_str)
            prompts_base.append(prompt_base)

        for cidx, expert in enumerate(cultural_experts):
            prompt_base = represent_point_judgment_expert.format(country = args.country, role = expert, question = question_str, cultural_points = consensus_points_str)
            prompts_base.append(prompt_base)
            
        for cidx, foreign in enumerate(foreign_experts):
            prompt_base = represent_point_judgment_foreign.format(foreign = foreign, country = args.country, question = question_str, cultural_points = consensus_points_str)
            prompts_base.append(prompt_base)

    responses = []
    for start in range(0, len(prompts_base), 64):
        prompt_base_batch = prompts_base[start:start + 64]
        responses.extend(generator.get_batch_top_p_answer(prompt_base_batch, [""] * len(prompt_base_batch), max_tokens = 512, temperature = 0.7, top_p = 0.8))
    
    for qidx, question in enumerate(question_batch):
        labels = question_labels[qidx]
        one_responses_count = len(culture_profiles) + len(cultural_experts) + len(foreign_experts)
        raw_consensus_matrix = np.zeros((len(culture_profiles) + len(cultural_experts) + len(foreign_experts), len(set(labels))))
        one_responses = responses[qidx * one_responses_count:(qidx + 1) * one_responses_count]

        for cidx, response in enumerate(one_responses):
            input_words += len(prompts_base[qidx * one_responses_count + cidx].split())
            output_words += len(response.split())
            for match in pattern.finditer(response):
                point_idx = int(match.group(1)) - 1
                score = int(match.group(2))
                # judgment = 1 if score >= 4 else 0
                if point_idx >= raw_consensus_matrix.shape[1]:
                    break
                raw_consensus_matrix[cidx, point_idx] = score

        likert_consensus_matrix = raw_consensus_matrix.copy()
        raw_consensus_matrix[raw_consensus_matrix < 4] = 0  # convert to binary matrix, 0 or 1
        raw_consensus_matrix[raw_consensus_matrix >= 4] = 1  #
        if args.rep_score_type == "CCT":
            final_consensus = culture_consensus_theory(raw_consensus_matrix)
        elif args.rep_score_type == "Majority":
            final_consensus = majority_consensus(raw_consensus_matrix)

        question[f"round_{round_idx}_content"]["likert_consensus"] = likert_consensus_matrix.tolist()
        question[f"round_{round_idx}_content"]["consensus_scores"] = final_consensus
    return input_words, output_words

def integrate_rep_dist_answers(args, question_batch, round_idx, generator):
    input_words, output_words = 0, 0

    prompts_base = []
    for question in question_batch:
        question_str = question["question"] if round_idx == 1 else question[f"round_{round_idx - 1}_content"]["refined_question"]

        consensus_points_str = ""
        consensus_scores = question[f"round_{round_idx}_content"]["consensus_scores"]
        consensus_points = question[f"round_{round_idx}_content"]["consensus_point_str"].strip().split("\n")
        count = 0
        # print("Consensus scores: ", consensus_scores)
        for idx in range(len(consensus_scores)):
            if int(idx) not in consensus_scores or consensus_scores[int(idx)] < 0.75:  # only keep the points with high consensus
                continue
            count += 1
            consensus_points_str += f"Point {count}: {consensus_points[idx].split(": ", 1)[-1].strip()}\n"
        # print("Consensus points: ", consensus_points_str)

        prompt_base = represent_consensus_aggregation_prompt.format(country=args.country, question=question_str, consensus_points=consensus_points_str)
        input_words += len(prompt_base.split())
        prompts_base.append(prompt_base)
    responses = generator.get_batch_top_p_answer(prompts_base, [""] * len(prompts_base), max_tokens=1024, temperature=0.7, top_p=0.8)
    for question, response in zip(question_batch, responses):
        output_words += len(response.split())
        question[f"round_{round_idx}_content"]["integrated_answer"] = response.strip()
    return input_words, output_words

## refine the question to maximize representativeness and distinctiveness
def refine_rep_dist_question(args, question_batch, round_idx, generator):
    input_words, output_words = 0, 0
    client = OpenAI(api_key = API_KEYS["openai"])

    ## first to extract cultural points from other responses
    prompts_base = []
    for question in question_batch:
        question_str = question["question"] if round_idx == 1 else question[f"round_{round_idx - 1}_content"]["refined_question"]
        for cidx in range(len(question[f"round_{round_idx}_content"]["other_country_responses"])):
            prompt_base = point_extraction_prompt.format(question=question_str, response=question[f"round_{round_idx}_content"]["other_country_responses"][cidx]["response"])
            prompts_base.append(prompt_base)
            input_words += len(prompt_base.split())
        prompt_base = point_extraction_prompt.format(question=question_str, response=question[f"round_{round_idx}_content"]["integrated_answer"])
        prompts_base.append(prompt_base)
        input_words += len(prompt_base.split())
    responses = generator.get_batch_top_p_answer(prompts_base, [""] * len(prompts_base), max_tokens=256, temperature=0.7, top_p=0.8)

    prompts_base = []
    other_response_count = len(responses) // args.batch_size
    for qidx, question in enumerate(question_batch):
        topic = question["topic"]
        question_str = question["question"] if round_idx == 1 else question[f"round_{round_idx - 1}_content"]["refined_question"]
        
        consensus_points_str = ""
        # print("consensus: ", question[f"round_{round_idx}_content"]["consensus_point_str"].strip().split("\n"))
        for idx, consensus in enumerate(question[f"round_{round_idx}_content"]["consensus_point_str"].strip().split("\n")): # consensus score start from 0
            if idx not in question[f"round_{round_idx}_content"]['consensus_scores']:
                continue
            consensus_points_str += f"{consensus} (score: {question[f"round_{round_idx}_content"]['consensus_scores'][idx]})\n"
        # reward_score = np.mean([consensus['score'] for consensus in consensus_points])
        represent_reward = np.sum([1 if question[f"round_{round_idx}_content"]['consensus_scores'][idx] >= 0.8 else 0 for idx in range(len(question[f"round_{round_idx}_content"]['consensus_scores']))])
        # print("Representativeness reward: ", represent_reward)

        cultural_points = []
        one_responses = responses[qidx * other_response_count:(qidx + 1) * other_response_count]
        for cidx, response in enumerate(one_responses[:-1]):  # exclude the last response which is the main country response
            output_words += len(response.split())
            cultural_points.append(f"[{question[f'round_{round_idx}_content']['other_country_responses'][cidx]['country']}]: {response[response.find('1.'):].strip()}")
        question[f"round_{round_idx}_content"]["other_cultural_points"] = cultural_points
        response_embeds = client.embeddings.create(input=one_responses, model="text-embedding-3-small").data
        response_embeds = [x.embedding for x in response_embeds]
        question[f"round_{round_idx}_content"]["cultural_points_embeds"] = response_embeds
        culture_similarity = compute_cosine_similarity(np.array(response_embeds[:-1]), response_embeds[-1])  # [n_points, 1]
        distinct_reward = 1 - np.mean(culture_similarity)
        # print("Distinctiveness reward: ", distinct_reward)

        prompt_base = rep_dist_reflection_prompt.format(country=args.country, topic=topic, question=question_str, cultural_points=consensus_points_str, other_response="\n".join(cultural_points), representativeness_reward=represent_reward, distinctiveness_reward=distinct_reward)
        input_words += len(prompt_base.split())
        prompts_base.append(prompt_base)
        # print("Reflection prompt: ", prompt_base)

    responses = generator.get_batch_top_p_answer(prompts_base, [""] * len(prompts_base), max_tokens=512, temperature=0.7, top_p=0.8)
    messages = []
    for prompt_base, response in zip(prompts_base, responses):
        output_words += len(response.split())
        input_words += len(prompt_base.split() + rep_dist_refine_prompt.split())
        messages.append([{"role": "user", "content": prompt_base},
                {"role": "assistant", "content": response},
                {"role": "user", "content": rep_dist_refine_prompt}])
    responses = generator.get_batch_message_answer(messages, max_tokens=128, temperature=0.7, top_p=0.8)
    for idx, response in enumerate(responses):
        output_words += len(response.split())
        question_batch[idx][f"round_{round_idx}_content"]["refined_question"] = response.split(":", 1)[-1].strip()
    return input_words, output_words

if __name__ == "__main__":
    args = parse_arguments()
    print("Running with args: ", args)
    random.seed(2025)
    if args.stage == 1:
        init_rep_dist_questions(args)
    
    if args.stage == 2:
        rep_dist_optimization(args)