from ..agents.Probability import ProbabilityEmpiricalMeasure
from ..agents.Policy import PolicyFiniteActionsFiniteStates
from ..agents.tabular.DiscreteAgent import DiscreteAgent
import matplotlib.pyplot as plt
import os
import deepdish as dd
import importlib
from modules.utils.Distances import binary_distance as dist


def load_probability_measure(data: dict) -> ProbabilityEmpiricalMeasure:
    space = data["space"]
    probability = data["probability"]
    return ProbabilityEmpiricalMeasure(space, probability)


def load_policy(data: dict) -> PolicyFiniteActionsFiniteStates:
    state_space = data["state_space"]
    action_space = data["action_space"]
    policy = []
    for i in range(len(state_space)):
        policy.append(load_probability_measure(data["policy"][i]))
    return PolicyFiniteActionsFiniteStates(state_space, action_space, policy)


def load_agent(data: dict) -> DiscreteAgent:
    state_space = data["state_space"]
    action_space = data["action_space"]
    policy = load_policy(data["policy"])
    module = importlib.import_module("DiscreteAgent")
    agent_name = data["agent_name"]
    class_ = getattr(module, agent_name)
    data["parameters"]["dist"] = dist # TODO integrate other dists
    return class_(state_space, action_space, policy, data["parameters"])


def save_to_file(folder: str, name: str, data: dict):
    dd.io.save(get_path(folder, name), data)


def read_from_file(folder: str, name: str, ext: str = "h5") -> dict:
    return dd.io.load(get_path(folder, name, ext))


def save_current_fig(folder: str, name: str) -> dict:
    return plt.savefig(get_path(folder, name, "png"))


def get_path(folder: str, name: str, ext: str = "h5"):
#     dirname = os.path.dirname(__file__)
    return os.path.join(folder, name + "." + ext)