import os
from openai import OpenAI

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY", ""))

import re
import gc
import torch
import time
import argparse
import torch.nn.functional as F
from vllm import LLM, SamplingParams
from datetime import datetime
from prompts import *
import pandas as pd
from collections import defaultdict
from scipy.optimize import linear_sum_assignment
import itertools
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig, GenerationConfig, StoppingCriteria, StoppingCriteriaList, LogitsProcessor, LogitsProcessorList
import json
import backoff
import random
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import multiprocessing as mp
from tqdm import tqdm
import tempfile
import uuid
import sys

SEED = 2025
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

def get_openai_embeddings(text_list, model="text-embedding-3-large"):
    retry_count = 0
    max_retries = 5

    while retry_count < max_retries:
        try:
            response = client.embeddings.create(input=text_list,
            model=model)
            break
        except Exception as e:
            retry_count += 1
            print(f"Error: {e} and Wait")
            time.sleep(2)
            if retry_count == max_retries:
                input(f"Check: Failed after {max_retries} retries.")

    embeddings = [e.embedding for e in response.data]
    return torch.tensor(embeddings, dtype=torch.float32)

class StopAfterSeq(LogitsProcessor):
    def __init__(self, base_len: int, stop_sequence: str, tokenizer, eos_token_id: int):
        super().__init__()
        
        self.base_len = base_len
        self.eos_token_id = eos_token_id
        self.sp_token_ids = tokenizer.encode(stop_sequence, add_special_tokens=False)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if input_ids.size(1) > self.base_len:
            forced_eos = torch.full((scores.size(1),), -float("inf"))
            forced_eos[self.eos_token_id] = 0
            
            for idx in range(len(input_ids)):
                if input_ids[idx, -len(self.sp_token_ids):].tolist() == self.sp_token_ids:
                    scores[idx] = forced_eos
        
        return scores

def cot_prompt_wrap(x: str, y:str, format_prompt:str, author_persona:str, given_persona:str, given_strategy:str) -> str:
    if given_persona is not None and author_persona is not None:
        if given_strategy is not None:
            base_prompt = format_prompt.format(input=x, author_persona=author_persona, given_persona=given_persona, given_strategy=given_strategy)
        else:
            base_prompt = format_prompt.format(input=x, author_persona=author_persona, given_persona=given_persona)
    elif author_persona is not None:
        if given_strategy is not None:
            base_prompt = format_prompt.format(input=x, author_persona=author_persona, given_strategy=given_strategy)
        else:
            base_prompt = format_prompt.format(input=x, author_persona=author_persona)
    elif given_persona is not None:
        if given_strategy is not None:
            base_prompt = format_prompt.format(input=x, persona=given_persona, given_strategy=given_strategy)
        else:
            base_prompt = format_prompt.format(input=x, persona=given_persona)
    else:
        if given_strategy is not None:
            base_prompt = format_prompt.format(input=x, given_strategy=given_strategy)
        else:
            base_prompt = format_prompt.format(input=x)
    
    return base_prompt + y

def vote_multiple_outputs_unwrap(vote_outputs: list, n_candidates: int) -> list:
    vote_results = [0] * n_candidates
    if not vote_outputs:
        return vote_results

    for output in vote_outputs:
        text = str(output).strip()

        m = re.search(r'the best choices?\s+(?:are|is)\s+([^\n\r]+)', text, flags=re.IGNORECASE)
        numbers = None
        if m:
            numbers = re.findall(r'\b\d+\b', m.group(1))

        if not numbers:
            numbers = re.findall(r'\b\d+\b', text)

        seen_this_output = set()
        for n in numbers:
            v = int(n)
            if 1 <= v <= n_candidates and v not in seen_this_output:
                vote_results[v - 1] += 1
                seen_this_output.add(v)

    return vote_results

def vote_one_prompt_wrap(x: str, ys: list, format_prompt: str, given_persona: str, author_persona: str) -> str:
    if given_persona is not None and author_persona is not None:
        prompt = format_prompt.format(input=x, author_persona=author_persona, given_persona=given_persona)
    elif given_persona is not None:
        prompt = format_prompt.format(input=x, given_persona=given_persona)
    elif author_persona is not None:
        prompt = format_prompt.format(input=x, author_persona=author_persona)
    else:
        prompt = format_prompt.format(input=x)
    for i, y in enumerate(ys, 1):
        prompt += f'Option {i}:\n{y}\n'
    return prompt

def vote_one_outputs_unwrap(vote_outputs: list, n_candidates: int) -> list:
    vote_results = [0] * n_candidates
    for vote_output in vote_outputs:
        pattern = r".*choice is .*(\d+).*"
        match = re.match(pattern, vote_output, re.DOTALL)
        
        if match:
            vote = int(match.groups()[0]) - 1
            if vote in range(n_candidates):
                vote_results[vote] += 1
        else:
            numbers = re.findall(r"\d+", vote_output)
            if numbers:
                vote = int(numbers[0]) - 1
                if vote in range(n_candidates):
                    vote_results[vote] += 1
            else:
                print(f'vote no match: {[vote_output]}')
    return vote_results

# @backoff.on_exception(backoff.expo, Exception)
def batch_generate(llm, tokenizer, list_of_messages, n=1, stop=None, temperature=0.8):
    # input('check~~~~~~~~~~')
    sampling_params = SamplingParams(
        temperature=temperature,
        top_p=0.95,
        max_tokens=2048,
        stop=stop,
        seed=2025,
        n=n,
    )

    prompt_texts = []
    for message in list_of_messages:
        prompt_text = tokenizer.apply_chat_template(
            message,
            add_generation_prompt=True,
            tokenize=False
        )
        prompt_texts.append(prompt_text)
    # print(list_of_messages[0])
    # input('check........')

    outputs = llm.generate(
        prompt_texts,
        sampling_params,
    )

    responses = []
    for i in range(len(outputs)):
        for j in range(n):
            responses.append(outputs[i].outputs[j].text)

    del prompt_texts
    gc.collect()
    torch.cuda.empty_cache()

    batch_size = len(list_of_messages)
    result_per_message = []

    for i in range(batch_size):
        s_idx = i * n
        e_idx = s_idx + n
        chunk = responses[s_idx:e_idx]
        result_per_message.append(chunk)

    # print(list_of_messages[0])
    # print(result_per_message[0])
    # input('check---')

    return result_per_message

def get_prompt_format(prompt, system_prompt=None, ex_prompt=None):
    if system_prompt:
        messages = [{"role": "system", "content": system_prompt}]
    else:
        messages = []

    messages.append({"role": "user", "content": prompt})
    return messages

def get_op_persona_input_messages(x, y, op_method: str):
    prompt = op_persona_generation_prompt.format(input=x)
    return get_prompt_format(prompt, system_prompt=op_persona_system_prompt)

def get_input_messages(x, y, step:int, op_method:str, select_method:str, author_persona:str, given_persona:str, given_strategy:str):
    if step == 3:
        system_prompt = tot_system_prompt
        if given_persona is not None and author_persona is not None:
            format_prompt = prompt_persona_only
        elif author_persona is not None:
            format_prompt = prompt_author_only
        elif given_persona is not None:
            format_prompt = prompt_given_only_persona
        else:
            format_prompt = prompt_minimal

    prompt = cot_prompt_wrap(x, y, format_prompt, author_persona, given_persona, given_strategy)
    return get_prompt_format(prompt, system_prompt=system_prompt, ex_prompt=None)

def get_batch_one_votes(llm, tokenizer, xs: list, yss_list: list, format_prompt: str, given_persona: str, author_persona: str, n_evaluate_sample: int):
    list_of_messages = []
    idx_map = []

    for i, yss in enumerate(yss_list):
        prompt = vote_one_prompt_wrap(xs, yss, format_prompt, given_persona=given_persona, author_persona=author_persona)
        message = get_prompt_format(prompt, system_prompt=vote_one_system_prompt)
        list_of_messages.append(message)
        idx_map.append(i)

    batch_out = batch_generate(llm, tokenizer, list_of_messages, n=n_evaluate_sample, temperature=0.8)

    values = []
    for i, outs, yss in zip(idx_map, batch_out, yss_list):
        values.append(vote_one_outputs_unwrap(outs, len(yss)))

    return values

def get_vote_multiple_system_prompt(n_choices):
    example_vars = ', '.join(chr(ord('A') + i) for i in range(n_choices))
    
    return f'''
Only output the final answer.
You must respond with: The best choices are {example_vars}
Where {example_vars} are the numeric IDs of the selected options in order of preference.
Select exactly {n_choices} options.
Use no quotes, brackets, or extra symbols. Do not change the wording or format.
Do not include any introduction, explanation, or extra punctuation.
Separate the numbers with commas and spaces.
'''

def get_batch_multiple_votes(llm, tokenizer, xs: list, yss_list: list, format_prompt: str, given_persona: str, author_persona: str, n_evaluate_sample: int, n_seq: int):
    messages = []

    vote_multiple_system_prompt = get_vote_multiple_system_prompt(n_seq)
    
    for yss in yss_list:
        prompt = vote_multiple_prompt_wrap(xs, yss, format_prompt, given_persona=given_persona, author_persona=author_persona, n_choices=n_seq)
        message = get_prompt_format(prompt, system_prompt=vote_multiple_system_prompt)
        messages.append(message)
    
    batch_out = batch_generate(llm, tokenizer, messages, n=n_evaluate_sample, temperature=0.8)

    results = []
    for yss, outs in zip(yss_list, batch_out):
        vote_result = vote_multiple_outputs_unwrap(outs, len(yss))
        results.append(vote_result)
    
    return results

def vote_multiple_prompt_wrap(xs, yss, format_prompt, given_persona=None, author_persona=None, n_choices=3):
    argument_text = "\n".join(xs) if isinstance(xs, list) else str(xs)
    
    options_text = ""
    for i, option in enumerate(yss):
        options_text += f"{i+1}. {option}\n"
    
    persona_info = ""
    if given_persona:
        persona_info += f"Given persona: {given_persona}\n"
    if author_persona:
        persona_info += f"Author persona: {author_persona}\n"
    
    example_vars = ', '.join(chr(ord('A') + i) for i in range(n_choices))
    choice_instruction = f"\nSelect the top {n_choices} options.\nDo not explain your reasoning. Only output the final sentence:\nThe best choice is {example_vars}"
    
    full_prompt = f"""Argument:
{argument_text}

{persona_info}

{format_prompt}

{choice_instruction}

{options_text}"""
    
    return full_prompt

def get_batch_counter_one_votes(llm, tokenizer, xs: list, yss_list: list, format_prompt: str, given_persona_list: list, author_persona: str, n_evaluate_sample: int):
    list_of_messages = []
    idx_map = []

    for i, yss in enumerate(yss_list):
        prompt = vote_one_prompt_wrap(xs, yss, format_prompt, given_persona=given_persona_list[i] if given_persona_list is not None else None, author_persona=author_persona)
        message = get_prompt_format(prompt, system_prompt=vote_one_system_prompt)
        list_of_messages.append(message)
        idx_map.append(i)

    batch_out = batch_generate(llm, tokenizer, list_of_messages, n=n_evaluate_sample, temperature=0.8)

    values = []
    for i, outs, yss in zip(idx_map, batch_out, yss_list):
        values.append(vote_one_outputs_unwrap(outs, len(yss)))

    return values

def load_cluster_resources(n_components, min_cluster_size):
    persona_dict = torch.load("../my-llama3.1-imple/persona_hub/persona_embeddings.pt")
    original_persona_list = persona_dict['personas']
    original_persona_embeddings = persona_dict['embeddings']
    # persona_list = persona_dict['personas']
    # persona_embeddings = persona_dict['embeddings']

    # persona_text_to_embedding_dict = {
    #     persona: embedding
    #     for persona, embedding in zip(persona_list, persona_embeddings)
    # }

    text_dict = torch.load("./data/persona_hub/test_text_embeddings.pt")
    text_embeddings = text_dict['embeddings']

    df = pd.read_csv(f"./data/persona_hub/persona_clusters_{n_components}d_{min_cluster_size}_no_-1.csv")

    cluster_personas = set(df['persona'])

    selected_indices = [
        idx for idx, persona in enumerate(original_persona_list)
        if persona in cluster_personas
    ]

    persona_list = [original_persona_list[i] for i in selected_indices]
    persona_embeddings = original_persona_embeddings[selected_indices]

    persona_text_to_embedding_dict = {
        persona: embedding
        for persona, embedding in zip(persona_list, persona_embeddings)
    }

    label_list = list(set(df['cluster_label'].tolist()))
    label_dict = {row['persona']: row['cluster_label'] for _, row in df.iterrows()}
    label_to_persona_dict = defaultdict(list)
    for persona, label in label_dict.items():
        label_to_persona_dict[label].append(persona)

    persona_to_embedding = {persona: emb for persona, emb in zip(persona_list, persona_embeddings) if persona in label_dict}

    cluster_embeddings = defaultdict(list)
    cluster_personas = defaultdict(list)

    for persona in persona_list:
        label = label_dict.get(persona, -1)
        if label != -1:
            emb = persona_to_embedding[persona]
            cluster_embeddings[label].append(emb)
            cluster_personas[label].append(persona)

    cluster_center_info = {}
    cluster_center_df = pd.read_csv(f"./data/persona_hub/cluster_centroid_representatives_{n_components}d_{min_cluster_size}.csv")
    for _, row in cluster_center_df.iterrows():
        cluster_center_info[row['cluster_label']] = row['centroid_persona']

    cluster_dist_info = {}
    cluster_dist_df = pd.read_csv(f"./data/persona_hub/cluster_centroid_cosine_distances_{n_components}d_{min_cluster_size}.csv")
    for i, (_, row) in enumerate(cluster_dist_df.iterrows()):
        cluster_dist_info[i] = row.astype(float).tolist()

    cluster_dict = torch.load(f"./data/persona_hub/cluster_centroids_{n_components}d_{min_cluster_size}.pt", weights_only=False)
    centroid_labels = cluster_dict["label"]
    centroid_embeddings = cluster_dict["embeddings"]
    cluster_vectors = np.array([emb.cpu().numpy() for emb in centroid_embeddings])

    return {
        "persona_list": persona_list,
        "persona_embeddings": persona_embeddings,
        "persona_text_to_embedding_dict": persona_text_to_embedding_dict,
        "text_embeddings": text_embeddings,
        "label_list": label_list,
        "label_dict": label_dict,
        "label_to_persona_dict": label_to_persona_dict,
        "cluster_embeddings": cluster_embeddings,
        "cluster_personas": cluster_personas,
        "cluster_center_info": cluster_center_info,
        "cluster_dist_info": cluster_dist_info,
        "cluster_vectors": cluster_vectors,
    }


def solve(
    llm,
    tokenizer,
    samples,
    cluster_data,
    op_method,
    select_method,
    step,
    persona_n,
    s_id,
    e_id,
    batch_pids=None,
):
    persona_list = cluster_data["persona_list"]
    persona_embeddings = cluster_data["persona_embeddings"]
    persona_text_to_embedding_dict = cluster_data["persona_text_to_embedding_dict"]
    text_embeddings = cluster_data["text_embeddings"]
    label_list = cluster_data["label_list"]
    label_dict = cluster_data["label_dict"]
    label_to_persona_dict = cluster_data["label_to_persona_dict"]
    cluster_embeddings = cluster_data["cluster_embeddings"]
    cluster_personas = cluster_data["cluster_personas"]
    cluster_center_info = cluster_data["cluster_center_info"]
    cluster_dist_info = cluster_data["cluster_dist_info"]
    cluster_vectors = cluster_data["cluster_vectors"]

    given_text_embeddings = text_embeddings[s_id:e_id]

    partial_ys = [[''] for _ in samples]

    infos = []

    stop_list = [
        None,
        ['Counterargument:\n', 'Counterargument:\n\n', 'Counterargument: ', 'Counterargument:'],
        None
    ]

    author_persona_list = None
    author_cluster_list = None
    author_embedding_list = None
    selected_clusters_list = None

    best_strategies = [dict() for _ in samples]
    persona_best_texts = [dict() for _ in samples]

    if 'cluster' in select_method:
        author_cluster_list = []
        author_embedding_list = []
        selected_clusters_list = []

    # Original Poster (OP) Inference
    if op_method == 'wo_op_persona':
        author_persona_list = [None] * len(samples)
    elif op_method == 'op_persona':
        list_of_messages = []
        idx_map = []
        for i, x in enumerate(samples):
            for j, partial in enumerate(partial_ys[i]):
                message = get_op_persona_input_messages(
                    x,
                    partial,
                    op_method=op_method
                )
                list_of_messages.append(message)
                idx_map.append((i, j))

        batch_out = batch_generate(
            llm,
            tokenizer,
            list_of_messages,
            n=1,
        )

        author_persona_list = [gen_texts[0].strip() for gen_texts in batch_out]
        for i in range(len(samples)):
            author_persona_list[i] = author_persona_list[i].split("Author's Persona:")[1].strip()
            if 'cluster' in select_method:
                author_persona_embedding = get_openai_embeddings([author_persona_list[i]], model="text-embedding-3-large")[0]
                author_embedding_list.append(author_persona_embedding)

                cluster_vectors_tensor = torch.tensor(cluster_vectors).to(author_persona_embedding.device)
                similarities = F.cosine_similarity(
                    author_persona_embedding.unsqueeze(0),
                    cluster_vectors_tensor,
                    dim=1
                )
                best_cluster_idx = similarities.argmax().item()
                author_cluster = label_list[best_cluster_idx]
                author_cluster_list.append(author_cluster)

    
    step_list = [persona_n, 3, 3]
        
    select_new_partials = [[] for _ in samples]

    for i in range(len(samples)):
        for j, partial in enumerate(partial_ys[i]):
            select_new_partials[i].append(partial)

    for step_i, n_seq in enumerate(step_list):
        print()
        print(f"==================={step_i}===================")
        stop = stop_list[step_i]

        if step_i == 0: # persona selection
            selected_strategies = [[] for _ in range(len(samples))]

            batch_personas = []
            vote_counts_list = []

            for i in range(len(samples)):
                if select_method == 'wo_persona':
                    mid_personas = [j + 1 for j in range(n_seq)]
                elif select_method == "random":
                    mid_personas = random.sample(persona_list, n_seq)
                elif select_method == "cluster_diff_case":
                    author_cluster = author_cluster_list[i]
                    author_cluster_distances = cluster_dist_info[author_cluster]
                    sorted_indices = np.argsort(author_cluster_distances).tolist()

                    op_same_idx = sorted_indices[0] # op_same
                    similar_idx = sorted_indices[1:(n_seq-1) // 2 + 1] # similar
                    dissimilar_idx = sorted_indices[-((n_seq-1) // 2):] # dissimilar

                    selected_clusters = [op_same_idx] + similar_idx + dissimilar_idx
                    selected_clusters_list.append(selected_clusters)

                    mid_personas = []
                    for idx in selected_clusters:
                        cluster_label = label_list[idx]
                        if cluster_label in cluster_center_info:
                            mid_personas.append(cluster_center_info[cluster_label])
                        else:
                            input(f"Cluster {cluster_label} not found in center info.")

                batch_personas.append(mid_personas)

            selected_strategies = [[None] * n_seq] * len(samples)
        
        elif step_i == 1:
            if n_seq == 0:
                continue
            
            list_of_messages = []
            idx_map = []
            for i, x in enumerate(samples):
                for temp_idx, (persona, strategy) in enumerate(zip(batch_personas[i], selected_strategies[i])):
                    message = get_input_messages(
                        x,
                        partial_ys[i][0],
                        step=step,
                        op_method=op_method,
                        select_method=select_method,
                        author_persona=author_persona_list[i],
                        given_persona=persona if select_method != 'wo_persona' else None,
                        given_strategy=strategy
                    )
                    list_of_messages.append(message)
                    if select_method == 'wo_persona':
                        group_idx = (temp_idx % n_seq) + 1
                        idx_map.append((i, group_idx, strategy))
                    else:
                        idx_map.append((i, persona, strategy))
                
            batch_out = batch_generate(
                llm,
                tokenizer,
                list_of_messages,
                n=n_seq,
                stop=stop
            )

            new_partials = [[] for _ in samples]
            select_new_partials = [[] for _ in samples]
            vote_counts_list = [[] for _ in samples]

            for (i, persona, strategy), gen_texts in zip(idx_map, batch_out):
                for t in gen_texts:
                    merged = partial_ys[i][0] + t
                    new_partials[i].append((persona, strategy, merged))

        elif step_i == 2:
            list_of_messages = []
            idx_map = []
            for i, x in enumerate(samples):
                for group_i, (persona, strategy) in enumerate(zip(batch_personas[i], selected_strategies[i])):
                    if group_i + 1 in persona_best_texts[i]:
                        prefix = persona_best_texts[i][group_i + 1]
                    else:
                        prefix = persona_best_texts[i][persona]
                    
                    message = get_input_messages(
                        x,
                        prefix,
                        step=step,
                        op_method=op_method,
                        select_method=select_method,
                        author_persona=author_persona_list[i],
                        given_persona=persona if select_method != 'wo_persona' else None,
                        given_strategy=strategy
                    )
                    list_of_messages.append(message)
                    idx_map.append((i, persona, strategy))
                
            batch_out = batch_generate(
                llm,
                tokenizer,
                list_of_messages,
                n=n_seq,
                stop=stop
            )

            new_partials = [[] for _ in samples]
            select_new_partials = [[] for _ in samples]
            vote_counts_list = [[] for _ in samples]

            for (i, persona, strategy), gen_texts in zip(idx_map, batch_out):
                if step == 2:
                    prefix = ''
                elif step == 3:
                    if select_method != 'wo_persona':
                        prefix = persona_best_texts[i][persona]
                    else:
                        if group_i + 1 in persona_best_texts[i]:
                            prefix = persona_best_texts[i][group_i + 1]
                        else:
                            prefix = persona_best_texts[i][strategy]
                for t in gen_texts:
                    if select_method != 'wo_persona':
                        merged = prefix + t
                    else:
                        merged = t
                    new_partials[i].append((persona, strategy, merged))

        if step_i != 0:
            vote_counts_list = [[] for _ in samples]

            if op_method == 'wo_op_persona' and select_method == 'wo_persona':
                format_prompt = vote_one_plan_prompt_wo_op_and_persona if step_i == 1 else vote_one_counter_prompt_wo_op_and_persona
            elif op_method == 'wo_op_persona':
                format_prompt = vote_one_plan_prompt_wo_op if step_i == 1 else vote_one_counter_prompt_wo_op
            elif select_method == 'wo_persona':
                format_prompt = vote_one_plan_prompt_wo_persona if step_i == 1 else vote_one_counter_prompt_wo_persona
            else:
                format_prompt = vote_one_plan_prompt_with_persona if step_i == 1 else vote_one_counter_prompt_with_persona

            selected_persona_strategy_list = [[] for _ in samples] if step_i == 1 else None
                
            for i, new_partial in enumerate(new_partials):
                if step_i == 1:
                    vote_count_list = {}

                    grouped = {}
                    if select_method == 'wo_persona':
                        for idx, (persona, strategy, text) in enumerate(new_partial):
                            group_id = (idx % n_seq) + 1
                            grouped.setdefault(group_id, []).append((persona, text))
                    else:
                        for persona, strategy, text in new_partial:
                            grouped.setdefault(persona, []).append((strategy, text))

                    for persona, strategy_texts in grouped.items():
                        yss = [[t for _, t in strategy_texts]]

                        temp_vote = get_batch_one_votes(
                            llm,
                            tokenizer,
                            xs=samples[i],
                            yss_list=yss,
                            format_prompt=format_prompt,
                            given_persona=persona,
                            author_persona=author_persona_list[i],
                            n_evaluate_sample=len(strategy_texts) * 3,
                        )
                        counts = temp_vote[0]
                        idx_sorted = sorted(range(len(counts)), key=lambda k: counts[k], reverse=True)
                        best_strategy, best_text = strategy_texts[idx_sorted[0]]
                        if select_method == 'wo_persona':
                            best_strategy = strategy

                        best_strategies[i][persona] = best_strategy
                        persona_best_texts[i][persona] = best_text
                        select_new_partials[i].append(best_text)

                        selected_persona_strategy_list[i].append({
                            'persona': persona, 
                            'strategy': best_strategy
                        })

                        vote_count_list[persona] = {
                            'counts': counts
                        }

                    vote_counts_list[i] = vote_count_list

                elif step_i == 2:
                    grouped_ys = [new_partial[j:j+n_seq] for j in range(0, len(new_partial), n_seq)]

                    temp_vote = []
                    for yss, persona in zip([[t for _, _, t in group] for group in grouped_ys], batch_personas):
                        temp = get_batch_counter_one_votes(
                            llm,
                            tokenizer,
                            xs=samples[i],
                            yss_list=[[t for _, _, t in group] for group in grouped_ys],
                            format_prompt=format_prompt,
                            given_persona_list=[group[0][0] for group in grouped_ys],
                            author_persona=author_persona_list[i],
                            n_evaluate_sample=n_seq * 3,
                        )
                        temp_vote += temp
                    vote_counts_list[i] += temp_vote

                    for j, group in enumerate(grouped_ys):
                        cands = group
                        if len(cands) != n_seq:
                            print('cands: Error!!')
                            exit(-1)
                        counts = temp_vote[j]
                        idx_sorted = sorted(range(len(counts)), key=lambda k: counts[k], reverse=True)
                        best1 = cands[idx_sorted[0]][2]
                        select_new_partials[i].append(best1)

        for i in range(len(samples)):
            if step_i == 0:
                info_dict = {
                    'step': step_i, 
                    'x': samples[i], 
                    'ys': partial_ys[i], 
                    'author_persona': author_persona_list[i] if author_persona_list else None, 
                    'values': vote_counts_list[i] if (vote_counts_list and vote_counts_list != []) else [], 
                    'select_strategies': selected_strategies[i]
                }
                infos.append(info_dict)

            elif step_i == 1:
                info_dict = {
                    'step': step_i, 
                    'prompt': list_of_messages[0],
                    'x': samples[i], 
                    'ys': partial_ys[i],
                    'new_ys': new_partials[i] if 'new_partials' in locals() else [], 
                    'author_persona': author_persona_list[i] if author_persona_list else None, 
                    'personas': batch_personas[i], 
                    'values': vote_counts_list[i] if (vote_counts_list and vote_counts_list != []) else [], 
                    'selected_persona_strategy': selected_persona_strategy_list[i] if selected_persona_strategy_list else [],
                    'select_new_ys': select_new_partials[i]
                }
                if 'cluster' in select_method:
                    info_dict.update({
                        'personas_clusters': selected_clusters_list[i] if selected_clusters_list else None
                    })
                infos.append(info_dict)
            else:
                infos.append({
                    'step': step_i, 
                    'prompt': list_of_messages[0],
                    'x': samples[i], 
                    'ys': partial_ys[i], 
                    'new_ys': new_partials[i] if 'new_partials' in locals() else [], 
                    'values': vote_counts_list[i] if (vote_counts_list and vote_counts_list != []) else [], 
                    'select_new_ys': select_new_partials[i]
                })
        
        partial_ys = select_new_partials

        try:
            del list_of_messages, batch_out, new_partials
        except:
            pass

        gc.collect()
        torch.cuda.empty_cache()

    return partial_ys, [{'steps': info} for info in infos]

def run(op_method, select_method, step, persona_n, gpu_num="0,1,2,3"):
    model_name = "meta-llama/Llama-3.1-8B-Instruct"
    llm = LLM(
        model=model_name,
        model_impl="transformers",
        max_model_len=8192,
        gpu_memory_utilization=0.90,#0.80,
        tensor_parallel_size=len(gpu_num.split(",")),
        task="generate"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")

    final_gens, logs, cnt_avg, cnt_any = [], [], 0, 0

    data_df = pd.read_pickle('./data/processed_multiple_test_data.pickle')

    data = [' '.join(conclusion) + ' ' + ' '.join(premises) for conclusion, premises in zip(data_df['conclusion'].tolist(), data_df['premises'].tolist())]
    post_id_list = data_df['post_id'].tolist()

    cluster_data = load_cluster_resources(n_components=50, min_cluster_size=200)

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    out_dir = os.path.join("./results/", timestamp)
    os.makedirs(out_dir, exist_ok=True)
    
    file = os.path.join(out_dir, f"tot_llama_{op_method}_personahub_{select_method}_step{step}_{persona_n}persona.jsonl")
    final_file = os.path.join(out_dir, f"tot_llama_{op_method}_personahub_{select_method}_step{step}_{persona_n}persona_gens.jsonl")

    batch_size = 4
    for batch_start in tqdm(range(0, len(data), batch_size)):
        batch_data = data[batch_start:batch_start + batch_size]
        batch_pids = post_id_list[batch_start:batch_start + batch_size]
        
        results = solve(
            llm,
            tokenizer,
            batch_data,
            cluster_data=cluster_data,
            s_id=batch_start,
            e_id=batch_start + batch_size,
            op_method=op_method,
            select_method=select_method,
            step=step,
            persona_n=persona_n,
            batch_pids=batch_pids,
        )
        final_ys_batch, final_infos_batch = results
    
        with open(file, 'a', encoding='utf-8') as f:
            for pid, entry in zip(batch_pids * 3, final_infos_batch):
                new_entry = {'post_id': pid, **entry}
                f.write(json.dumps(new_entry, ensure_ascii=False))
                f.write('\n')

        new_final_gens = []
        for i, (pid, final_ys) in enumerate(zip(batch_pids, final_ys_batch)):
            new_final_gens.append({
                'post_id': pid,
                'idx': batch_start + i,
                'ys': final_ys,
            })

        with open(final_file, 'a', encoding='utf-8') as f:
            for g in new_final_gens:
                f.write(json.dumps(g, ensure_ascii=False))
                f.write('\n')
        
        del final_infos_batch, new_final_gens
        gc.collect()

    if len(data) > 0:
        print(cnt_avg / len(data), cnt_any / len(data))

def print_run(op_method, select_method, step, persona_n, gpu_num):
    print("Running with:")
    print(f"  Operation Method: {op_method}")
    print(f"  Selection Method: {select_method}")
    print(f"  Step: {step} "
          f"({'Plan + Counter' if step == 3 else 'Counter only'})")
    print(f"  Number of Random Personas: {persona_n}")
    print(f"  GPU(s): {gpu_num}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Run persona-based counterargument generation.")

    parser.add_argument(
        "--op_method",
        choices=["op_persona", "wo_op_persona"],
        required=True,
        help="OP Persona LLM inference: 'op_persona' (infer) or 'wo_op_persona' (do not infer)"
    )

    parser.add_argument(
        "--select_method",
        choices=["random", "wo_persona", "cluster_random", "cluster_diff_case"],
        required=True,
        help="Persona selection method"
    )

    parser.add_argument(
        "--persona_n",
        type=int,
        choices=[1, 2, 3, 4, 5],
        default=3,
        help="Number of random personas to sample (n = 1..5). Default: 3"
    )

    parser.add_argument(
        "--step",
        type=int,
        choices=[2, 3],
        required=True,
        help="2: Counter only (no plan), 3: Plan + Counter"
    )

    parser.add_argument(
        "--gpu_num",
        default="0,1,2,3",
        help="Comma-separated GPU numbers to use (e.g., '0', '1,2', or '0,1,2,3')"
    )

    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num

    print_run(
        op_method=args.op_method,
        select_method=args.select_method,
        step=args.step,
        persona_n=args.persona_n,
        gpu_num=args.gpu_num,
    )

    run(
        op_method=args.op_method,
        select_method=args.select_method,
        step=args.step,
        persona_n=args.persona_n,
        gpu_num=args.gpu_num,
    )

    exit(0)