"""
Global Token Statistics Manager
Used to track and manage token usage across multiple LLM instances
"""

from typing import Dict, List, Optional, Any
import json
from datetime import datetime
from pathlib import Path


class GlobalTokenStats:
    """Global Token Statistics Manager"""
    
    def __init__(self):
        self.llm_stats: Dict[str, Dict[str, int]] = {}  # {llm_instance_id: {input_tokens, output_tokens}}
        self.session_stats: List[Dict[str, Any]] = []  # Session statistics history
        self.start_time = datetime.now()
        
    def register_llm(self, llm_instance_id: str, llm_instance) -> None:
        """Register LLM instance"""
        self.llm_stats[llm_instance_id] = {
            "input_tokens": 0,
            "output_tokens": 0,
            "model": getattr(llm_instance, 'model', 'unknown'),
            "config_name": llm_instance_id
        }
    
    def update_stats(self, llm_instance_id: str, input_tokens: int, output_tokens: int) -> None:
        """Update statistics for specific LLM instance"""
        if llm_instance_id not in self.llm_stats:
            self.llm_stats[llm_instance_id] = {"input_tokens": 0, "output_tokens": 0}
        
        self.llm_stats[llm_instance_id]["input_tokens"] += input_tokens
        self.llm_stats[llm_instance_id]["output_tokens"] += output_tokens
    
    def get_total_input_tokens(self) -> int:
        """Get total input tokens across all LLM instances"""
        return sum(stats["input_tokens"] for stats in self.llm_stats.values())
    
    def get_total_output_tokens(self) -> int:
        """Get total output tokens across all LLM instances"""
        return sum(stats["output_tokens"] for stats in self.llm_stats.values())
    
    def get_total_tokens(self) -> int:
        """Get total tokens across all LLM instances"""
        return self.get_total_input_tokens() + self.get_total_output_tokens()
    
    def get_stats_by_llm(self) -> Dict[str, Dict[str, int]]:
        """Get statistics grouped by LLM instance"""
        return self.llm_stats.copy()
    
    def get_comprehensive_stats(self) -> Dict[str, Any]:
        """Get comprehensive statistics"""
        total_input = self.get_total_input_tokens()
        total_output = self.get_total_output_tokens()
        total_tokens = total_input + total_output
        
        runtime = datetime.now() - self.start_time
        
        return {
            "session_info": {
                "start_time": self.start_time.isoformat(),
                "runtime_seconds": runtime.total_seconds(),
                "runtime_formatted": str(runtime).split('.')[0]  # Remove microseconds
            },
            "global_totals": {
                "total_input_tokens": total_input,
                "total_output_tokens": total_output,
                "total_tokens": total_tokens
            },
            "llm_breakdown": self.llm_stats,
            "efficiency_metrics": {
                "tokens_per_second": total_tokens / max(runtime.total_seconds(), 1),
                "input_output_ratio": total_input / max(total_output, 1),
                "active_llm_instances": len(self.llm_stats)
            }
        }
    
    def print_comprehensive_report(self) -> None:
        """Print comprehensive statistics report"""
        stats = self.get_comprehensive_stats()
        
        print("\n" + "="*60)
        print("🌍 Global Token Usage Statistics Report")
        print("="*60)
        
        # Session information
        session = stats["session_info"]
        print(f"📅 Session Start Time: {session['start_time']}")
        print(f"⏱️  Runtime: {session['runtime_formatted']}")
        
        # Global totals
        totals = stats["global_totals"]
        print(f"\n📊 Global Totals:")
        print(f"   Input Tokens: {totals['total_input_tokens']:,}")
        print(f"   Output Tokens: {totals['total_output_tokens']:,}")
        print(f"   Total Tokens: {totals['total_tokens']:,}")
        
        # Breakdown by LLM instance
        print(f"\n🤖 Breakdown by LLM Instance:")
        for llm_id, llm_stats in stats["llm_breakdown"].items():
            input_tokens = llm_stats["input_tokens"]
            output_tokens = llm_stats["output_tokens"]
            total = input_tokens + output_tokens
            model = llm_stats.get("model", "unknown")
            
            print(f"   {llm_id} ({model}):")
            print(f"     Input: {input_tokens:,} | Output: {output_tokens:,} | Total: {total:,}")
        
        # Efficiency metrics
        metrics = stats["efficiency_metrics"]
        print(f"\n⚡ Efficiency Metrics:")
        print(f"   Token Processing Speed: {metrics['tokens_per_second']:.2f} tokens/sec")
        print(f"   Input/Output Ratio: {metrics['input_output_ratio']:.2f}:1")
        print(f"   Active LLM Instances: {metrics['active_llm_instances']}")
        
        print("="*60)
    
    def save_stats_to_file(self, filepath: str) -> None:
        """Save statistics to JSON file"""
        stats = self.get_comprehensive_stats()
        
        # Ensure directory exists
        Path(filepath).parent.mkdir(parents=True, exist_ok=True)
        
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(stats, f, indent=2, ensure_ascii=False)
        
        print(f"📁 Statistics saved to: {filepath}")
    
    def load_and_merge_stats(self, filepath: str) -> None:
        """Load and merge statistics from file"""
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                loaded_stats = json.load(f)
            
            # Merge LLM statistics
            for llm_id, stats in loaded_stats.get("llm_breakdown", {}).items():
                if llm_id in self.llm_stats:
                    self.llm_stats[llm_id]["input_tokens"] += stats.get("input_tokens", 0)
                    self.llm_stats[llm_id]["output_tokens"] += stats.get("output_tokens", 0)
                else:
                    self.llm_stats[llm_id] = stats
            
            print(f"📁 Statistics loaded and merged from {filepath}")
            
        except FileNotFoundError:
            print(f"⚠️  File {filepath} does not exist")
        except json.JSONDecodeError as e:
            print(f"❌ JSON parsing error: {e}")
    
    def reset_all_stats(self) -> None:
        """Reset all statistics"""
        print(f"🔄 Resetting global statistics...")
        self.print_comprehensive_report()  # Show stats before reset
        
        self.llm_stats.clear()
        self.session_stats.clear()
        self.start_time = datetime.now()
        
        print(f"✅ Global statistics have been reset")
    
    def get_cost_estimate_all(self, cost_config: Dict[str, Dict[str, float]]) -> Dict[str, Any]:
        """
        Estimate costs for all LLM instances

        Args:
            cost_config: {model_name: {"input_price_per_1k": float, "output_price_per_1k": float}}
        """
        total_cost = 0.0
        llm_costs = {}
        
        for llm_id, stats in self.llm_stats.items():
            model = stats.get("model", "unknown")
            input_tokens = stats["input_tokens"]
            output_tokens = stats["output_tokens"]
            
            if model in cost_config:
                prices = cost_config[model]
                input_cost = (input_tokens / 1000.0) * prices["input_price_per_1k"]
                output_cost = (output_tokens / 1000.0) * prices["output_price_per_1k"]
                llm_total_cost = input_cost + output_cost
                
                llm_costs[llm_id] = {
                    "model": model,
                    "input_cost": input_cost,
                    "output_cost": output_cost,
                    "total_cost": llm_total_cost,
                    "input_tokens": input_tokens,
                    "output_tokens": output_tokens
                }
                
                total_cost += llm_total_cost
        
        return {
            "total_estimated_cost": total_cost,
            "llm_breakdown": llm_costs,
            "cost_config_used": cost_config
        }


# Global instance
global_token_stats = GlobalTokenStats()


def get_global_token_stats() -> GlobalTokenStats:
    """Get global token statistics instance"""
    return global_token_stats