from copy import deepcopy
import pathlib
import json
import traceback


from ..learning.supervised_learning import SupervisedLearning
from ..tasks import TaskPath
from ..tasks.task_path import recursive_default_setting
from ..tests import test_init_kwargs
from ..utils import hash_dict, measure_runtime
from ..utils.log import get_logger
from ..utils.sheet_uploader import SheetUploader


class Experiment:

    @test_init_kwargs
    def __init__(self, logdir: pathlib.Path, set_name: str, set_seed: int, **kwargs):
        self.config = kwargs

        self.logdir: pathlib.Path = logdir
        self.logger = get_logger("experiment")
        self.set_name = set_name
        self.set_seed = set_seed
        self.hash = hash_dict(kwargs)
        self.save_hash()

        # update defaults with config from learn_mapping for backwards compatibility (but no overwriting)
        if "defaults" not in self.config["task_path"]:
            self.config["task_path"]["defaults"] = deepcopy(self.config.get("learn_mapping"))
        else:
            recursive_default_setting(self.config.get("learn_mapping"), self.config["task_path"]["defaults"])

        self.task_path: TaskPath = TaskPath(self.logdir, **self.config["task_path"])

        self.supervised_learning = SupervisedLearning(logdir=self.logdir, **self.config['learn_mapping'])

    def run(self):
        with measure_runtime(self.logdir):
            for task in self.task_path:
                self.upload_task()

                with measure_runtime(task.logdir):
                    if task.type == 'supervised-learning':
                        task.output_model = self.supervised_learning.train(task)
                        self.supervised_learning.reset()
                    else:
                        raise NotImplementedError(f"The following task type is not implemented: {task.type}")

        self.upload_task()
        self.save_exit_code(0)

    def upload_task(self):
        try:
            uploader = SheetUploader(self.set_name)
            uploader.upload()

        except Exception:
            exc = traceback.format_exc()
            self.logger.exception(exc)
            self.logger.critical("Could not upload latest results. Likely connection problem to Google sheets.")

    def save_exit_code(self, code: int):
        with open(str(self.logdir / "exit_code.json"), "w") as fp:
            json.dump({"code": code}, fp, indent=4)

    def save_hash(self):
        with open(str(self.logdir / "hash.json"), "w") as fp:
            json.dump({"hash": self.hash}, fp, indent=4)
