import json
from collections import defaultdict
import pandas as pd


class UsageStatistics:
    """
    Class to track the tool usage statistics (duration, tokens, cost).
    """

    def __init__(self, file_name: str = "llm_cost.json"):
        """
        Initialize the usage statistics.
        
        Args:
            file_name (str): Name of the statistics file. Defaults to "llm_cost.json".
        """
        self.statistics_file_name = file_name
        self.stats_df = pd.DataFrame(
            columns=['FunctionName', 'StartTime', 'EndTime', 'Model', 'PromptTokens', 'CompletionTokens', 'Cost'])

    def log_statistic(self, function_name: str, start_time: float, end_time: float, model: str = None, prompt_tokens: int = None,
                      completion_tokens: int = None, cost: float = None):
        """
        Log a statistic entry.

        Args:
            function_name (str): Name of the function.
            start_time (float): Start time.
            end_time (float): End time.
            model (str): Model name.
            prompt_tokens (int): Number of prompt tokens.
            completion_tokens (int): Number of completion tokens.
            cost (float): Cost.
        """
        
        cost = ({
            'FunctionName': function_name,
            'StartTime': start_time,
            'EndTime': end_time,
            'Model': model,
            'PromptTokens': prompt_tokens,
            'CompletionTokens': completion_tokens,
            'Cost': cost
        })

        with open(self.statistics_file_name, 'a') as f:
            # Write the statistic entry to the file
            f.write(json.dumps(cost) + '\n')

    @staticmethod
    def calculate_total_cost(input_log_file: str = "llm_cost.json", output_log_file: str = "total_cost.json"):
        """
        Calculate the total cost from the usage statistics.

        Args:
            input_log_file (str): Path to the input JSON file. Defaults to "llm_cost.json".
            output_log_file (str): Path to the output JSON file. Defaults to "total_cost.json".
        """
        with open(input_log_file, 'r') as f:
            data = [json.loads(line) for line in f]

        # Initialize a defaultdict for aggregating totals by function name
        totals = defaultdict(lambda: {
            "TotalPromptTokens": 0,
            "TotalCompletionTokens": 0,
            "TotalCost": 0.0,
            "TotalDuration": 0.0
        })

        final_total = {
            "TotalPromptTokens": 0,
            "TotalCompletionTokens": 0,
            "TotalCost": 0.0,
            "TotalDuration": 0.0
        }

        # Iterate through each entry in the data
        for entry in data:
            function_name = entry["FunctionName"]
            totals[function_name]["TotalPromptTokens"] += entry["PromptTokens"]
            totals[function_name]["TotalCompletionTokens"] += entry["CompletionTokens"]
            totals[function_name]["TotalCost"] += entry["Cost"]
            totals[function_name]["TotalDuration"] += entry["EndTime"] - entry["StartTime"]

            final_total["TotalPromptTokens"] += entry["PromptTokens"]
            final_total["TotalCompletionTokens"] += entry["CompletionTokens"]
            final_total["TotalCost"] += entry["Cost"]
            final_total["TotalDuration"] += entry["EndTime"] - entry["StartTime"]

        # Convert totals to a regular dictionary for JSON serialization
        totals = {function_name: dict(totals_data) for function_name, totals_data in totals.items()}
        totals["FinalTotal"] = final_total

        # Write the totals to the output JSON file
        with open(output_log_file, 'w') as f:
            json.dump(totals, f, indent=4)

        print(f"Totals have been written to {output_log_file}")
