from abc import ABC, abstractmethod
import os

import mlflow
import omegaconf
import wandb
from omegaconf import DictConfig
import torch

class BaseLogger:
    @abstractmethod
    def set_experiment(self, experiment_name: str):
        pass

    @abstractmethod
    def log_param(self, key: str, value: int | float | str):
        pass

    @abstractmethod
    def log_metric(
        self, metrics: dict[str, float | int], step_key: str, commit: bool = True
    ):
        pass

    @abstractmethod
    def log_hparams(self, hparams: dict, metrics: dict):
        pass

    @abstractmethod
    def end_run(self):
        pass

    @abstractmethod
    def define_metrics(self, metrics: dict[str, str]):
        pass

    @abstractmethod
    def get_run_dir(self):
        pass

    @abstractmethod
    def log_artifact(self, local_path: str, artifact_path: str | None = None):
        pass


class AutoLogger(BaseLogger):
    def __init__(self, logger: str, **kwargs):
        if logger == "mlflow":
            self.logger = MLFlowLogger(**kwargs)
        elif logger == "wandb":
            self.logger = WandbLogger(**kwargs)
        else:
            raise ValueError(f"Logger {logger} not supported")

    def set_experiment(self, *args, **kwargs):
        self.logger.set_experiment(*args, **kwargs)

    def log_param(self, *args, **kwargs):
        self.logger.log_param(*args, **kwargs)

    def log_metric(self, *args, **kwargs):
        self.logger.log_metric(*args, **kwargs)

    def end_run(self):
        return self.logger.end_run()

    def log_summary(self, *args, **kwargs):
        self.logger.log_summary(*args, **kwargs)

    def define_metrics(self, *args, **kwargs):
        self.logger.define_metrics(*args, **kwargs)

    def get_run_dir(self) -> str:
        return self.logger.get_run_dir()

    def log_artifact(self, *args, **kwargs):
        self.logger.log_artifact(*args, **kwargs)


class MLFlowLogger(BaseLogger):
    def __init__(self, config: DictConfig):
        mlflow.set_tracking_uri(config.tracking_uri)

    def set_experiment(self, experiment_name: str):
        mlflow.set_experiment(experiment_name)

    def log_param(self, key: str, value):
        mlflow.log_param(key, value)

    def log_metric(
        self, metrics: dict[str, float | int], step_key: str, commit: bool = True
    ):
        step = int(metrics.pop(step_key))
        for key, value in metrics.items():
            mlflow.log_metric(key, value, step)

    def log_hparams(self, hparams: dict, metrics: dict):
        # mlflow logs hparams via mlflow.log_params and mlflow.log_metrics
        mlflow.log_params(hparams)
        mlflow.log_metrics(metrics)

    def end_run(self):
        mlflow.end_run()

    def log_summary(self, *args, **kwargs):
        mlflow.log_metric(*args, **kwargs)

    def get_run_dir(self) -> str:
        # Ensure there's an active run to get artifact URI
        if mlflow.active_run():
            return mlflow.active_run().info.artifact_uri
        # Fallback or raise error if no run is active
        # For simplicity, returning None or a default path might be alternatives
        # Or ensure set_experiment/start_run is called before this.
        # Let's assume a run is active for now.
        # Consider adding error handling or a check.
        try:
            # Attempt to get artifact URI, might fail if no run active
             return mlflow.get_artifact_uri()
        except Exception:
             # Fallback: return tracking URI or a placeholder
             # This might not be the desired behavior in all cases.
             print("Warning: No active MLflow run found to get artifact URI.")
             return mlflow.get_tracking_uri()

    def log_artifact(self, local_path: str, artifact_path: str | None = None):
        mlflow.log_artifact(local_path, artifact_path)


class WandbLogger(BaseLogger):
    def __init__(
        self,
        config: DictConfig,
    ):
        wandb_cfg = omegaconf.OmegaConf.to_container(
            config,
            resolve=True,
        )

        self.run = wandb.init(
            project=config.project_name,  # ShortCircuit
            dir=config.log_dir,
            name=config.experiment_name,
            id=config.run_id,
            tags=[],
            config=wandb_cfg,  # type: ignore
            settings=wandb.Settings(
                start_method="thread",
            ),
            resume="allow" if config.run_id else False,
        )
        self.log_param("run_id", self.get_run_id())
        self.log_param("run_dir", self.get_run_dir())

    def get_run_id(self) -> str:
        return self.run.id

    def get_run_dir(self) -> str:
        return self.run.dir

    def define_metrics(self, metrics: dict[str, str]):
        values = []
        for key, value in metrics.items():
            values.append(value)
            tmp_str = key.lower()
            if "loss" in tmp_str:
                goal = "min"
            else:
                goal = "max"
            self.run.define_metric(name=key, step_metric=value, summary=goal)
        for step in list(set(values)):
            self.run.define_metric(name=step)

    def set_experiment(self, experiment_name: str):
        self.run.name = experiment_name  # type: ignore

    def log_param(self, key: str, value: int | float | str):
        wandb.config.update({key: value}, allow_val_change=True)

    def log_metric(
        self, metrics: dict[str, float | int], step_key: str, commit: bool = True
    ):
        self.run.log(metrics, commit=commit)

    def log_summary(self, key: str, value: int | float):
        self.run.summary[key] = value

    def end_run(self):
        if self.run: # Check if run exists before finishing
             self.run.finish()
             self.run = None # Reset run object after finishing

    def log_artifact(self, local_path: str, artifact_path: str | None = None):
        if self.run: # Check if run exists
            # Use basename of local_path for artifact name if artifact_path (directory) is None
            artifact_name = os.path.basename(local_path)
            # Use artifact_path as the name within the artifact if provided, else keep it same as file name
            log_name = artifact_path if artifact_path else artifact_name
            artifact = wandb.Artifact(name=artifact_name.split('.')[0], type='file') # Use filename without ext as artifact name
            artifact.add_file(local_path=local_path, name=log_name) # Use log_name for the file within artifact
            self.run.log_artifact(artifact) # Added
