import argparse
import json
import random
from llm_info_extraction import LLM_summaraize
from tqdm import tqdm
random.seed(42)

def process_message(json_obj, **kwargs):
    info_set = json_obj.get("info_set")
    # info_set_str = ", ".join(info_set) if isinstance(info_set, list) else ""
    # If remaining_chat has no user messages, it means the conversation has ended (assistant output summarize token)
    # If remaining_chat has user messages, it means the conversation continues
    if "user: " not in json_obj["remaining_chat"]:
        decision_str = "stop"
        info_set = LLM_summaraize(json_obj, **kwargs)
    else:
        decision_str = "continue"
    if not info_set and decision_str == "continue":
        if_keep = False
    else:
        if_keep = True
    return if_keep, info_set, decision_str

def add_stop_sample(data, outfile, **kwargs):
    finegrained_query = LLM_summaraize(data, **kwargs)
    stop_sample = {
        "cid": f"{data['cid']}_sum",
        "session_id": data["session_id"],
        "topic": data.get("topic", ""),
        "messages": data["messages"], 
        "decision_truth": "stop", 
        "finegrained_query": finegrained_query, 
    }
    
    outfile.write(json.dumps(stop_sample, ensure_ascii=False) + "\n")


def main(input_file_path, output_file_path, **kwargs):
    with open(input_file_path, "r", encoding="utf-8") as infile:
        print("data processing started...")
        clarify_list = []
        summary_list = []
        for line in tqdm(infile):
            data = json.loads(line.strip())
            if_keep, info_set, decision = process_message(data, **kwargs)
            if not if_keep:
                continue
            cid_parts = data["cid"].split('_')
            if decision != "stop":
                new_item = {
                    "cid": data["cid"],
                    "session_id": data["session_id"],
                    "topic": data.get("topic", ""),  # Use topic instead of diagn
                    "messages": data["messages"],
                    "decision_truth": decision,
                    "info_truth": info_set,
                }
                clarify_list.append(new_item)
            else:
                session_id = f"{cid_parts[0]}_{cid_parts[1]}"
                summary_list.append({
                    "id": int(cid_parts[0]),
                    "session_id": session_id,
                    "messages": data["messages"],
                    "prompt": info_set,
                })
        random.shuffle(clarify_list)
        with open(output_file_path, 'w', encoding='utf-8') as f:
            for item in clarify_list:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        with open(kwargs.get("output_summary_file", "summary_merge.jsonl"), 'w', encoding='utf-8') as f:
            for item in summary_list:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
    print("job done!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--input_file", type=str, default="/data/train_processed.jsonl"
    )

    # The final file for training or testing
    parser.add_argument("--output_file", type=str, default="/data/train.jsonl")
    parser.add_argument("--output_summary_file", type=str, default="/data/train_summary.jsonl")

    parser.add_argument("--model_path", type=str, required=True)
    args = parser.parse_args()

    main(args.input_file, args.output_file, output_summary_file=args.output_summary_file, model_path=args.model_path)
