import argparse
import json
import time
import pdb

from llm_info_extraction import LLM_info_extraction, parse_llm_output
from message_splitter import split_session_to_json_lines


def process_jsonl_file(
    input_file, output_file, model_call_mode="online_api", max_retries=30, **kwargs
):
    """
    Process all sessions in a JSONL file and save results to output file.

    Args:
        input_file (str): Path to input JSONL file
        output_file (str): Path to output JSONL file
        model_call_mode (str): Either "online_api" or "local_vllm"
        max_retries (int): Maximum number of retries for LLM calls
        **kwargs: Additional parameters for API calls

    Returns:
        str: Success message or error information
    """
    try:
        # Read and process each session
        with open(input_file, "r", encoding="utf-8") as infile, open(
            output_file, "w", encoding="utf-8"
        ) as outfile:
            for line_num, line in enumerate(infile, 1):
                if line.strip():
                    # try:
                    session = json.loads(line)
                    print(
                        f"Processing session {session.get('session_id', 'unknown')} (line {line_num})..."
                    )

                    # Process the session
                    processed_lines = process_session(
                        session, model_call_mode, max_retries, **kwargs
                    )
                    for processed_line in processed_lines:
                        outfile.write(processed_line + "\n")


        return f"Successfully processed. Results saved to {output_file}"

    except Exception as e:
        return f"Error processing JSONL file: {str(e)}"


def process_session(session, model_call_mode="online_api", max_retries=30, **kwargs):
    """
    Pipeline function that splits messages into rounds and extracts info from each round's remaining chat.

    Args:
        session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys
        model_call_mode (str): Either "online_api" or "local_vllm"
        max_retries (int): Maximum number of retries for LLM calls
        **kwargs: Additional parameters for API calls

    Returns:
        list: List of JSON strings with added "info_set" key, or error information
    """
    # try:
    # Step 1: Split messages into JSON lines
    json_lines = split_session_to_json_lines(session)
    trajectory = session["trajectory"]
    # Step 2: Process each JSON line with LLM info extraction
    processed_lines = []
    question_information_map = {}
    for i, line in enumerate(json_lines):
        data = json.loads(line)
        if i < len(json_lines)-1:
            future_steps = trajectory[i][1]
            remaining_chat = data.get("remaining_chat", "")
            info_set = data.get("info_set", [])
        else:
            info_set = [stored_information for stored_information in question_information_map.values()]

        data["info_set"] = info_set
        processed_lines.append(json.dumps(data, ensure_ascii=False))
    return processed_lines

    # except Exception as e:
    #     return f"Pipeline error: {str(e)}"


# Example usage:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input_file", type=str, default="/data/train_raw.jsonl"
    )
    parser.add_argument(
        "--output_file", type=str, default="/data/train_processed.jsonl"
    )
    parser.add_argument(
        "--model_call_mode", type=str, choices=["online_api", "local_vllm"], default="local_vllm"
    )
    # parser.add_argument("--model_path", type=str, required=True)
    args = parser.parse_args()
    print(
        process_jsonl_file(
            input_file=args.input_file,
            output_file=args.output_file,
            model_call_mode=args.model_call_mode,
            # model_path=args.model_path,
            # Additional parameters for API calls
        )
    )
