import os
import logging
import pandas as pd

from data.data_loader import load_bench_data
from tools.openai import row_create, prompt_to_message
from tools.utils import chat_history, extract_questions

logger = logging.getLogger("rich")

def create_checklist(
    dataset_name: str,
    output_file: str = None,
    model: str = 'gpt-4o',
    data_dir: str = "data/",
    config_dir: str = "config/",
    task_id: str = ""
) -> str:
    """
    Generate checklist creating batch file for the given dataset.
    """
    template = open(os.path.join(config_dir, "template", "create.md"), 'r').read()
    data = load_bench_data(dataset_name=dataset_name)
    data["history"] = data["conversation_input"].apply(chat_history)
    data["user_query"] = data["conversation_input"].apply(lambda x: x[-1]["content"])
    data["reference_response"] = data["references"].apply(lambda x: x["gpt-4"])
    data['prompt'] = data.apply(lambda x: template.format(**x), axis=1)
    data['messages'] = data['prompt'].apply(prompt_to_message)
    batch_data = data.apply(
        lambda x: row_create(
            model=model,
            custom_id=x["session_id"],
            messages=x["messages"],
            temperature=0.7,
            max_tokens=1024,
            top_p=0.95
        ), axis=1)
    if not output_file:
        output_file = os.path.join(data_dir, "batch", f"{task_id}_checklist.batch_submission.jsonl")
    batch_data.to_json(output_file, orient='records', lines=True)
    logger.info(f"""Batch file output to "{output_file}" """)
    return output_file


def parse_checklist(
    dataset_name: str,
    input_file: str,
    data_dir: str = "data/",
) -> None:
    """
    Parse the checklist creating batch results and save to the target directory.
    """
    data = pd.read_json(input_file, lines=True, orient='records')
    data['content'] = data['response'].apply(lambda x: [e['message']['content'] for e in x['body']['choices']])
    data['checklist'] = data['content'].apply(extract_questions)
    data['session_id'] = data['custom_id']
    data['checklist'] = data['checklist'].apply(lambda x: x[0])
    output_dir = os.path.join(data_dir, dataset_name, "checklist")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, "checklist.jsonl")
    data.loc[:, ['session_id', 'checklist']].to_json(output_file, orient='records', lines=True)
    logger.info(f"""Checklist output to "{output_file}" """)
    return
