import comet_ml
import functools
import warnings
from config import get_config

cf = get_config(["config/private.yaml"])

def comet_conditional(func):
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        if self.use_comet:
            return func(self, *args, **kwargs)
        return None
    return wrapper

class CometML:
    def __init__(self, use_comet: bool = True):
        self.use_comet = use_comet
    
    @comet_conditional
    def start(self, name: str = None, tags: list[str] = None):
        experiment_config = comet_ml.ExperimentConfig(
            name=name,
            tags=tags,
            display_summary_level=0
        )
        self.exp = comet_ml.start(
            api_key=cf.comet.api_key,
            project_name=cf.comet.project_name,
            workspace=cf.comet.workspace,
            experiment_config=experiment_config
        )
        
    @comet_conditional
    def end(self):
        if self.exp is not None:
            self.exp.end()

    @comet_conditional
    def log_code(self, file_path):
        self.exp.log_code(file_path)
        
    @comet_conditional
    def log_parameters(self, cfg):
        if not isinstance(cfg, dict):
            warnings.warn("cfg is not a dict, so it will not be logged to Comet.")
        self.exp.log_parameters(cfg)
    
    @comet_conditional
    def log_metrics(self, *args, **kwargs):
        self.exp.log_metrics(*args, **kwargs)
    
    @comet_conditional
    def log_metric(self, *args, **kwargs):
        self.exp.log_metric(*args, **kwargs)

        