from datasets import load_dataset
from litellm import acompletion
import litellm
import asyncio
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm.asyncio import tqdm
from typing import Optional, Dict
from datasets import Dataset
from pdb import set_trace
import click
import os
import random
from dotenv import load_dotenv
from value import values
load_dotenv()
import re
import json
import uuid
SHARD_SIZE = 1000000
MAX_QUEUE_SIZE = 100
litellm.enable_json_schema_validation=True

# persons:list[str]
with open("Datasets/person_500.json", 'r', encoding='utf-8') as file:
    persons_json = json.load(file)
persons = [d["description"] for d in persons_json]

MAX_REWRITE_NUM2 = 200
MAX_REWRITE_NUM = 300
MAX_SAMPLES = 3
PERSONS_SAMPLE_NUM = 5
# task1
with open("prompt/exp2_generate_prompt_novalues.txt", 'r') as file:
    GENERATE_QUESTION = file.read()
GENERATE_QUESTION_pattern = r'\{\n\s+"question": "(.*?)"'

with open("prompt/task2_generate_answer.txt", 'r') as file:
    GENERATE_ANSWER = file.read()

# task 3
with open("prompt/exp3_rewrite_respond2.txt", 'r') as file:
    REWRITE_ANSWER = file.read()
answer_pattern = r'(?<="New_Answer": ")[^"]*"'

with open("prompt/exp3_rewrire_evaluate.txt", 'r') as file:
    EVALUATE_ANSWER = file.read()
evaluate_pattern = r'(?<="levels": ")[^"]*"'

agree_level_to_score = {
    'Strongly disagree':-2,
    'Somewhat disagree':-1, 
    'Neutral':0, 
    'Somewhat agree':1, 
    'Strongly agree':2
}

AGREE_LEVELS = list(agree_level_to_score.keys())

AGREE_LEVEL_THRESH = 0.45

def extract_first_number(text):
    match = re.search(r'[+-]?\d+(\.\d+)?', text)
    if match:
        return match.group(0)
    else:
        return None 

async def async_completion(
    prompt, model: str, api_base: Optional[str] = None, task = None
):
    
    try:
        messages = [
            {"role": "user", "content": prompt}
        ]
        if 'Meta' in model and '3' in model:
            response = await acompletion(
                model=model,
                messages=messages,
                api_base=api_base,
                num_retries=3,
                stop=["<|eot_id|>", "<|end_of_text|>"],
                temperature=1.0
            )
        else:
            response = await acompletion(
                model=model,
                messages=messages,
                api_base=api_base,
                num_retries=3,
                temperature=1.0,
            )
        return response
    except Exception as e:
        print(e)
        results = None

    return results


async def producer(queue, examples):
    for example in examples:
        await queue.put(example)
    print("Producer: Finished adding examples to the queue.")


async def save_shard(data, output_path):
    """Asynchronously save a shard of data to parquet format"""
    df = pd.DataFrame(data)
    table = pa.Table.from_pandas(df)
    shard_file_path = output_path

    # Use run_in_executor to handle file I/O asynchronously
    loop = asyncio.get_event_loop()
    await loop.run_in_executor(None, pq.write_table, table, shard_file_path)


async def consumer(queue, model, api_base, pbar, task, output_path):
    while True:
        try:
            if task == "generate_question":
                item_id, statement= await queue.get()
                prompt = GENERATE_QUESTION.replace("{answer}", statement)
                for _ in range(MAX_SAMPLES):
                    question = ""
                    results = await async_completion(prompt, model, api_base=api_base)
                    text = results["choices"][0]['message']['content'] 
                    try:
                        match = re.search(GENERATE_QUESTION_pattern, text)
                        if match:
                            question = match.group(1)
                            break
                    except:
                        continue
                data = {
                'item_id':item_id, 
                'statement':statement, 
                'question':question
                }
                unique_id = item_id
            elif task == "generate_answer":
                item_id, question, value_infor = await queue.get()
                data = {
                    'question':question,
                    'item_id':item_id, 
                }
                prompt = GENERATE_ANSWER.replace("{description}", value_infor['description'])
                prompt = prompt.replace("{question}", question)
                prompt = prompt.replace("{value}", value_infor['value'])
                for l in AGREE_LEVELS:
                    response = None
                    prompt_l = prompt.replace("{agree_level}", l)
                    results = await async_completion(prompt_l, model, api_base=api_base)
                    response = results["choices"][0]['message']['content']
                    data[l] = response
                unique_id = item_id
            elif task == "rewrite_answer":
                item_id, question, answers, value_infor = await queue.get()
                data = {
                    'question': question,
                    'item_id': item_id,
                }
                persona_infor = random.sample(persons, PERSONS_SAMPLE_NUM)
                for l in AGREE_LEVELS:
                    data[l] = {}
                    best_score_diff = float('inf')
                    best_result = None
                    new_answer = answers[l]
                    rewrite_times = 0
                    for _ in range(MAX_REWRITE_NUM):
                        person_scores = []
                        person_scores_dicts = []
                        for per in persona_infor:
                            person_prompt = "USER: Let's role-play. I will ask you a question and you must give me an answer. I want you to act as the person described below. Think from the person's perspective.\n"
                            person_prompt += f"{per} \n"
                            person_prompt += "Use the given information to answer the question below.\n"
                            scores_dict = {}
                            scores = []
                            for sub_value in value_infor['sub_values']:
                                sub_value_prompt = EVALUATE_ANSWER.replace("{grained_description}", sub_value['description'])
                                sub_value_prompt = sub_value_prompt.replace("{description}", value_infor['description'])
                                sub_value_prompt = sub_value_prompt.replace("{value}", value_infor['value'])
                                sub_value_prompt = sub_value_prompt.replace("{statement}", new_answer)
                                full_prompt = person_prompt + sub_value_prompt
                                for _ in range(MAX_SAMPLES):
                                    results = await async_completion(full_prompt, model, api_base=api_base)
                                    evaluate_response = results["choices"][0]['message']['content']
                                    if "disagree" in evaluate_response.lower():
                                        scores_dict[sub_value['sub_value']] = -1
                                        scores.append(-1)
                                        break
                                    elif "neutral" in evaluate_response.lower():
                                        scores_dict[sub_value['sub_value']] = 0
                                        scores.append(0)
                                        break
                                    elif "agree" in evaluate_response.lower():
                                        scores_dict[sub_value['sub_value']] = 1
                                        scores.append(1)
                                        break
                            person_scores_dicts.append(scores_dict)
                            if len(scores) == 0:
                                import pdb
                                pdb.set_trace()
                            person_scores.append(float(sum(scores)/len(scores)))
                        ave_person_score = float(sum(person_scores)/len(person_scores)) * 2
                        score_diff = abs(ave_person_score - agree_level_to_score[l])
                        if score_diff < best_score_diff:
                            best_score_diff = score_diff
                            best_result = {
                                'answer': new_answer,
                                'persons_infor': {
                                    'persona_infor': persona_infor,
                                    'person_scores': person_scores_dicts
                                },
                                'score': ave_person_score,
                            }
                        if score_diff <= AGREE_LEVEL_THRESH:
                            best_result['rewrite_times'] = rewrite_times
                            data[l] = best_result
                            break
                        else:
                            prompt = REWRITE_ANSWER.replace("{description}", value_infor['description'])
                            prompt = prompt.replace("{question}", question)
                            prompt = prompt.replace("{value}", value_infor['value'])
                            prompt = prompt.replace("{agree_level}", l)
                            prompt = prompt.replace("{answer}", new_answer)
                            for _ in range(MAX_SAMPLES):
                                results = await async_completion(prompt, model, api_base=api_base)
                                response = results["choices"][0]['message']['content']
                                if response:
                                    new_answer = response
                                    break
                            rewrite_times += 1
                    if bool(data[l]) == False:
                        best_result['rewrite_times'] = rewrite_times
                        data[l] = best_result
                unique_id = item_id
            elif task == "rewrite_answer2":
                item_id, answer, question, agree_level, value_infor = await queue.get()
                data = {
                    'question': question,
                    'item_id': item_id,
                }
                persona_infor = random.sample(persons, PERSONS_SAMPLE_NUM)
                data['answer'] = {}
                data['history'] = [answer]
                best_score_diff = float('inf')
                best_result = None
                new_answer = answer
                rewrite_times = 0
                for _ in range(MAX_REWRITE_NUM2):
                    person_scores = []
                    person_scores_dicts = []
                    for per in persona_infor:
                        person_prompt = "USER: Let's role-play. I will ask you a question and you must give me an answer. I want you to act as the person described below. Think from the person's perspective.\n"
                        person_prompt += f"{per} \n"
                        person_prompt += "Use the given information to answer the question below.\n"
                        scores_dict = {}
                        scores = []
                        for sub_value in value_infor['sub_values']:
                            sub_value_prompt = EVALUATE_ANSWER.replace("{grained_description}", sub_value['description'])
                            sub_value_prompt = sub_value_prompt.replace("{description}", value_infor['description'])
                            sub_value_prompt = sub_value_prompt.replace("{value}", value_infor['value'])
                            sub_value_prompt = sub_value_prompt.replace("{statement}", new_answer)
                            full_prompt = person_prompt + sub_value_prompt
                            for _ in range(MAX_SAMPLES):
                                results = await async_completion(full_prompt, model, api_base=api_base)
                                evaluate_response = results["choices"][0]['message']['content']
                                if "disagree" in evaluate_response.lower():
                                    scores_dict[sub_value['sub_value']] = -1
                                    scores.append(-1)
                                    break
                                elif "neutral" in evaluate_response.lower():
                                    scores_dict[sub_value['sub_value']] = 0
                                    scores.append(0)
                                    break
                                elif "agree" in evaluate_response.lower():
                                    scores_dict[sub_value['sub_value']] = 1
                                    scores.append(1)
                                    break
                        person_scores_dicts.append(scores_dict)
                        if len(scores) == 0:
                            import pdb
                            pdb.set_trace()
                        person_scores.append(float(sum(scores)/len(scores)))
                    ave_person_score = float(sum(person_scores)/len(person_scores)) * 2
                    score_diff = abs(ave_person_score - agree_level_to_score[agree_level])
                    if score_diff < best_score_diff:
                        best_score_diff = score_diff
                        best_result = {
                            'answer': new_answer,
                            'persons_infor': {
                                'persona_infor': persona_infor,
                                'person_scores': person_scores_dicts
                            },
                            'score': ave_person_score,
                            'target_score': agree_level_to_score[agree_level]
                        }
                        data['answer'] = best_result
                    if score_diff <= AGREE_LEVEL_THRESH:
                        best_result['rewrite_times'] = rewrite_times
                        data['answer'] = best_result
                        print('finish')
                        break
                    else:
                        prompt = REWRITE_ANSWER.replace("{description}", value_infor['description'])
                        prompt = prompt.replace("{question}", question)
                        prompt = prompt.replace("{value}", value_infor['value'])
                        prompt = prompt.replace("{agree_level}", agree_level)
                        prompt = prompt.replace("{answer}", new_answer)
                        for _ in range(MAX_SAMPLES):
                            results = await async_completion(prompt, model, api_base=api_base)
                            response = results["choices"][0]['message']['content']
                            if response:
                                new_answer = response
                                data['history'].append(new_answer)
                                break
                        rewrite_times += 1
                if bool(data['answer']) == False:
                    best_result['rewrite_times'] = rewrite_times
                    data['answer'] = best_result
                unique_id = str(item_id)+str(uuid.uuid4())
            else:
                raise ValueError("Invalid task")
            single_data = [data]
            single_dir = f"{os.path.dirname(output_path)}/{value_infor['value']}/"
            single_file = f"{os.path.dirname(output_path)}/{value_infor['value']}/item_{unique_id}.parquet"
            if not os.path.exists(single_dir):
                os.makedirs(single_dir)
            await save_shard(single_data, single_file)
            pbar.update(1)
            queue.task_done()
        except asyncio.CancelledError:
            print("Consumer: Task cancelled.")
            break


async def run(
    examples,
    model,
    output_path: str,
    concurrency: int = 5,
    api_base: Optional[str] = None,
    task_name = None
):
    input_queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
    pbar = tqdm(total=len(examples), desc='Processing examples', position=0, unit='examples')
    producer_task = asyncio.create_task(producer(input_queue, examples))
    consumer_tasks = [
        asyncio.create_task(consumer(input_queue, model, api_base, pbar, task_name, output_path))
        for _ in range(concurrency)
    ]
    await producer_task
    await input_queue.join()
    for task in consumer_tasks:
        task.cancel()
    pbar.close()
@click.command()
@click.argument('model_name')
@click.option('--concurrency', default=10, help='The number of concurrent requests to make to the model.')
@click.option('--api_base', default="http://localhost:8811/v1", help='The base URI for the model.')
@click.option('--seed', default=2025, help='The random seed.')
@click.option('--output_path', default='Datasets/test', help='The path to save the output dataset.')
@click.option('--dataset', default='Datasets/test', help='The path to save the output dataset.')
@click.option('--task_name', default='generate_questions', help='The path to save the output dataset.')
def main(model_name, concurrency, api_base, seed, output_path, dataset, task_name):
    import random
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)

    print(f'==> input_dataset: {dataset}')
    nnn = 10000
    if '.parquet' in dataset: 
        df = pd.read_parquet(dataset)
        if df.shape[0] >= nnn:
            df = df.iloc[:nnn,:]
        ds = Dataset.from_pandas(df)
    else:
        ds = pd.read_csv(dataset,index_col=0)
        if ds.shape[0] >= nnn:
            ds = ds.iloc[:nnn,:]
        ds = Dataset.from_pandas(ds)
    examples = []
    tqb = tqdm(total=ds.shape[0], desc='questions', position=0, unit='examples')
    for i, example in enumerate(ds):
        if task_name == 'rewrite_answer2':
            value = example['value']
            examples.append((example['index'], example['answer'], example['question'], example['agree_level'], values[value]))
        tqb.update(1)
    asyncio.run(
        run(
            examples,
            f'openai/{model_name}' if (model_name.startswith('Qwen') or "Llama" in model_name or "gemma") else model_name,
            output_path=output_path,
            concurrency=concurrency,
            api_base=api_base,
            task_name=task_name
        )
    )

if __name__ == '__main__':
    main()