"""Policies classes for environment control."""

import numpy as np

from .base_policy import DiscretePolicy


def mean_action_trajectory(policy, states):
    ret = np.empty((states.shape[0], policy.dim_A))
    for i in range(states.shape[0]):
        ret[i] = policy.mean(states[i])
    return ret


class NoisyContinuous(object):
    def __init__(self, noise=None, dim_A=None):
        if noise is None:
            self.dim_A = dim_A
            self.noise = np.zeros(self.dim_A)
        else:
            self.noise = noise
            self.dim_A = len(noise)
        self.approx_noise = self.noise.copy()
        self.approx_noise[self.approx_noise == 0] = 10e-50
        self.precision = 1. / self.approx_noise

    def __call__(self, s, n_samples=1):
        m = self.mean(s)
        noise = np.sqrt(self.noise[None, :]) * np.random.randn(n_samples, self.dim_A)
        if n_samples == 1:
            return (m + noise).flatten()
        else:
            return m + noise

    def p(self, s, a, mean=None):
        m_a = mean - a if mean is not None else self.mean(s) - a
        return np.exp(-.5 * (m_a * m_a * self.precision).sum()) / \
            ((2 * np.pi) ** (float(self.dim_A) / 2)) / np.sqrt(
                self.approx_noise.sum())


class LinearContinuous(NoisyContinuous):
    def __init__(self, theta=None, noise=None, dim_S=None, dim_A=None):
        NoisyContinuous.__init__(self, noise=noise, dim_A=dim_A)
        if theta is None:
            self.dim_S = dim_S
            self.dim_A = dim_A
            self.theta = np.zeros((self.dim_A, self.dim_S))
        else:
            self.theta = np.atleast_2d(theta)
            self.dim_A, self.dim_S = self.theta.shape

    def mean(self, s):
        return np.array(np.dot(self.theta, s)).flatten()

    def __repr__(self):
        return f"LinearContinuous({repr(self.theta)}, {repr(self.noise)})"


Discrete = DiscretePolicy
