import datasets
import re
import argparse
import time
from tqdm import tqdm


def init_client():
    client = None
    # Initialize the client
    return client


def process_conversations(train_data, few_shot_text, client, output_path):
    question = "Choose the most helpful and honest response. Give a one-sentence explanation for your answer."
    cnt = 0
    new_dataset_list = []

    for i, data in tqdm(enumerate(train_data), total=len(train_data), dynamic_ncols=True):
        chosen = data["chosen"]
        rejected = data["rejected"]

        chosen_splits = re.split(r"(?=\bHuman:|\bAssistant:)", chosen)
        rejected_splits = re.split(r"(?=\bHuman:|\bAssistant:)", rejected)

        if len(chosen_splits) != len(rejected_splits):
            cnt += 1
            continue

        if not chosen_splits[-1].startswith("Assistant:"):
            cnt += 1
            continue

        chosen_prompt = "".join(chosen_splits[:-1])
        rejected_prompt = "".join(rejected_splits[:-1])

        if chosen_prompt != rejected_prompt:
            cnt += 1
            continue

        chosen_prompt = re.sub(r"Human:", "H:", chosen_prompt)
        chosen_prompt = re.sub(r"Assistant:", "A:", chosen_prompt).strip()
        chosen_response = re.sub(r"Assistant:", "", chosen_splits[-1], count=1).strip()
        rejected_response = re.sub(r"Assistant:", "", rejected_splits[-1], count=1).strip()

        message_text = few_shot_text + "\n\n"
        message_text += f"""Consider the following conversation between a human (H) and an assistant (A).
        
<start_conversation>
{chosen_prompt}
<end_conversation>

{question}

Response A: {chosen_response}
Response B: {rejected_response}

Explanation:
Answer:
"""
        while True:
            # Query the model
            try:
                completion = client.chat.completions.create(
                    model="",
                    seed=123,
                    messages=[
                        {
                            "role": "system",
                            "content": """ You are a careful, helpful and diligent assistant. 
                         Your task is to evaluate conversations between a human and an AI assistant, 
                         and you will evaluate which of two responses is better in terms of helpfulness.""",
                        },
                        {"role": "user", "content": message_text},
                    ],
                )
                text = completion.choices[0].message.content
                answer = 3  # Default to invalid response

                if "Answer:" in text:
                    parts = text.split("Answer:")
                    answer_part = parts[1].strip()  # Extract the part after 'Answer:' and strip leading/trailing spaces

                # Use regular expressions to detect the answer
                match = re.findall(
                    r"\(?([AB])\)?", answer_part
                )  # Find all occurrences of A or B, with or without parentheses

                if match:
                    unique_answers = set(match)  # Create a set of unique answers

                    if len(unique_answers) == 1:  # If there's only one unique answer (either A or B)
                        if "A" in unique_answers:
                            answer = 1
                        elif "B" in unique_answers:
                            answer = 2
                    elif len(unique_answers) > 1:  # If both A and B are detected
                        answer = 3

                # break out of the loop after successfully querying the model
                break

            except ValueError:
                text = "ValueError"
                answer = 4
                break  # break out of the loop after encountering a ValueError

            except Exception as e:
                # Print the error and retry after a delay
                print(f"Error: {e}. Retrying in 60 seconds...")
                time.sleep(60)

        # Prepare new data entry
        new_train_data = data.copy()
        new_train_data["answer_llm"] = answer
        new_train_data["message_llm"] = text
        new_dataset_list.append(new_train_data)

        # save to disk every 1000 iterations
        if i % 200 == 199:
            new_dataset = datasets.Dataset.from_list(new_dataset_list)
            new_train_data_path = f"{output_path}_{str(i)}"
            new_dataset.save_to_disk(new_train_data_path)
            new_dataset_list = []  # reset new_dataset_list after saving

        # if i is the last iteration, save the remaining data
        if i == len(train_data) - 1:
            new_dataset = datasets.Dataset.from_list(new_dataset_list)
            new_dataset.save_to_disk(f"{output_path}_remaining")

        # if i is the last iteration, merge all the saved datasets and remaining dataset into one, save it to output_path
        if i == len(train_data) - 1:
            new_dataset_list = []
            for j in range(199, len(train_data), 200):
                new_dataset = datasets.load_from_disk(f"{output_path}_{str(j)}")
                new_dataset_list.extend(new_dataset)
            new_dataset = datasets.load_from_disk(f"{output_path}_remaining")
            new_dataset_list.extend(new_dataset)
            new_dataset = datasets.Dataset.from_list(new_dataset_list)
            new_dataset.save_to_disk(output_path)


def parse_args():
    parser = argparse.ArgumentParser()
    # Adding the required arguments
    parser.add_argument(
        "--train_data_path",
        type=str,
        default="",
        help="Path to the training dataset.",
    )
    parser.add_argument(
        "--few_shot_path",
        type=str,
        default="",
        help="Path to the few-shot text file.",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default="",
        help="Output path to save the processed dataset.",
    )
    return parser.parse_args()


def main():
    args = parse_args()

    # Initialize the model
    client = init_client()

    # Load the dataset
    train_data = datasets.load_from_disk(args.train_data_path)

    # Read few-shot examples
    few_shot_text = open(args.few_shot_path, "r").read()

    # Process the conversations and save the validated dataset
    process_conversations(train_data, few_shot_text, client, args.output_path)


if __name__ == "__main__":
    main()
