import pathlib
import json
import traceback

from ..learning.supervised_learning import SupervisedLearning
from ..tasks import TaskPath
from ..tests import test_init_kwargs
from ..utils import hash_dict, measure_runtime
from ..utils.log import get_logger
from ..utils.sheet_uploader import SheetUploader


class Experiment:

    @test_init_kwargs
    def __init__(self, logdir: pathlib.Path, set_name: str, **kwargs):
        self.config = kwargs

        self.logdir: pathlib.Path = logdir
        self.logger = get_logger("experiment")
        self.set_name = set_name
        self.hash = hash_dict(kwargs)
        self.save_hash()

        if "use_bias_reset" in kwargs["learn_mapping"]["model"]:
            kwargs["task_path"]["use_bias_reset"] = kwargs["learn_mapping"]["model"]["use_bias_reset"]
        else:
            kwargs["task_path"]["use_bias_reset"] = None
        if "wo_first_bias_reset" in kwargs["learn_mapping"]["model"]:
            kwargs["task_path"]["wo_first_bias_reset"] = kwargs["learn_mapping"]["model"]["wo_first_bias_reset"]
        else:
            kwargs["task_path"]["wo_first_bias_reset"] = None
        self.task_path: TaskPath = TaskPath(self.logdir, **kwargs["task_path"])

        self.supervised_learning = SupervisedLearning(logdir=self.logdir, **self.config['learn_mapping'])

    def run(self):
        with measure_runtime(self.logdir):
            for task in self.task_path:
                self.upload_task()

                with measure_runtime(task.logdir):
                    if task.type == 'supervised-learning':
                        task.output_model = self.supervised_learning.train(task)
                        self.supervised_learning.reset()
                    else:
                        raise NotImplementedError(f"The following task type is not implemented: {task.type}")

        self.upload_task()
        self.save_exit_code(0)

    def upload_task(self):
        try:
            uploader = SheetUploader(self.set_name)
            uploader.upload()

        except Exception:
            exc = traceback.format_exc()
            self.logger.exception(exc)
            self.logger.critical("Could not upload latest results. Likely connection problem to Google sheets.")

    def save_exit_code(self, code: int):
        with open(str(self.logdir / "exit_code.json"), "w") as fp:
            json.dump({"code": code}, fp, indent=4)

    def save_hash(self):
        with open(str(self.logdir / "hash.json"), "w") as fp:
            json.dump({"hash": self.hash}, fp, indent=4)
