from abc import abstractmethod
from random import sample
from typing import List
from gym.spaces import Box
import numpy as np
import torch
from torch import optim
import copy
from modules.agents.Probability import ProbabilityMeasure, ProbabilityEmpiricalMeasure
from modules.nn.NNModels import ContinuousStatesDiscreteActionsNet,  ContinuousStatesContinuousActionsNet



class Policy:
    @abstractmethod
    def policy(self, state: int):
        "to be implemented"
        pass

    @abstractmethod
    def state_space(self, copy=True):
        "to be implemented"
        pass

    @abstractmethod
    def __call__(self, state):
        pass

    @abstractmethod
    def action_space(self, copy=True):
        "to be implemented"
        pass

    @abstractmethod
    def sample(self, state):
        "to be implemented"
        pass

    @abstractmethod
    def copy(self):
        "to be implemented"
        pass

    @abstractmethod
    def save(self):
        "to be implemented"
        pass


class PolicyFiniteActionsFiniteStates(Policy):
    def __init__(self, state_space: np.ndarray, action_space: np.ndarray, policy: List[ProbabilityMeasure] or None) -> None:
        self._state_space = state_space
        self._action_space = action_space
        self._n_state = len(state_space)
        self._n_action = len(action_space)
        # set policy
        if policy is None:
            policy = []
            for _ in range(self.n_state()):
                policy.append(ProbabilityEmpiricalMeasure(self.action_space(copy=False), None))
        self._policy = policy

    def __call__(self, state: int):
        return self.policy(state)

    def state_space(self, copy=True):
        if copy:
            return self._state_space.copy()
        else:
            return self._state_space

    def action_space(self, copy=True):
        if copy:
            return self._action_space.copy()
        else:
            return self._action_space

    def n_state(self):
        return self._n_state

    def n_action(self):
        return self._n_action

    def policy(self, state: int) -> ProbabilityEmpiricalMeasure:
        return self._policy[state]

    def sample(self, state: int):
        return self.policy(state).sample()

    def set_policy(self, state: int, policy: ProbabilityEmpiricalMeasure) -> None:
        if not policy.is_normalized():
            raise ValueError()
        self._policy[state] = policy

    def copy(self) -> Policy:
        policy = []
        for s in self.state_space(copy=False):
            policy.append(self.policy(s).copy())
        return PolicyFiniteActionsFiniteStates(self.state_space(), self.action_space(), policy)

    def save(self) -> dict:
        data = []
        for s in self.state_space():
            data.append(self.policy(s).save())
        return {"state_space": self.state_space(),
                "action_space": self.action_space(),
                "policy": data}


class PolicyFiniteActionsContinuousStates(Policy):
    def __init__(self, state_space: Box, action_space: np.ndarray, nn_parameters: dict) -> None:
        self._state_space = state_space
        self._action_space = action_space
        self._dim_state = state_space.shape[0]
        self._n_action = len(action_space)

        # NN parameters
        self._lr = nn_parameters["learning_rate"]
        self._nn_parameters = nn_parameters

        self._policy = ContinuousStatesDiscreteActionsNet(self.dim_state(),
                                                          self.n_action(),
                                                          hidden_layers_size=self._nn_parameters['hidden_layers_size'],
                                                          activation_functions=self._nn_parameters['activation_functions'])
        self._policy_opt = optim.Adam(self._policy.parameters(), lr=self._lr)

    def __call__(self, state: int):
        return self.policy(state)

    def state_space(self, copy=True):
        if copy:
            return Box(low=self._state_space.low, high=self._state_space.high)
        else:
            return self._state_space

    def action_space(self, copy=True):
        if copy:
            return self._action_space.copy()
        else:
            return self._action_space

    def dim_state(self):
        return self._dim_state

    def n_action(self):
        return self._n_action

    def nn_parameters(self) -> dict:
        return self._nn_parameters

    def policy(self, state: float, detach=True):
        x = torch.tensor(state, dtype=torch.float).view(-1, self.dim_state())
        y = self._policy(x).view(-1, self.n_action())
        if detach:
            prob = ProbabilityEmpiricalMeasure(self.action_space(copy=False), y.detach().numpy().reshape(-1))
            prob.normalize()
            return prob
        return y

    def policy_opt(self):
        return self._policy_opt

    def sample(self, state: float):
        # TODO optimize with a LRU cache -> store predictions for a given state
        return self.policy(state).sample()

    def copy(self) -> Policy:
        policy = PolicyFiniteActionsContinuousStates(self.state_space(), self.action_space(), self.nn_parameters())
        policy._policy.load_state_dict(self._policy.state_dict().copy())
        return policy

    def save(self) -> dict:
        data = self.nn_parameters().copy()
        data["nn_state"] = self._policy.state_dict()
        return {"state_space": self.state_space(copy=False),
                "action_space": self.action_space(copy=False),
                "policy": data}


class PolicyContinuousActionsContinuousStates(Policy):
    def __init__(self, state_space: Box, action_space: Box, nn_parameters: dict) -> None:
        self._state_space = state_space
        self._action_space = action_space
        self._dim_state = state_space.shape[0]
        self._dim_action = action_space.shape[0]

        # NN parameters
        self._lr = nn_parameters["learning_rate"]
        self._n_action_samples = nn_parameters["n_action_samples"]
        self._nn_parameters = nn_parameters
        self._policy = ContinuousStatesContinuousActionsNet(self.dim_state(),
                                                            self._n_action_samples,
                                                            self.action_space(copy=False).low[0],
                                                            self.action_space(copy=False).high[0],
                                                            hidden_layers_size=self._nn_parameters['hidden_layers_size'],
                                                            activation_functions=self._nn_parameters['activation_functions'])
        self._policy_opt = optim.Adam(self._policy.parameters(), lr=self._lr)

    def __call__(self, state: int):
        return self.policy(state)

    def state_space(self, copy=True):
        if copy:
            return Box(low=self._state_space.low, high=self._state_space.high)
        else:
            return self._state_space

    def action_space(self, copy=True):
        if copy:
            return Box(low=self._action_space.low, high=self._action_space.high)
        else:
            return self._action_space

    def dim_state(self):
        return self._dim_state

    def dim_action(self):
        return self._dim_action

    def n_action_samples(self):
        return self._n_action_samples

    def nn_parameters(self) -> dict:
        return self._nn_parameters

    def policy(self, state, detach: bool = True):
        if detach:
            state = torch.FloatTensor(state)
        probs, samples = self._policy.forward(state)
        if detach:
            return probs.detach().numpy(), samples.detach().numpy()
        return probs, samples

    def policy_opt(self):
        return self._policy_opt

    def sample(self, state: float):
        probs, samples = self.policy(state)
        pr = ProbabilityEmpiricalMeasure(space=samples,
                                            probability=probs)
        return pr.sample()

    def copy(self) -> Policy:
        policy = PolicyContinuousActionsContinuousStates(self.state_space(), self.action_space(), self.nn_parameters())
        policy._policy.load_state_dict(copy.deepcopy(self._policy.state_dict()))
        return policy

    def save(self) -> dict:
        data = self.nn_parameters().copy()
        data["nn_state"] = self._policy.state_dict()
        return {"state_space": self.state_space(copy=False),
                "action_space": self.action_space(copy=False),
                "policy": data}
