from abc import abstractmethod
from typing import List
import numpy as np
import numpy.linalg as la
from scipy.optimize import linprog


class ProbabilityMeasure:
    @abstractmethod
    def normalize(self):
        "to be implemented"
        pass

    @abstractmethod
    def sample(self, n_samples):
        "to be implemented"
        pass

    @abstractmethod
    def mean(self):
        "to be implemented"
        pass

    @abstractmethod
    def variance(self):
        "to be implemented"
        pass

    @abstractmethod
    def expected_value_function(self, fun):
        "to be implemented"
        pass

    @abstractmethod
    def space(self, copy):
        "to be implemented"
        pass

    @abstractmethod
    def probability(self, copy):
        "to be implemented"
        pass

    @abstractmethod
    def save(self):
        "to be implemented"
        pass


class ProbabilityEmpiricalMeasure(ProbabilityMeasure):
    def __init__(self, space: np.ndarray or int or np.int64, probability: np.ndarray or None, state_idx_correspond=True, tol:int=4):
        # state
        if type(space) is int or type(space) is np.int64:
            space = np.arange(space)
        # assert len(space.shape) == 1, 'Check input dimensions' # now can be 2d
        self._space = space
        self._n = len(space)
        # probabilities
        if probability is None:
            self.set_uniform()
        else:
            assert len(probability.shape) == 1, 'Check input dimensions'
        #     assert np.all(space.shape == probability.shape), 'Check input dimensions' # not valid check anymore
            self._probability = probability

        # optional optimization
        self._tol = tol
        self._state_idx_correspond = state_idx_correspond
        if not self._state_idx_correspond:
            self._space.sort()

    def space(self, copy=True):
        if copy:
            return self._space.copy()
        return self._space

    def n(self):
        return self._n

    def probability(self, copy=True):
        if copy:
            return self._probability.copy()
        return self._probability

    def get_index(self, states: np.ndarray or List[np.ndarray]):
        if self._state_idx_correspond:
            # optional optimization
            return np.array(states)
        
        def find_state(s):
            s.round(decimals=self._tol)
            i = 0
            dist = []
            while i < self.n():
                dist.append(la.norm(self.space(copy=False)[i] - s, 2))
                if la.norm(self.space(copy=False)[i] - s, 2) < 10**(-self._tol):
                    return i
                i += 1
            raise ValueError("State not valid")

        idxs = []
        if isinstance(states, List):
            for s in states:
                idxs.append(find_state(s))
        else:
            idxs.append(find_state(states))

        return np.array(idxs)

    @staticmethod
    def convert_to_array(x: np.ndarray or list or int) -> np.ndarray:
        if type(x) is int:
            return x
        return np.array(x)

    def total_probability_mass(self) -> float:
        return float(np.sum(self.probability(copy=False)))

    def is_normalized(self) -> bool:
        return np.abs(self.total_probability_mass() - 1.0) < 1e-6

    def normalize(self) -> None:
        sum_prob = self.total_probability_mass()
        if sum_prob == 0:
            ValueError()
        else:
            self._probability /= sum_prob

    def set_uniform(self) -> None:
        self._probability = np.ones(self.n()) / self.n()

    def sample(self, n_samples=1) -> np.ndarray or int:
        if not self.is_normalized():
            raise ValueError()
        samples = np.random.choice(self.space(copy=False), size=n_samples, p=self.probability(copy=False))
        if n_samples == 1:
            return samples[0]
        return samples

    def assign_probability(self, states: np.ndarray, probability: np.ndarray) -> None:
        states = self.convert_to_array(states)
        probability = self.convert_to_array(probability)
        self.probability(copy=False)[self.get_index(states)] = probability

    def get_probability(self, states: int or list or np.ndarray) -> np.ndarray:
        states = self.convert_to_array(states)
        return self.probability(copy=False)[self.get_index(states)]

    def add_probability(self, states: int or list or np.ndarray, probability: float or list or np.ndarray) -> None:
        states = self.convert_to_array(states)
        probability = self.convert_to_array(probability)
        self.probability(copy=False)[self.get_index(states)] += probability

    def mean(self) -> float:
        return float(np.dot(self.space(copy=False), self.probability(copy=False)))

    def variance(self) -> float:
        return np.dot(np.power(self.space(copy=False), 2), self.probability(copy=False)) - np.power(self.mean(), 2)

    def expected_value_function(self, fun: list or np.ndarray) -> float:
        fun = np.asarray(fun)
        assert len(fun.shape) == 1 and len(fun) == self.n()
        return float(np.dot(fun, self.probability(copy=False)))

    def wasserstein_distance(self, other, distance: callable, p: float) -> float:
        # distance matrix
        d = np.zeros((self.n(), other.n()))
        A_eq1 = np.zeros((self.n(), self.n(), other.n()))
        A_eq2 = np.zeros((other.n(), self.n(), other.n()))
        for i in range(self.n()):
            for j in range(other.n()):
                d[i, j] = distance(self.space(copy=False)[i], other.space(copy=False)[j])
                A_eq1[i, i, j] = 1
                A_eq2[j, i, j] = 1
        d = np.power(d, p)
        c = np.reshape(d, [self.n() * other.n()])
        b_eq = np.concatenate((self.probability(copy=False), other.probability(copy=False)))
        A_eq = np.concatenate((np.reshape(A_eq1, [self.n(), -1]), np.reshape(A_eq2, [other.n(), -1])))
        res = linprog(c, A_eq=A_eq, b_eq=b_eq)  # by default, decision variables are non-negative
        return np.power(res.fun, 1.0 / p)

    def kl_divergence(self, other):
        p = self.probability(copy=False)
        q = other.probability(copy=False)
        return np.sum(np.where(p != 0, p * np.log(p / (q + 1e-6)), 0))

    def sort_space(self):
        if len(self.space(copy=False).shape) != 1:
            ValueError("To sort the array should be 1D.")
        idx = np.argsort(self.space(copy=False))
        self._space = self._space[idx]
        self._probability = self._probability[idx]

    def copy(self) -> ProbabilityMeasure:
        return ProbabilityEmpiricalMeasure(self.space(copy=True), self.probability(copy=True))

    def save(self) -> dict:
        return {"space": self.space(copy=False),
                "probability": self.probability(copy=False)}

    def __mul__(self, alpha: float) -> ProbabilityMeasure:
        return ProbabilityEmpiricalMeasure(self.space(copy=True), alpha * self.probability(copy=True))

    def __rmul__(self, alpha: float) -> ProbabilityMeasure:
        return ProbabilityEmpiricalMeasure(self.space(copy=True), alpha * self.probability(copy=True))

    def __add__(self, other: ProbabilityMeasure) -> ProbabilityMeasure:
        if np.all(self.space(copy=False) == other.space(copy=False)):
            return ProbabilityEmpiricalMeasure(self.space(copy=True), self.probability(copy=True) + other.probability(copy=True))
        else:
            raise ValueError()