"""Logging utilities for neural operator training.

This module provides logging functionality for training neural operator models,
including file logging, console output, and Weights & Biases integration.
"""

import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Optional, Dict, Any, List

import wandb


class CustomLogger:
    """Custom logger that logs to file, console, and Weights & Biases (wandb).
    
    This logger provides comprehensive logging capabilities including metrics,
    parameters, artifacts, and visualizations. It can be configured to use
    Weights & Biases for experiment tracking.
    
    Parameters
    ----------
    log_dir : str, optional
        Directory to save log files, by default "logs".
    run : Optional[wandb.wandb_sdk.wandb_run.Run], optional
        Existing wandb run to log to, by default None.
    experiment_name : Optional[str], optional
        Name of the experiment (used for log file and wandb run), by default None.
        
    Attributes
    ----------
    log_dir : Path
        Directory where log files are saved.
    use_wandb : bool
        Whether wandb logging is enabled.
    experiment_name : str
        Name of the current experiment.
    logger : logging.Logger
        Python logging logger instance.
    wandb_run : Optional[wandb.wandb_sdk.wandb_run.Run]
        Active wandb run if wandb is enabled.
    """

    def __init__(
            self,
            log_dir: str = "logs",
            run: Optional[wandb.wandb_sdk.wandb_run.Run] = None,
            experiment_name: Optional[str] = None
    ) -> None:
        self.log_dir = Path(log_dir)
        self.experiment_name = experiment_name or datetime.now().strftime("%Y%m%d_%H%M%S")

        # Set up file logging
        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.logger = logging.getLogger(self.experiment_name)
        self.logger.setLevel(logging.INFO)

        # File handler
        fh = logging.FileHandler(os.path.join(log_dir, f"{self.experiment_name}.log"))
        fh.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        fh.setFormatter(formatter)
        self.logger.addHandler(fh)

        # Console handler
        ch = logging.StreamHandler()
        ch.setLevel(logging.INFO)
        ch.setFormatter(formatter)
        self.logger.addHandler(ch)

        # wandb setup
        self.wandb_run = run

    def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
        """Log metrics to both file and wandb.
        
        Parameters
        ----------
        metrics : Dict[str, float]
            Dictionary of metric names and values.
        step : Optional[int], optional
            Training step or epoch, by default None.
        """
        metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
        self.logger.info(f"Step {step}: {metrics_str}")

        if self.wandb_run is not None:
            wandb.log(metrics, step=step)

    def log_params(self, params: Dict[str, Any]) -> None:
        """Log parameters to both file and wandb.
        
        Parameters
        ----------
        params : Dict[str, Any]
            Dictionary of parameter names and values.
        """
        params_str = ", ".join([f"{k}: {v}" for k, v in params.items()])
        self.logger.info(f"Parameters: {params_str}")

        if self.wandb_run is not None:
            wandb.config.update(params, allow_val_change=True)

    def log_artifact(self, file_path: str, name: Optional[str] = None, aliases=None) -> None:
        """Log an artifact (file) to wandb.
        
        Parameters
        ----------
        file_path : str
            Path to the file to log.
        name : Optional[str], optional
            Name for the artifact in wandb, by default None.
        """
        if self.wandb_run is not None:
            artifact = wandb.Artifact(name or os.path.basename(file_path), type="artifact")
            artifact.add_file(file_path)
            wandb.log_artifact(artifact, aliases=aliases)

    def log_visuals(
            self,
            visuals: List[Any],
            dtype: List[str],
            step: Optional[int] = None
    ) -> None:
        """Log visualizations to wandb.
        
        This method handles various types of visualizations including plots,
        images, and GIFs. It automatically detects the visualization type
        and logs it appropriately.
        
        Parameters
        ----------
        visuals : List[Any]
            List of visuals to log, can be paths to images/gifs or direct data.
        dtype : List[str]
            List of data types of the visuals, e.g., 'image', 'lineplot', 'gif'.
        step : Optional[int], optional
            Training step or epoch for logging, by default None.
        """
        if self.wandb_run is not None:
            for i, (data, name) in enumerate(zip(visuals, dtype)):
                if name == 'plotly_line_plot':
                    wandb.log(
                        {f"Prediction_Vs_Target_{i}": wandb.Html(data)},
                        step=step
                    )
                elif name == 'gif':
                    wandb.log(
                        {
                            f"Prediction_GIF_{i}": wandb.Video(
                                str(data),
                                caption=f"Prediction GIF {i}",
                                format="gif"
                            )
                        },
                        step=step
                    )
                else:
                    self.logger.warning(
                        f"Unsupported visual type '{name}' for {data}. Skipping logging."
                    )

    def log_video(self, video_path: str, step: Optional[int] = None) -> None:
        """Log a video file to wandb.
        
        Parameters
        ----------
        video_path : str
            Path to the video file to log.
        step : Optional[int], optional
            Training step or epoch, by default None.
        """
        if self.wandb_run is not None:
            wandb.log(
                {
                    Path(video_path).stem: wandb.Video(
                        video_path,
                        caption="Wavefield Video",
                        format="gif"
                    )
                },
                step=step
            )

    def close(self) -> None:
        """Close the wandb run if active."""
        if self.wandb_run is not None:
            wandb.finish()

    def info(self, message: str) -> None:
        self.logger.info(message)

    def debug(self, message: str) -> None:
        self.logger.debug(message)

    def warning(self, message: str) -> None:
        self.logger.warning(message)

    def error(self, message: str) -> None:
        self.logger.error(message)


class NullLogger:
    """Null logger that implements the same interface as CustomLogger but does nothing.
    
    This class is useful for disabling logging in certain contexts, such as
    non-main processes in distributed training.
    
    All methods are no-ops and do nothing when called.
    """

    def log_metrics(self, *args: Any, **kwargs: Any) -> None: pass

    def log_artifact(self, *args: Any, **kwargs: Any) -> None: pass

    def log_visuals(self, *args: Any, **kwargs: Any) -> None:  pass

    def info(self, *args: Any, **kwargs: Any) -> None: pass

    def log_params(self, *args: Any, **kwargs: Any) -> None: pass

    def log_video(self, *args: Any, **kwargs: Any) -> None: pass

    def close(self) -> None: pass

    def debug(self, *args: Any, **kwargs: Any) -> None: pass

    def warning(self, *args: Any, **kwargs: Any) -> None: pass

    def error(self, *args: Any, **kwargs: Any) -> None: pass

    def __getattr__(self, item: str) -> None:
        """Return None for any attribute access to maintain interface compatibility."""
        return None
