import json
import os
from shutil import copyfile

import optuna
from filelock import FileLock

from adl4cv.utils.utils import relative_file_lock, get_lock_folder, Serializable


class HyperParameterStorage:
    def __init__(self, storage_folder: str, scope: str):
        self.storage_folder = storage_folder
        self.scope = scope
        self.scope_folder = os.path.join(self.storage_folder, self.scope)

        self._storage = None

    @property
    def storage_path(self):
        return f"sqlite:///{os.path.join(self.scope_folder, 'study_storage.db')}"

    @property
    def summary_path(self):
        return os.path.join(self.scope_folder, 'summary.json')

    @property
    def storage(self):
        if self._storage is None:
            self._storage = optuna.storages.RDBStorage(url=self.storage_path)
        return self._storage

    def update_summary(self):
        summaries = self.storage.get_all_study_summaries()
        summary_dict = {}
        for summary in summaries:
            summary_dict[summary.study_name] = self._summary_to_dict(summary)

        with relative_file_lock(self.summary_path, timeout=10, root_path=get_lock_folder(self.storage_folder)):
            with open(self.summary_path, 'w', encoding='utf-8') as f:
                json.dump(summary_dict, f, ensure_ascii=False, indent=4)

    def _summary_to_dict(self, summary):
        result_dictionary = dict()
        best_trial = self._trial_to_dict(summary.best_trial)

        result_dictionary["best_trial"] = best_trial
        return result_dictionary

    def _trial_to_dict(self, trial: optuna.trial.FrozenTrial) -> dict:
        """
        Converts the results of a trial to dictionary
        :param trial: The trial to be converted
        :return: The trial as a dictionary
        """
        trial_dict = dict()
        trial_dict["duration"] = str(trial.duration)
        trial_dict["number"] = trial.number
        trial_dict["value"] = trial.value
        trial_dict["params"] = trial.params
        trial_dict["user_attrs"] = trial.user_attrs
        trial_dict["intermediate_values"] = trial.intermediate_values
        return trial_dict

    @property
    def _best_definition_file(self):
        """The parameter file of the best trial"""
        return os.path.join(self.scope_folder, "best_params.json")

    @property
    def best_definition(self):
        """The best optimization definition"""
        return Serializable.loads_from_file(self._best_definition_file)

    @property
    def best_trial(self):
        """The best optimization definition"""
        return self.storage.get_best_trial(self.storage.get_study_id_from_name(self.scope))

    def save_best_definition(self):
        """Saves the best optimization definition"""
        best_params_src = os.path.join(self.scope_folder,
                                       self.get_trial_name(self.best_trial.number),
                                       self.get_fold_name(0),
                                       "params.json")
        if not os.path.exists(best_params_src):
            best_params_src = os.path.join(self.scope_folder,
                                           self.get_trial_name(self.best_trial.number),
                                           "testset",
                                           "params.json")
        best_params_dst = self._best_definition_file

        with relative_file_lock(best_params_dst, timeout=10, root_path=get_lock_folder(self.storage_folder)):
            copyfile(best_params_src, best_params_dst)

    def get_trial_name(self, trial_id):
        return f"trial_{trial_id}"

    def get_fold_name(self, fold_id):
        return f"fold_{fold_id}"

