import os
import json
import argparse
import tiktoken
from tqdm import tqdm

def calculate_cost(prompts_dir: str,
                   encoding_name: str,
                   max_output_tokens: int,
                   input_rate_per_million: float,
                   output_rate_per_million: float):
    encoder = tiktoken.encoding_for_model(encoding_name)
    total_input_tokens = 0
    total_output_tokens = 0
    total_requests = 0

    for fname in tqdm(os.listdir(prompts_dir)):
        if not fname.endswith(".jsonl"):
            continue
        path = os.path.join(prompts_dir, fname)
        with open(path, 'r') as f:
            for line in f:
                item = json.loads(line)
                prompt = item.get("prompt", "")
                input_tokens = len(encoder.encode(prompt))
                total_input_tokens += input_tokens
                total_output_tokens += max_output_tokens
                total_requests += 1

    cost_input = (total_input_tokens / 1_000_000) * input_rate_per_million
    cost_output = (total_output_tokens / 1_000_000) * output_rate_per_million
    total_cost = cost_input + cost_output

    print(f"Processed {total_requests} requests across {len(os.listdir(prompts_dir))} files.")
    print(f"Total input tokens: {total_input_tokens:,}")
    print(f"Total output tokens (assumed {max_output_tokens} each): {total_output_tokens:,}")
    print(f"Cost @ ${input_rate_per_million:.2f}/M input: ${cost_input:.2f}")
    print(f"Cost @ ${output_rate_per_million:.2f}/M output: ${cost_output:.2f}")
    print(f"Estimated total cost: ${total_cost:.2f}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Estimate GPT-4o-mini batch request cost.")
    parser.add_argument("--prompts_dir", required=True, help="Directory containing JSONL prompt files.")
    parser.add_argument("--encoding", default="gpt-4o-mini", help="Model name.")
    parser.add_argument("--max_output_tokens", type=int, default=128+32, help="Max tokens per response.")
    parser.add_argument("--input_rate", type=float, default=0.30, help="Cost per 1M input tokens (USD).")
    parser.add_argument("--output_rate", type=float, default=1.20, help="Cost per 1M output tokens (USD).")
    args = parser.parse_args()

    calculate_cost(
        prompts_dir=args.prompts_dir,
        encoding_name=args.encoding,
        max_output_tokens=args.max_output_tokens,
        input_rate_per_million=args.input_rate,
        output_rate_per_million=args.output_rate
    )
