import os
import logging
import dataclasses
import json
import wandb

logging.basicConfig(
    format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
    datefmt="%Y-%m-%d:%H:%M:%S",
    level=logging.INFO,
)
LOG = logging.getLogger(__name__)


class BaseTask:
    def __init__(self, config) -> None:
        LOG.info("Config is %s", config)
        self.config = config
        self.report_to = "none"
        self.wandb_run = None

        self.out_dir = os.path.join(self.config.base_dir, "out_latte", self.config.name)
        os.makedirs(self.out_dir, exist_ok=True)
        # dump config file in model dir for debug
        with open(os.path.join(self.out_dir, "config.json"), "w+") as f:
            a = dataclasses.asdict(config)
            json.dump(a, f)
        self.set_logger()

    def set_logger(self):
        # configure wandb logs
        if self.config.wandb_log:
            resume = False
            run_id = None
            if not self.config.check_path is None:
                resume = "must"
                run_id = self.config.run_id
            wandb_run = wandb.init(
                project=self.config.project,
                entity=self.config.entity,
                name=self.config.name,
                dir=self.out_dir,
                config=self.config,
                id=run_id,
                resume=resume,
            )
            self.report_to = "wandb"
            self.wandb_run = wandb_run