"""
Centralized logging configuration for IFC-ViT.

Provides structured logging to both console and file with different
verbosity levels. Logs important debug values, validation metrics,
and diagnostic information.

DDP-aware: Only rank 0 emits output by default.
"""

import logging
import os
import sys
from datetime import datetime
from typing import Optional
import json
from functools import wraps

import torch
import torch.distributed as dist


# =============================================================================
# DDP State Management
# =============================================================================

class DDPState:
    """
    Global DDP state container.
    
    Stores distributed training state accessible from all modules.
    """
    _initialized: bool = False
    _world_size: int = 1
    _rank: int = 0
    _local_rank: int = 0
    _device: str = "cuda"
    _backend: str = "nccl"
    
    @classmethod
    def initialize(
        cls,
        backend: str = "nccl",
        init_method: Optional[str] = None,
    ) -> bool:
        """
        Initialize DDP from environment variables.
        
        Auto-detects torchrun/SLURM environment.
        
        Returns:
            True if DDP was initialized, False if running single-GPU.
        """
        if cls._initialized:
            return cls._world_size > 1
        
        # Check for torchrun environment
        world_size = int(os.environ.get("WORLD_SIZE", "1"))
        rank = int(os.environ.get("RANK", "0"))
        local_rank = int(os.environ.get("LOCAL_RANK", "0"))
        
        cls._world_size = world_size
        cls._rank = rank
        cls._local_rank = local_rank
        cls._backend = backend
        
        if world_size > 1:
            # Set device before init
            torch.cuda.set_device(local_rank)
            cls._device = f"cuda:{local_rank}"
            
            # Initialize process group
            if not dist.is_initialized():
                dist.init_process_group(
                    backend=backend,
                    init_method=init_method,
                    world_size=world_size,
                    rank=rank,
                )
            
            cls._initialized = True
            
            if cls.is_rank0():
                print(f"DDP initialized: world_size={world_size}, backend={backend}")
            
            return True
        else:
            cls._device = "cuda" if torch.cuda.is_available() else "cpu"
            cls._initialized = True
            return False
    
    @classmethod
    def is_initialized(cls) -> bool:
        return cls._initialized
    
    @classmethod
    def is_ddp(cls) -> bool:
        """Return True if running in DDP mode (world_size > 1)."""
        return dist.is_available() and dist.is_initialized() and cls._world_size > 1
    
    @classmethod
    def is_rank0(cls) -> bool:
        """Return True if this is the main process."""
        return cls._rank == 0
    
    @classmethod
    def world_size(cls) -> int:
        return cls._world_size
    
    @classmethod
    def rank(cls) -> int:
        return cls._rank
    
    @classmethod
    def local_rank(cls) -> int:
        return cls._local_rank
    
    @classmethod
    def device(cls) -> str:
        return cls._device
    
    @classmethod
    def barrier(cls):
        """Synchronize all processes."""
        if cls.is_ddp():
            dist.barrier()
    
    @classmethod
    def broadcast_object(cls, obj, src: int = 0):
        """
        Broadcast a Python object from src to all ranks.
        
        Args:
            obj: Object to broadcast (only meaningful on src rank)
            src: Source rank
            
        Returns:
            The broadcasted object on all ranks
        """
        if not cls.is_ddp():
            return obj
        
        obj_list = [obj]
        dist.broadcast_object_list(obj_list, src=src)
        return obj_list[0]
    
    @classmethod
    def allreduce_scalar(cls, value: float, op=dist.ReduceOp.SUM) -> float:
        """
        All-reduce a scalar value across all ranks.
        
        Args:
            value: Scalar value
            op: Reduction operation
            
        Returns:
            Reduced scalar
        """
        if not cls.is_ddp():
            return value
        
        tensor = torch.tensor(value, device=cls._device, dtype=torch.float64)
        dist.all_reduce(tensor, op=op)
        return tensor.item()
    
    @classmethod
    def allreduce_tensor(cls, tensor: torch.Tensor, op=dist.ReduceOp.SUM) -> torch.Tensor:
        """
        All-reduce a tensor in-place across all ranks.
        
        Args:
            tensor: Tensor to reduce (modified in-place)
            op: Reduction operation
            
        Returns:
            The reduced tensor
        """
        if not cls.is_ddp():
            return tensor
        
        dist.all_reduce(tensor, op=op)
        return tensor
    
    @classmethod
    def cleanup(cls):
        """Clean up DDP resources."""
        if cls.is_ddp() and dist.is_initialized():
            dist.destroy_process_group()
        cls._initialized = False


# Convenience functions
def is_rank0() -> bool:
    """Check if this is the main process."""
    return DDPState.is_rank0()


def is_ddp() -> bool:
    """Check if running in DDP mode."""
    return DDPState.is_ddp()


def barrier():
    """Synchronize all processes."""
    DDPState.barrier()


def broadcast_object(obj, src: int = 0):
    """Broadcast object from src rank."""
    return DDPState.broadcast_object(obj, src)


def allreduce_scalar(value: float, op=dist.ReduceOp.SUM) -> float:
    """All-reduce a scalar."""
    return DDPState.allreduce_scalar(value, op)


def allreduce_tensor(tensor: torch.Tensor, op=dist.ReduceOp.SUM) -> torch.Tensor:
    """All-reduce a tensor in-place."""
    return DDPState.allreduce_tensor(tensor, op)


# =============================================================================
# Rank-aware utilities
# =============================================================================

def rank0_only(func):
    """Decorator to run function only on rank 0."""
    @wraps(func)
    def wrapper(*args, **kwargs):
        if is_rank0():
            return func(*args, **kwargs)
        return None
    return wrapper


def rank0_print(*args, **kwargs):
    """Print only on rank 0."""
    if is_rank0():
        print(*args, **kwargs)


class RankAwareTqdm:
    """
    tqdm wrapper that only displays on rank 0.
    
    Usage:
        for item in rank_tqdm(iterable, desc="Processing"):
            ...
    """
    def __init__(self, iterable=None, *args, disable=None, **kwargs):
        from tqdm import tqdm as _tqdm
        
        # Disable on non-rank0 processes
        if disable is None:
            disable = not is_rank0()
        
        self._tqdm = _tqdm(iterable, *args, disable=disable, **kwargs)
    
    def __iter__(self):
        return iter(self._tqdm)
    
    def __enter__(self):
        return self._tqdm.__enter__()
    
    def __exit__(self, *args):
        return self._tqdm.__exit__(*args)
    
    def update(self, n=1):
        return self._tqdm.update(n)
    
    def set_description(self, desc):
        return self._tqdm.set_description(desc)
    
    def set_postfix(self, *args, **kwargs):
        return self._tqdm.set_postfix(*args, **kwargs)
    
    def close(self):
        return self._tqdm.close()


def rank_tqdm(iterable=None, *args, **kwargs):
    """Create a rank-aware tqdm progress bar."""
    return RankAwareTqdm(iterable, *args, **kwargs)


# =============================================================================
# Custom log levels for metrics
# =============================================================================

METRICS = 25  # Between INFO (20) and WARNING (30)
logging.addLevelName(METRICS, "METRICS")


class MetricsFormatter(logging.Formatter):
    """Formatter that handles metrics specially."""
    
    def format(self, record):
        if record.levelno == METRICS:
            # Format metrics as JSON for easy parsing
            if hasattr(record, 'metrics'):
                record.msg = f"[METRICS] {record.msg}: {json.dumps(record.metrics, indent=2)}"
        return super().format(record)


class Rank0Filter(logging.Filter):
    """
    Logging filter that only allows messages on rank 0.
    
    In DDP mode, this prevents duplicate log messages from all ranks.
    """
    def filter(self, record):
        return is_rank0()


def setup_logger(
    name: str = "ifc_vit",
    log_dir: Optional[str] = None,
    log_file: Optional[str] = None,
    level: int = logging.DEBUG,
    console_level: int = logging.DEBUG,
    rank0_only: bool = True,
) -> logging.Logger:
    """
    Set up logger with file and console handlers.
    
    Args:
        name: Logger name
        log_dir: Directory for log files (creates timestamped file)
        log_file: Specific log file path (overrides log_dir)
        level: File logging level
        console_level: Console logging level
        rank0_only: If True, only log on rank 0 in DDP mode
        
    Returns:
        Configured logger
    """
    logger = logging.getLogger(name)
    
    # Avoid adding handlers multiple times
    if logger.handlers:
        return logger
    
    logger.setLevel(logging.DEBUG)
    
    # IMPORTANT: Disable propagation to avoid duplicate messages
    # when both parent and child loggers have handlers
    logger.propagate = False
    
    # Add rank0 filter if requested and in DDP mode
    if rank0_only:
        logger.addFilter(Rank0Filter())
    
    # Create formatters - include rank info if in DDP mode
    if is_ddp():
        file_formatter = MetricsFormatter(
            fmt=f'%(asctime)s | rank{DDPState.rank()} | %(name)s | %(levelname)s | %(filename)s:%(lineno)d | %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S'
        )
        console_formatter = logging.Formatter(
            fmt=f'%(asctime)s | rank{DDPState.rank()} | %(levelname)s | %(message)s',
            datefmt='%H:%M:%S'
        )
    else:
        file_formatter = MetricsFormatter(
            fmt='%(asctime)s | %(name)s | %(levelname)s | %(filename)s:%(lineno)d | %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S'
        )
        console_formatter = logging.Formatter(
            fmt='%(asctime)s | %(levelname)s | %(message)s',
            datefmt='%H:%M:%S'
        )
    
    # Console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(console_level)
    console_handler.setFormatter(console_formatter)
    logger.addHandler(console_handler)
    
    # File handler
    if log_file is None and log_dir is not None:
        os.makedirs(log_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = os.path.join(log_dir, f"ifc_vit_{timestamp}.log")
    
    if log_file is not None:
        os.makedirs(os.path.dirname(log_file) if os.path.dirname(log_file) else '.', exist_ok=True)
        file_handler = logging.FileHandler(log_file, mode='a')
        file_handler.setLevel(level)
        file_handler.setFormatter(file_formatter)
        logger.addHandler(file_handler)
        logger.info(f"Logging to file: {log_file}")
    
    return logger


def get_logger(name: str = "ifc_vit") -> logging.Logger:
    """Get existing logger or create default one.
    
    For child loggers (e.g., 'ifc_vit.data'), we don't add handlers.
    They inherit from the parent 'ifc_vit' logger via propagation
    ONLY if the parent has been set up. Otherwise, we set up a 
    console-only logger.
    """
    logger = logging.getLogger(name)
    
    # Check if this logger or any parent already has handlers
    # to avoid duplicate setup
    test_logger = logger
    while test_logger:
        if test_logger.handlers:
            return logger
        if not test_logger.propagate:
            break
        test_logger = test_logger.parent
    
    # No handlers found in hierarchy, set up this logger
    # For child loggers, just add handlers to them directly
    # since parent may not be set up yet
    setup_logger(name)
    return logger


def log_metrics(logger: logging.Logger, message: str, metrics: dict):
    """Log metrics at METRICS level with structured data."""
    record = logger.makeRecord(
        logger.name, METRICS, "", 0, message, (), None
    )
    record.metrics = metrics
    logger.handle(record)


def log_tensor_stats(logger: logging.Logger, name: str, tensor, level: int = logging.DEBUG):
    """Log statistics about a tensor."""
    import torch
    
    if tensor is None:
        logger.log(level, f"{name}: None")
        return
    
    if isinstance(tensor, torch.Tensor):
        stats = {
            'shape': list(tensor.shape),
            'dtype': str(tensor.dtype),
            'device': str(tensor.device),
            'min': tensor.min().item() if tensor.numel() > 0 else None,
            'max': tensor.max().item() if tensor.numel() > 0 else None,
            'mean': tensor.float().mean().item() if tensor.numel() > 0 else None,
            'std': tensor.float().std().item() if tensor.numel() > 1 else None,
            'norm': tensor.float().norm().item() if tensor.numel() > 0 else None,
            'num_nan': torch.isnan(tensor).sum().item(),
            'num_inf': torch.isinf(tensor).sum().item(),
        }
    else:
        import numpy as np
        stats = {
            'shape': list(tensor.shape),
            'dtype': str(tensor.dtype),
            'min': float(np.min(tensor)) if tensor.size > 0 else None,
            'max': float(np.max(tensor)) if tensor.size > 0 else None,
            'mean': float(np.mean(tensor)) if tensor.size > 0 else None,
            'std': float(np.std(tensor)) if tensor.size > 1 else None,
            'num_nan': int(np.isnan(tensor).sum()),
            'num_inf': int(np.isinf(tensor).sum()),
        }
    
    logger.log(level, f"{name} stats: {json.dumps(stats)}")


def log_dict(logger: logging.Logger, message: str, data: dict, level: str = 'info'):
    """Log a dictionary in formatted JSON.
    
    Args:
        logger: Logger instance
        message: Message prefix
        data: Dictionary to log
        level: Log level as string ('debug', 'info', 'warning', 'error')
    """
    level_map = {
        'debug': logging.DEBUG,
        'info': logging.INFO,
        'warning': logging.WARNING,
        'error': logging.ERROR,
    }
    log_level = level_map.get(level.lower(), logging.INFO) if isinstance(level, str) else level
    logger.log(log_level, f"{message}: {json.dumps(data, indent=2, default=str)}")


class StageTimer:
    """
    Timer for tracking multiple stages and logging durations.
    
    Can be used two ways:
    1. As a multi-stage tracker:
        timer = StageTimer()
        with timer.stage("stage1"):
            do_stuff()
        timer.log_summary(logger)
        
    2. As a single-stage context manager:
        with StageTimer(logger, "stage_name"):
            do_stuff()
    """
    
    def __init__(self, logger: logging.Logger = None, stage_name: str = None):
        """
        Args:
            logger: Optional logger for single-stage use
            stage_name: Optional stage name for single-stage use
        """
        self._logger = logger
        self._stage_name = stage_name
        self._stages = {}
        self._start_time = None
        self._current_stage = None
        
    def __enter__(self):
        """For single-stage context manager use."""
        import time
        if self._logger and self._stage_name:
            self._start_time = time.time()
            self._logger.info(f"Starting: {self._stage_name}")
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """For single-stage context manager use."""
        import time
        if self._logger and self._stage_name and self._start_time:
            elapsed = time.time() - self._start_time
            if exc_type is not None:
                self._logger.error(f"Failed: {self._stage_name} after {elapsed:.2f}s - {exc_val}")
            else:
                self._logger.info(f"Completed: {self._stage_name} in {elapsed:.2f}s")
        return False
    
    def stage(self, name: str):
        """
        Context manager for timing a named stage.
        
        Args:
            name: Stage name
            
        Returns:
            Context manager
        """
        return _StageContext(self, name)
    
    def _start_stage(self, name: str):
        """Internal: start timing a stage."""
        import time
        self._current_stage = name
        self._stages[name] = {'start': time.time(), 'end': None, 'duration': None}
        
    def _end_stage(self, name: str):
        """Internal: end timing a stage."""
        import time
        if name in self._stages:
            self._stages[name]['end'] = time.time()
            self._stages[name]['duration'] = self._stages[name]['end'] - self._stages[name]['start']
        self._current_stage = None
        
    def log_summary(self, logger: logging.Logger):
        """Log a summary of all timed stages."""
        if not self._stages:
            return
            
        total = sum(s['duration'] for s in self._stages.values() if s['duration'])
        
        logger.info("=" * 40)
        logger.info("Stage Timing Summary")
        logger.info("=" * 40)
        for name, data in self._stages.items():
            if data['duration']:
                pct = 100 * data['duration'] / total if total > 0 else 0
                logger.info(f"  {name}: {data['duration']:.2f}s ({pct:.1f}%)")
        logger.info(f"  TOTAL: {total:.2f}s ({total/60:.1f}min)")
        logger.info("=" * 40)


class _StageContext:
    """Context manager for a single stage within StageTimer."""
    
    def __init__(self, timer: StageTimer, name: str):
        self.timer = timer
        self.name = name
        
    def __enter__(self):
        self.timer._start_stage(self.name)
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.timer._end_stage(self.name)
        return False


# Global logger instance
_global_logger: Optional[logging.Logger] = None


def init_global_logger(log_dir: str, level: int = logging.DEBUG) -> logging.Logger:
    """Initialize the global logger."""
    global _global_logger
    _global_logger = setup_logger("ifc_vit", log_dir=log_dir, level=level)
    return _global_logger


def get_global_logger() -> logging.Logger:
    """Get the global logger instance."""
    global _global_logger
    if _global_logger is None:
        _global_logger = get_logger("ifc_vit")
    return _global_logger
