"""
Logging utilities for training monitoring and experiment tracking.
"""

import os
import json
import time
from pathlib import Path
from typing import Dict, Any, Optional
from datetime import datetime
import torch
import numpy as np


class Logger:
    """
    Unified logger for console output, file logging, and optional WandB integration.
    """
    
    def __init__(
        self,
        log_dir: str,
        experiment_name: str,
        use_wandb: bool = False,
        wandb_project: Optional[str] = None,
        wandb_entity: Optional[str] = None,
        config: Optional[Dict[str, Any]] = None
    ):
        """
        Initialize logger.
        
        Args:
            log_dir: Directory to save log files
            experiment_name: Name of the experiment
            use_wandb: Whether to use Weights & Biases
            wandb_project: WandB project name
            wandb_entity: WandB entity (username or team)
            config: Configuration dictionary to log
        """
        self.log_dir = Path(log_dir)
        self.experiment_name = experiment_name
        self.use_wandb = use_wandb
        
        # Create log directory
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.run_dir = self.log_dir / f"{experiment_name}_{timestamp}"
        self.run_dir.mkdir(parents=True, exist_ok=True)
        
        # Initialize log file
        self.log_file = self.run_dir / "training.log"
        
        # Metrics history
        self.metrics_history = {
            'train': [],
            'val': [],
            'test': []
        }
        
        # Initialize WandB if requested
        if self.use_wandb:
            try:
                import wandb
                self.wandb = wandb
                self.wandb.init(
                    project=wandb_project,
                    entity=wandb_entity,
                    name=f"{experiment_name}_{timestamp}",
                    config=config,
                    dir=str(self.run_dir)
                )
                self.log("✓ WandB initialized successfully")
            except ImportError:
                self.log("Warning: wandb not installed, disabling WandB logging")
                self.use_wandb = False
        
        # Save config
        if config is not None:
            self.save_config(config)
        
        # Timer
        self.start_time = time.time()
        self.iteration_start_time = None
        
        self.log(f"✓ Logger initialized: {self.run_dir}")
    
    def log(self, message: str, level: str = "INFO"):
        """
        Log a message to console and file.
        
        Args:
            message: Message to log
            level: Log level (INFO, WARNING, ERROR)
        """
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        formatted_message = f"[{timestamp}] [{level}] {message}"
        
        # Print to console
        print(formatted_message)
        
        # Write to file
        with open(self.log_file, 'a') as f:
            f.write(formatted_message + '\n')
    
    def log_metrics(
        self,
        metrics: Dict[str, float],
        iteration: int,
        phase: str = 'train',
        prefix: str = ''
    ):
        """
        Log metrics for current iteration.
        
        Args:
            metrics: Dictionary of metric name -> value
            iteration: Current iteration number
            phase: 'train', 'val', or 'test'
            prefix: Optional prefix for metric names
        """
        # Add iteration number
        metrics_with_iter = {'iteration': iteration, **metrics}
        
        # Store in history
        self.metrics_history[phase].append(metrics_with_iter)
        
        # Format for console
        metrics_str = ', '.join([f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}" 
                                 for k, v in metrics.items()])
        self.log(f"[{phase.upper()}] Iter {iteration}: {metrics_str}")
        
        # Log to WandB
        if self.use_wandb:
            wandb_metrics = {f"{prefix}{phase}/{k}": v for k, v in metrics.items()}
            wandb_metrics['iteration'] = iteration
            self.wandb.log(wandb_metrics)
    
    def start_iteration(self):
        """Mark the start of an iteration for timing."""
        self.iteration_start_time = time.time()
    
    def end_iteration(self) -> float:
        """
        Mark the end of an iteration and return elapsed time.
        
        Returns:
            Elapsed time in seconds
        """
        if self.iteration_start_time is None:
            return 0.0
        elapsed = time.time() - self.iteration_start_time
        self.iteration_start_time = None
        return elapsed
    
    def get_elapsed_time(self) -> float:
        """Get total elapsed time since logger initialization."""
        return time.time() - self.start_time
    
    def save_config(self, config: Dict[str, Any]):
        """Save configuration to JSON file."""
        config_file = self.run_dir / "config.json"
        
        # Convert non-serializable objects
        serializable_config = {}
        for k, v in config.items():
            if isinstance(v, (int, float, str, bool, list, dict, type(None))):
                serializable_config[k] = v
            elif isinstance(v, tuple):
                serializable_config[k] = list(v)
            else:
                serializable_config[k] = str(v)
        
        with open(config_file, 'w') as f:
            json.dump(serializable_config, f, indent=2)
        
        self.log(f"✓ Configuration saved to {config_file}")
    
    def save_metrics_history(self):
        """Save all metrics history to JSON file."""
        metrics_file = self.run_dir / "metrics_history.json"
        
        # Convert numpy/torch values to Python native types
        def convert_value(v):
            if isinstance(v, (np.ndarray, torch.Tensor)):
                return v.item() if v.numel() == 1 else v.tolist()
            elif isinstance(v, (np.floating, np.integer)):
                return float(v)
            return v
        
        serializable_history = {}
        for phase, history in self.metrics_history.items():
            serializable_history[phase] = [
                {k: convert_value(v) for k, v in entry.items()}
                for entry in history
            ]
        
        with open(metrics_file, 'w') as f:
            json.dump(serializable_history, f, indent=2)
        
        self.log(f"✓ Metrics history saved to {metrics_file}")
    
    def log_summary(self, summary: Dict[str, Any]):
        """
        Log final summary statistics.
        
        Args:
            summary: Dictionary of summary statistics
        """
        summary_file = self.run_dir / "summary.json"
        
        # Add total time
        summary['total_time_seconds'] = self.get_elapsed_time()
        summary['total_time_formatted'] = self.format_time(summary['total_time_seconds'])
        
        with open(summary_file, 'w') as f:
            json.dump(summary, f, indent=2)
        
        self.log("=" * 80)
        self.log("TRAINING SUMMARY")
        self.log("=" * 80)
        for k, v in summary.items():
            self.log(f"{k}: {v}")
        self.log("=" * 80)
        
        if self.use_wandb:
            self.wandb.log({"summary/" + k: v for k, v in summary.items()})
    
    @staticmethod
    def format_time(seconds: float) -> str:
        """Format seconds into human-readable string."""
        hours = int(seconds // 3600)
        minutes = int((seconds % 3600) // 60)
        secs = int(seconds % 60)
        return f"{hours:02d}:{minutes:02d}:{secs:02d}"
    
    def close(self):
        """Close logger and finalize logs."""
        self.save_metrics_history()
        
        if self.use_wandb:
            self.wandb.finish()
        
        total_time = self.get_elapsed_time()
        self.log(f"✓ Training completed in {self.format_time(total_time)}")
        self.log(f"✓ Logs saved to {self.run_dir}")


class MetricAggregator:
    """
    Aggregate metrics over multiple batches/iterations.
    """
    
    def __init__(self):
        self.metrics = {}
        self.counts = {}
    
    def update(self, metrics: Dict[str, float], weight: float = 1.0):
        """
        Update aggregator with new metrics.
        
        Args:
            metrics: Dictionary of metric values
            weight: Weight for weighted average (e.g., batch size)
        """
        for key, value in metrics.items():
            if key not in self.metrics:
                self.metrics[key] = 0.0
                self.counts[key] = 0.0
            
            self.metrics[key] += value * weight
            self.counts[key] += weight
    
    def compute(self) -> Dict[str, float]:
        """
        Compute aggregated metrics.
        
        Returns:
            Dictionary of averaged metric values
        """
        return {
            key: self.metrics[key] / self.counts[key]
            for key in self.metrics.keys()
        }
    
    def reset(self):
        """Reset all aggregated metrics."""
        self.metrics.clear()
        self.counts.clear()