#!/usr/bin/env python
"""
estimate_tokens.py

This script iterates over all JSON files in a given directory, replicates the prompt construction
for each file (as done in your clinical reasoning code), and uses tiktoken (falling back on simple word
splitting if necessary) to estimate the total number of input tokens.
"""

import os
import json
import glob
import argparse
import datetime
import random

# Use tiktoken for token counting if available
def count_tokens(text, encoding_name="gpt2"):
    try:
        import tiktoken
        encoding = tiktoken.get_encoding(encoding_name)
        tokens = encoding.encode(text)
        return len(tokens)
    except Exception as e:
        # Fallback: simple whitespace splitting
        return len(text.split())

# A minimal version of the DataProcessor class to replicate parts of your pipeline.
class DataProcessor:
    def __init__(self, file_path):
        self.file_path = file_path
        self.data = self.load_data()

    def load_data(self):
        with open(self.file_path, "r") as f:
            data = json.load(f)
        return data

    def get_timeline(self):
        timeline = self.data.get("timeline", [])
        # Skip any event whose "source" contains "ICD"
        filtered = [ev for ev in timeline if "ICD" not in ev.get("source", "")]
        return filtered

    def event_time_to_datetime(self, t):
        try:
            return datetime.datetime.fromisoformat(t)
        except Exception:
            return None

    def split_timeline_by_count(self):
        timeline = self.get_timeline()
        if not timeline:
            return None, None
        # Sort timeline by event time (using a fallback for invalid times)
        sorted_timeline = sorted(
            timeline, key=lambda ev: self.event_time_to_datetime(ev.get("time")) or datetime.datetime.min
        )
        n = len(sorted_timeline)
        mid_index = n // 2
        past = sorted_timeline[:mid_index]
        future = sorted_timeline[mid_index:]
        return past, future

    def get_relevant_misc(self):
        misc = self.data.get("misc", {})
        relevant = {}
        if "Patient" in misc:
            relevant["Patient"] = misc["Patient"]
        if "ED_triage" in misc:
            relevant["ED_triage"] = misc["ED_triage"]
        return relevant

# Dictionary with action space descriptions (as in your code)
CATEGORY_DESCRIPTIONS = {
    "Diagnosis Assistance": (
        "This category is dedicated to clinical reasoning around identifying the correct diagnosis. It "
        "encompasses the following subcategories: 1) Condition Identification, where the model flags likely "
        "conditions (e.g., sepsis, acute kidney injury) by recognizing abnormal vitals, lab trends, and clinical "
        "notes; 2) Differential Diagnosis Formulation, which requires generating a ranked list of possible conditions "
        "to capture the true illness; and 3) Rare Disease Detection, where unusual patterns trigger alerts for "
        "less common conditions. Example: 'What is the most likely diagnosis for this patient?'"
    ),
    "Treatment Recommendations": (
        "This category focuses on generating questions that guide appropriate therapeutic interventions based on "
        "the patient’s clinical history. It includes subcategories such as Medication Suggestions, Dosage and "
        "Parameter Adjustments, and Treatment Escalation Decisions. Example: 'Based on the patient’s timeline, what "
        "medication(s) would you recommend?'"
    ),
    "Procedural Decision Making": (
        "This category is aimed at clinical decision-making related to diagnostic and interventional procedures. It "
        "includes subcategories like Lab Test and Imaging Orders, Recommendations for Surgical or Invasive "
        "Interventions, and ICU Admission or Transfer Decisions. Example: 'Which diagnostic procedure(s) should be "
        "ordered for further evaluation?'"
    )
}

def build_prompt(past, future, misc, action_space_category, description):
    """
    Build the prompt string as in your generate_qa_pair function.
    """
    prompt = f"""
You are a clinical reasoning simulator LLM. Your goal is to generate a question-answer pair for training a smaller LLM in the context of reasoning about clinical decision support.
--------------------
PAST DATA:
{json.dumps(past, indent=2)}
{json.dumps(misc, indent=2)}
--------------------

Action Space Category: {action_space_category}
Action Space Description: {description}

Using the provided data and the context of the action space, generate a question based on the past data of the patient that a clinical agent might need to answer for clinical decision support along with the corresponding answer (as found in the future data of the patient).
The goal is to give the smaller LLM only the past data and the question and make it reason about possible answers, which can then be verified using the future data.
Use the future data of the patient to generate your question:

--------------------
FUTURE DATA:
{json.dumps(future, indent=2)}
--------------------
"""
    json_prompt = """
Provide your final answer output in JSON format with the following fields:
{
    "question": <string>,
    "final_answer": <string>,
    "action_space_category": <string>,
    "action_space_subcategory": <string or null>,
    "source": <dict>
}

Thus, your response should look like this:

<think>Reasoning process for question generation</think><answer>{"question": "What is the most likely diagnosis for this patient?", "final_answer": "The patient likely has sepsis based on the clinical presentation and lab results.", "action_space_category": "Diagnosis Assistance", "action_space_subcategory": "Condition Identification", "source": {"event": "EVENT NAME", "time": "time", "source": "EVENT SOURCE"}}</answer>
"""
    return prompt + json_prompt

def main():
    parser = argparse.ArgumentParser(
        description="Estimate total token count for prompts built from JSON files in a directory."
    )
    parser.add_argument("--folder", type=str, required=True, help="Path to folder containing JSON files.")
    args = parser.parse_args()

    json_files = glob.glob(os.path.join(args.folder, "*.json"))
    if not json_files:
        print("No JSON files found in folder:", args.folder)
        return

    total_tokens = 0
    processed_files = 0

    # Each file will have 3 prompts (one per action space)
    for file_path in json_files:
        try:
            dp = DataProcessor(file_path)
            past, future = dp.split_timeline_by_count()
            if past is None or future is None:
                continue
            misc = dp.get_relevant_misc()
            for action, description in CATEGORY_DESCRIPTIONS.items():
                prompt_text = build_prompt(past, future, misc, action, description)
                tokens = count_tokens(prompt_text)
                total_tokens += tokens
            processed_files += 1
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")

    print(f"Processed {processed_files} files.")
    print(f"Total tokens for all prompts (3 per file): {total_tokens}")

if __name__ == "__main__":
    main()
