import json
from collections import defaultdict
from typing import Dict, Optional, Tuple
import pandas as pd
import warnings
import os

class TokenCostTracker:
    def __init__(self, model_name: str, pricing: dict = None):
        # pricing per 1M tokens
        self.model_name = model_name
        self.pricing = pricing or self.get_default_pricing(model_name)
        self.usage = defaultdict(lambda: {"input_tokens": 0, "output_tokens": 0, "cached_input_tokens": 0,
                                          "input_cost": 0.0, "output_cost": 0.0, "cached_input_cost": 0.0,
                                          "total_cost": 0.0})

    def get_default_pricing(self, model_name):
        # pricing per 1M tokens (cost last updated: Sep 17, 2025)
        defaults = {
            # OpenAI
            "gpt-4o": {"input": 2.50, "output": 10.00, "cached_input": 1.25},
            "gpt-4o-mini": {"input": 0.15, "output": 0.60, "cached_input": 0.075},
            "gpt-4.1": {"input": 2.00, "output": 8.00, "cached_input": 0.50},
            "gpt-4.1-mini": {"input": 0.40, "output": 1.60, "cached_input": 0.10},
            "gpt-4.1-nano": {"input": 0.10, "output": 0.40, "cached_input": 0.025},
            "gpt-5": {"input": 1.25, "output": 10.0, "cached_input": 0.125},
            "gpt-5-mini": {"input": 0.25, "output": 2.0, "cached_input": 0.025},
            "gpt-5-nano": {"input": 0.05, "output": 0.4, "cached_input": 0.005},

            # Gemini
            "gemini-2.5-pro-up-to-200k": {"input": 1.25, "output": 10.0, "cached_input": 0.31},
            "gemini-2.5-pro-above-200k": {"input": 2.50, "output": 15.0, "cached_input": 0.625},
            "gemini-2.5-flash": {"input": 0.30, "output": 2.50, "cached_input": 0.075},
            "gemini-2.5-flash-lite": {"input": 0.10, "output": 0.40, "cached_input": 0.025},
            "gemini-2.0-flash": {"input": 0.10, "output": 0.40, "cached_input": 0.025},
            "gemini-2.0-flash-lite": {"input": 0.075, "output": 0.30, "cached_input": 0.0},
        }

        return defaults.get(model_name, {"input": 0.0, "output": 0.0, "cached_input": 0.0})

    def log(self, agent_id: str, input_tokens: int, output_tokens: int, cached_input_tokens: int = 0):
        prices = self.pricing
        non_cached_input = input_tokens - cached_input_tokens

        cost_input = (non_cached_input / 1_000_000) * prices["input"]
        cost_cached = (cached_input_tokens / 1_000_000) * prices["cached_input"]
        cost_output = (output_tokens / 1_000_000) * prices["output"]

        entry = self.usage[agent_id]
        entry["input_tokens"] += input_tokens
        entry["output_tokens"] += output_tokens
        entry["cached_input_tokens"] += cached_input_tokens
        entry["input_cost"] += cost_input
        entry["output_cost"] += cost_output
        entry["cached_input_cost"] += cost_cached
        entry["total_cost"] += cost_input + cost_output + cost_cached

    def get_total_cost(self) -> float:
        return sum(entry["total_cost"] for entry in self.usage.values())

    def summary(self) -> dict:
        return {
            "total_input_tokens": sum(e["input_tokens"] for e in self.usage.values()),
            "total_output_tokens": sum(e["output_tokens"] for e in self.usage.values()),
            "total_cached_input_tokens": sum(e["cached_input_tokens"] for e in self.usage.values()),
            "total_cost": round(self.get_total_cost(), 6)
        }

    def to_json(self, path: str):
        with open(path, "w") as f:
            json.dump({
                "model_name": self.model_name,
                "pricing_per_million": self.pricing,
                "usage": dict(self.usage),
                "summary": self.summary()
            }, f, indent=2)

    def as_dataframe(self):
        """
        Return per-agent usage summary as a pandas DataFrame.
        """
        return pd.DataFrame.from_dict(self.usage, orient="index").reset_index().rename(columns={"index": "agent_id"})
    

    def warn_on_high_cost(self, threshold: float = 50.0):
        """
        Emit a warning if the total cost exceeds the given threshold.
        """
        total = self.get_total_cost()
        if total > threshold:
            warnings.warn(f"[TokenCostTracker] Total cost ${total:.2f} exceeds threshold ${threshold:.2f}")



class MultiAxisTokenTracker:
    """
    Track token usage and cost broken down by (model_name, module_name).
    """
    def __init__(self):
        self.trackers: Dict[Tuple[str, Optional[str]], TokenCostTracker] = {}

    def log(self,
            agent_id: str,
            input_tokens: int,
            output_tokens: int,
            model_name: str,
            cached_input_tokens: int = 0,
            module_name: Optional[str] = None):
        key = (model_name, module_name)
        if key not in self.trackers:
            self.trackers[key] = TokenCostTracker(model_name)
        self.trackers[key].log(agent_id, input_tokens, output_tokens, cached_input_tokens)

    def summary(self) -> dict:
        return {
            "per_model_module": {
                f"{model} | {module if module is not None else 'unspecified'}": tracker.summary()
                for (model, module), tracker in self.trackers.items()
            },
            "per_model": {
                model: self.get_model_cost(model)
                for model, _ in self.trackers.keys()
            },
            "per_module": {
                module if module is not None else "unspecified": self.get_module_cost(module)
                for _, module in self.trackers.keys()
            },
            "total_cost": round(self.get_total_cost(), 6)
        }

    def get_total_cost(self) -> float:
        return sum(tracker.get_total_cost() for tracker in self.trackers.values())
    
    def get_module_cost(self, module_name: Optional[str]) -> float:
        """
        Return total cost for a specific module (aggregated across all models).
        Pass `None` for unspecified module usage.
        """     
        return sum(
            tracker.get_total_cost()
            for (model, module), tracker in self.trackers.items()
            if module == module_name
        )
    
    def get_model_cost(self, model_name: str) -> float:
        """
        Return total cost for a specific model (aggregated across all modules).
        """
        return sum(
            tracker.get_total_cost()
            for (model, module), tracker in self.trackers.items()
            if model == model_name
        )    
    

    def as_dataframe(self) -> pd.DataFrame:
        records = []
        for (model, module), tracker in self.trackers.items():
            df = tracker.as_dataframe()
            df["model_name"] = model
            df["module_name"] = module if module is not None else "unspecified"
            records.append(df)
        return pd.concat(records, ignore_index=True)

    def warn_on_high_cost(self, threshold: float = 50.0):
        total = self.get_total_cost()
        if total > threshold:
            warnings.warn(f"[MultiAxisTokenTracker] Total cost ${total:.2f} exceeds threshold ${threshold:.2f}")

    def to_json(self, path: str):
        with open(path, "w") as f:
            json.dump(self.summary(), f, indent=2)

    def save(self, path: str):
        ext = os.path.splitext(path)[1].lower()
        
        if ext == ".json":
            self.to_json(path)
        elif ext == ".csv":
            df = self.as_dataframe()
            df.to_csv(path, index=False)
        elif ext == ".xlsx":
            df = self.as_dataframe()
            df.to_excel(path, index=False)
        else:
            raise ValueError(f"Unsupported file extension: {ext}")
        

    def report(self) -> str:
        """
        Return a human-readable summary string without emojis.
        """
        lines = []
        lines.append("Token Usage by Model | Module:\n")

        for (model, module), tracker in sorted(self.trackers.items()):
            mod = module if module is not None else "unspecified"
            s = tracker.summary()
            lines.append(
                f"  - {model:<20} | {mod:<20} "
                f"{s['total_input_tokens']:>8} in, "
                f"{s['total_output_tokens']:>8} out, "
                f"{s['total_cached_input_tokens']:>8} cached, "
                f"cost: ${s['total_cost']:.4f}"
            )

        lines.append(f"\nTotal Estimated Cost: ${self.get_total_cost():.4f}")
        return "\n".join(lines)