import json
import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional, Union

import wandb
from omegaconf import DictConfig, OmegaConf


class CustomFormatter(logging.Formatter):
    """Custom formatter with colors for different log levels"""

    grey = "\x1b[38;21m"
    blue = "\x1b[38;5;39m"
    yellow = "\x1b[38;5;226m"
    red = "\x1b[38;5;196m"
    bold_red = "\x1b[31;1m"
    reset = "\x1b[0m"

    format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"

    FORMATS = {
        logging.DEBUG: grey + format_str + reset,
        logging.INFO: blue + format_str + reset,
        logging.WARNING: yellow + format_str + reset,
        logging.ERROR: red + format_str + reset,
        logging.CRITICAL: bold_red + format_str + reset,
    }

    def format(self, record):
        log_fmt = self.FORMATS.get(record.levelno)
        formatter = logging.Formatter(log_fmt, datefmt="%Y-%m-%d %H:%M:%S")
        return formatter.format(record)


def create_exp_dir(exp_name: str, log_dir: str) -> Path:
    """Create a unique experiment directory based on timestamp and experiment name"""
    log_dir = log_dir or exp_name
    timestamp = datetime.now().strftime("%m%d_%H%M")
    exp_dir = Path("outputs") / f"{log_dir}_{timestamp}"
    exp_dir.mkdir(parents=True, exist_ok=True)
    return exp_dir


def setup_logging(cfg: DictConfig) -> Dict[str, Any]:
    """
    Set up logging for the project, including file logging and wandb integration.

    Args:
        cfg: Hydra configuration object

    Returns:
        Dict containing logger, wandb run object (if enabled), and paths
    """
    # Create experiment directory
    exp_dir = create_exp_dir(cfg.exp_name, cfg.log_dir)

    # Get run name (without 'outputs/' prefix)
    run_name = exp_dir.name

    # Initialize root logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # Remove existing handlers to avoid duplicates
    logger.handlers.clear()

    # Console handler with custom formatting
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(CustomFormatter())
    logger.addHandler(console_handler)

    # File handler
    log_file = exp_dir / "experiment.log"
    file_handler = logging.FileHandler(str(log_file))
    file_formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    file_handler.setFormatter(file_formatter)
    logger.addHandler(file_handler)

    # Save configuration
    config_save_path = exp_dir / "config.yaml"
    OmegaConf.save(cfg, config_save_path, resolve=True)

    # Initialize wandb if enabled
    wandb_run = None
    if cfg.logging.wandb:
        wandb_config = OmegaConf.to_container(cfg, resolve=True)
        # Update wandb config with experiment directory
        wandb_config["exp_dir"] = str(exp_dir)

        wandb_run = wandb.init(
            project=cfg.logging.project,
            entity=cfg.logging.entity,
            group=cfg.logging.group,
            name=run_name,
            config=wandb_config,
            dir=str(exp_dir),
            tags=cfg.logging.tags,
        )

        # Save config to wandb
        wandb.save(str(config_save_path))

    logger.info(f"Experiment name: {cfg.exp_name}")
    logger.info(f"Experiment directory: {exp_dir}")
    logger.info(f"Logging to: {log_file}")

    # Create additional directories for experiment artifacts
    checkpoints_dir = exp_dir / "checkpoints"
    eval_dir = exp_dir / "eval"

    for directory in [checkpoints_dir, eval_dir]:
        directory.mkdir(exist_ok=True)

    return {
        "logger": logger,
        "wandb_run": wandb_run,
        "exp_dir": exp_dir,
        "checkpoints_dir": checkpoints_dir,
        "eval_dir": eval_dir,
        "run_name": run_name,
    }


class MetricLogger:
    """Utility class for logging metrics during training"""

    def __init__(
        self,
        wandb_run=None,
        log_interval: int = 1,
        prefix: str = "",
    ):
        self.wandb_run = wandb_run
        self.log_interval = log_interval
        self.prefix = prefix
        self.metrics = {}
        self.step = 0

    def update(self, metrics: Dict[str, Union[float, int]]) -> None:
        """Update metrics dictionary with new values"""
        for k, v in metrics.items():
            if k not in self.metrics:
                self.metrics[k] = v

    def log(self, step: Optional[int] = None, console=True) -> None:
        """Log current metrics to wandb and/or console"""
        if step is not None:
            self.step = step

        if self.step % self.log_interval == 0:
            # Add prefix if specified
            log_metrics = {f"{self.prefix}{k}": v for k, v in self.metrics.items()}

            # Log to wandb if enabled
            if self.wandb_run is not None:
                self.wandb_run.log(log_metrics, step=self.step)

            # Log to console
            if console:
                metrics_str = " | ".join(
                    f"{k}: {v:.4f}" for k, v in log_metrics.items()
                )
                logging.info(f"Step {self.step} | {metrics_str}")
            self.reset()

    def reset(self) -> None:
        """Reset all metrics"""
        self.metrics.clear()


def log_git_info(logger: logging.Logger, exp_dir: Path) -> None:
    """Log git repository information if available"""
    try:
        import git

        repo = git.Repo(search_parent_directories=True)
        git_info = {
            "commit_hash": repo.head.object.hexsha,
            "branch": repo.active_branch.name,
            "dirty": repo.is_dirty(),
        }

        # Save git info to file
        with open(exp_dir / "git_info.json", "w") as f:
            json.dump(git_info, f, indent=2)

        logger.info(f"Git info: {git_info}")

    except (ImportError, git.InvalidGitRepositoryError):
        logger.warning("Unable to log git information")
