#!/usr/bin/env python

"""
This script generates synthetic data by playing a collection of chats, saving them to csv files
and combining them into a dataset of labeled AAI answers and embeddings.
"""

import argparse
import json
import os
import time

import pandas as pd
from tqdm import tqdm

from attachment_style.common.config import settings
from synthetic_agents.app.playable_chat import PlayableChat
from synthetic_agents.common.constants import AI_CHAT_USER_AGENT_TYPE

parser = argparse.ArgumentParser(
    description="Plays chats between user and a coach agents. The chat will play and messages "
    "will be persisted to the database as in the online chat. By the end of the chat"
    "messages will be persisted to individual CSV files. A single CSV file containing "
    "user answers from all the chats labeled by their synthetic_agents style is also "
    "generated in the end of this program."
)
parser.add_argument(
    "--chat_ids_json_filepath",
    type=str,
    required=True,
    help="Path to the json file containing chat IDs to play per synthetic_agents style.",
)
parser.add_argument(
    "--model",
    type=str,
    required=False,
    help="Which underlying LLM to use to use with the user agents. This will override the default "
    "model attributed to the chat.",
)
parser.add_argument(
    "--override",
    type=int,
    required=False,
    default=0,
    help="If we want to override an existing dataset in the same destiny folder or if we want to "
    "resume and skip chats already played.",
)

args = parser.parse_args()

with open(args.chat_ids_json_filepath, "r") as f:
    chat_dict = json.load(f)

os.makedirs(settings.datasets_dir, exist_ok=True)
final_filepath = f"{settings.datasets_dir}/synthetic_dataset_{args.model}.csv"
if not bool(args.override) and os.path.exists(final_filepath):
    dataset = [pd.read_csv(final_filepath)]
else:
    dataset = []

if len(dataset) > 0:
    # Remove chats already processed from the list
    for attachment_style in chat_dict.keys():
        chat_dict[attachment_style] = [
            chat_id
            for chat_id in chat_dict[attachment_style]
            if f"{chat_id}_{args.model}" not in dataset[-1]["interview_id"].values
        ]

print(chat_dict)

for attachment_style in tqdm(list(chat_dict.keys()), desc="Attachment Style"):
    for chat_id in tqdm(chat_dict[attachment_style], position=1, leave=False, desc="Chat"):
        interview_id = f"{chat_id}_{args.model}"

        chat = PlayableChat(chat_id=chat_id, seconds_between_messages=1, llm_name=args.model)
        chat.initialize_chat()
        chat.play(number_messages=40, progress_bar=tqdm(range(40), position=2, leave=False))

        df = chat.to_data_frame()
        user_messages_df = df[df["agent_type"] == AI_CHAT_USER_AGENT_TYPE]
        interviewer_messages_df = df[df["agent_type"] != AI_CHAT_USER_AGENT_TYPE]
        num_questions = len(user_messages_df)
        dataset_dict = dict(
            attachment_style=[attachment_style] * num_questions,
            interview_id=[interview_id] * num_questions,
            answer=user_messages_df["message_content"].values,
            question_number=list(range(1, num_questions + 1)),
            question=interviewer_messages_df["message_content"].values,
        )
        dataset_df = pd.DataFrame(dataset_dict)
        dataset.append(dataset_df)

        dataset = pd.concat(dataset).reset_index(drop=True)

        # Save partial results
        dataset.to_csv(final_filepath, index=False)
        dataset = [dataset]

        # Wait a second before starting the next chat not to overload the DB.
        time.sleep(1)
