"""
WandB Logger for VERL Training

Provides integration with Weights & Biases for logging training metrics,
per-dataset evaluation results, and checkpoints.
"""

import os
from typing import Dict, List, Optional, Any
import wandb


class WandBLogger:
    """Handles WandB logging for VERL training"""

    def __init__(
        self,
        project_name: str,
        experiment_name: str,
        config: Optional[Dict[str, Any]] = None,
        mode: str = "online",
        tags: Optional[List[str]] = None,
        notes: Optional[str] = None
    ):
        """
        Initialize WandB logger

        Args:
            project_name: WandB project name
            experiment_name: Run name
            config: Training configuration to log
            mode: WandB mode ("online", "offline", "disabled")
            tags: Optional tags for the run
            notes: Optional notes for the run
        """
        self.project_name = project_name
        self.experiment_name = experiment_name
        self.mode = mode
        self.enabled = mode != "disabled"

        # Check for API key
        if self.enabled and not os.environ.get("WANDB_API_KEY"):
            print("Warning: WANDB_API_KEY not set, disabling WandB logging")
            self.enabled = False
            self.mode = "disabled"

        if self.enabled:
            # Initialize WandB run
            self.run = wandb.init(
                project=project_name,
                name=experiment_name,
                config=config or {},
                mode=mode,
                tags=tags,
                notes=notes
            )
            print(f"WandB run initialized: {self.run.url}")
        else:
            self.run = None
            print("WandB logging disabled")

        # Track per-dataset best accuracies
        self.best_accuracies: Dict[str, float] = {}

    def log_training_step(
        self,
        step: int,
        metrics: Dict[str, Any],
        commit: bool = True
    ):
        """
        Log training step metrics

        Args:
            step: Training step number
            metrics: Dictionary of metrics to log
                    e.g., {"loss": 0.5, "lr": 1e-6, "reward_mean": 0.8}
            commit: Whether to commit the log (increment step counter)
        """
        if not self.enabled:
            return

        wandb.log(metrics, step=step, commit=commit)

    def log_rewards(
        self,
        step: int,
        rewards: List[float],
        prefix: str = "train"
    ):
        """
        Log reward statistics

        Args:
            step: Training step number
            rewards: List of reward values
            prefix: Metric prefix ("train", "val", etc.)
        """
        if not self.enabled or not rewards:
            return

        import numpy as np

        metrics = {
            f"{prefix}/reward_mean": np.mean(rewards),
            f"{prefix}/reward_std": np.std(rewards),
            f"{prefix}/reward_min": np.min(rewards),
            f"{prefix}/reward_max": np.max(rewards),
            f"{prefix}/reward_median": np.median(rewards),
            f"{prefix}/reward_p25": np.percentile(rewards, 25),
            f"{prefix}/reward_p75": np.percentile(rewards, 75),
        }

        wandb.log(metrics, step=step)

    def log_per_dataset_rewards(
        self,
        step: int,
        dataset_rewards: Dict[str, List[float]],
        prefix: str = "train"
    ):
        """
        Log per-dataset reward statistics

        Args:
            step: Training step number
            dataset_rewards: Dict mapping dataset name to list of rewards
            prefix: Metric prefix
        """
        if not self.enabled:
            return

        import numpy as np

        for dataset_name, rewards in dataset_rewards.items():
            if not rewards:
                continue

            # Clean dataset name for WandB
            clean_name = dataset_name.replace('/', '_').replace('-', '_')

            metrics = {
                f"{prefix}/reward_{clean_name}_mean": np.mean(rewards),
                f"{prefix}/reward_{clean_name}_std": np.std(rewards),
                f"{prefix}/success_rate_{clean_name}": np.mean([r > 0.9 for r in rewards]),
            }

            wandb.log(metrics, step=step, commit=False)

        # Commit after logging all datasets
        wandb.log({}, step=step, commit=True)

    def log_eval_results(
        self,
        step: int,
        dataset_name: str,
        accuracy: float,
        metrics: Optional[Dict[str, Any]] = None
    ):
        """
        Log evaluation results for a dataset

        Args:
            step: Training step number
            dataset_name: Name of the evaluation dataset
            accuracy: Accuracy on the dataset
            metrics: Additional metrics to log
        """
        if not self.enabled:
            return

        # Clean dataset name for WandB
        clean_name = dataset_name.replace('/', '_').replace('-', '_')

        log_dict = {
            f"eval/{clean_name}_accuracy": accuracy,
        }

        if metrics:
            for key, value in metrics.items():
                log_dict[f"eval/{clean_name}_{key}"] = value

        wandb.log(log_dict, step=step)

    def log_best_checkpoint(
        self,
        dataset_name: str,
        accuracy: float,
        checkpoint_path: str,
        step: int
    ):
        """
        Log best checkpoint for a dataset

        Args:
            dataset_name: Name of the dataset
            accuracy: Accuracy achieved
            checkpoint_path: Path to the checkpoint
            step: Training step number
        """
        if not self.enabled:
            return

        # Clean dataset name
        clean_name = dataset_name.replace('/', '_').replace('-', '_')

        # Check if this is a new best
        if clean_name not in self.best_accuracies or accuracy > self.best_accuracies[clean_name]:
            self.best_accuracies[clean_name] = accuracy

            log_dict = {
                f"best/{clean_name}_accuracy": accuracy,
                f"best/{clean_name}_step": step,
            }

            wandb.log(log_dict, step=step)

            print(f"New best accuracy for {dataset_name}: {accuracy:.4f} at step {step}")

    def log_model_config(self, config: Dict[str, Any]):
        """Log model configuration"""
        if not self.enabled:
            return

        wandb.config.update(config, allow_val_change=True)

    def log_hyperparameters(self, hyperparams: Dict[str, Any]):
        """Log training hyperparameters"""
        if not self.enabled:
            return

        wandb.config.update(hyperparams, allow_val_change=True)

    def finish(self):
        """Finish WandB run"""
        if self.enabled and self.run:
            self.run.finish()
            print("WandB run finished")

    def __enter__(self):
        """Context manager support"""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager cleanup"""
        self.finish()


class DummyLogger:
    """Dummy logger when WandB is disabled"""

    def __init__(self, *args, **kwargs):
        pass

    def log_training_step(self, *args, **kwargs):
        pass

    def log_rewards(self, *args, **kwargs):
        pass

    def log_per_dataset_rewards(self, *args, **kwargs):
        pass

    def log_eval_results(self, *args, **kwargs):
        pass

    def log_best_checkpoint(self, *args, **kwargs):
        pass

    def log_model_config(self, *args, **kwargs):
        pass

    def log_hyperparameters(self, *args, **kwargs):
        pass

    def finish(self):
        pass

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        pass


def get_logger(
    enabled: bool = True,
    project_name: Optional[str] = None,
    experiment_name: Optional[str] = None,
    **kwargs
) -> WandBLogger:
    """
    Get WandB logger instance (or dummy if disabled)

    Args:
        enabled: Whether to enable WandB logging
        project_name: WandB project name
        experiment_name: Run name
        **kwargs: Additional arguments for WandBLogger

    Returns:
        WandBLogger or DummyLogger instance
    """
    if enabled and project_name and experiment_name:
        return WandBLogger(
            project_name=project_name,
            experiment_name=experiment_name,
            **kwargs
        )
    else:
        return DummyLogger()
