#!/usr/bin/env python

"""
This script modifies the local synthetic and real datasets to include columns with embeddings for
each AAI answer and question.
"""

import pandas as pd
from langchain_openai.embeddings import OpenAIEmbeddings

from attachment_style.common.config import settings
from attachment_style.common.constants import (
    CLAUDE_3_OPUS_SYNTHETIC_DATASET_PATH,
    GPT4_SYNTHETIC_DATASET_PATH,
    OPEN_AI_EMBEDDING_MODEL,
)

if __name__ == "__main__":
    dataset_paths = [
        GPT4_SYNTHETIC_DATASET_PATH,
        CLAUDE_3_OPUS_SYNTHETIC_DATASET_PATH,
    ]

    for filepath in dataset_paths:
        transcripts = pd.read_csv(filepath)

        embedding_model = OpenAIEmbeddings(
            model=OPEN_AI_EMBEDDING_MODEL,
            api_key=settings.open_ai_api_key
        )

        transcripts["open_ai_embedding"] = list(
            embedding_model.embed_documents(transcripts["answer"].astype(str)))

        transcripts["open_ai_embedding_question"] = list(
            embedding_model.embed_documents(transcripts["question"].astype(str)))

        transcripts.to_csv(filepath, index=False)
