import jax
import jax.numpy as np
from jax import jit
# import numpy as np
from functools import partial
import json
import gym
from collections import deque
# jax.config.update('jax_platform_name', 'cpu')


def rescale_action(action_fn):
    def new_fn(cls, observations):
        action = action_fn(cls, observations)
        return (action.copy() + 1) / 2.
    return new_fn


class DEP:

    def __init__(self, params_path='./param_files/default_params.json'):
        with open(params_path, 'r') as f:
            self.params = json.load(f)['DEP']

    def initialize(self, observation_space, action_space, seed=None):
        action_space = gym.spaces.Box(low=-1, high=1, shape=(action_space.shape))
        self.obs_spec = observation_space.shape
        has_state = True
        self.has_init = False
        self._parse_arch_params(**self.params)
        self.num_sensors = action_space.shape[0]
        self.num_motors = action_space.shape[0]

        self.act_scale = self.act_high = action_space.high

        self.buffer = deque(maxlen=self.buffer_size)
        self.action_space = action_space

    def reinit(self, obs_spec):
        if len(obs_spec) > 1:
            n_env = obs_spec[0]
        else:
            print('single env')
            n_env = 1
        self.M = np.broadcast_to( - np.eye(self.num_motors, self.num_sensors), (n_env, self.num_motors, self.num_sensors))
        self.Mb = np.zeros((n_env, self.num_sensors))
        # Unnormalized controller matrix
        self.C_unnorm = np.zeros((n_env, self.num_motors, self.num_sensors))
        self.C = np.zeros((n_env, self.num_motors, self.num_sensors))  # Normalized controller matrix
        self.Cb = np.zeros((n_env, self.num_motors))  # Controller biases
        self.q_norm = 0
        self.obs_smoothed = np.zeros((n_env, self.num_sensors))

    def step(self, observations, steps=None):
        if observations.shape != self.obs_spec:
            self.reinit(observations.shape)
            self.obs_spec = observations.shape
        if not self.has_init:
            self.beginning_of_rollout()
            self.has_init = True
        if len(observations.shape) == 1:
            observations = observations[np.newaxis, :]
        return self.jax_get_action(observations).copy()
        return self.get_action(observations).copy()

    def test_step(self, observations, steps=None):
        return self.step(observations, steps=None)

    def update(*args, **kwargs):
        pass

    def test_update(*args, **kwargs):
        pass

    def save(*args, **kwargs):
        pass

    def load(*args, **kwargs):
        pass

    def _parse_arch_params(
            self,
            *,
            kappa,
            tau,
            bias_rate,
            time_dist,
            normalization,
            s4avg=2,
            buffer_size=150,
            sensor_delay=1,
            regularization=4,
            with_learning=True,
            q_norm_selector="l2",
            intervention_length=6,
            intervention_proba=0.001,
            test_episode_every=5,
            force_scaling=10
    ):

        self.kappa = kappa
        self.tau = tau
        self.bias_rate = bias_rate
        self.buffer_size = buffer_size
        self.time_dist = time_dist
        self.s4avg = s4avg
        self.normalization = normalization
        self.sensor_delay = sensor_delay
        self.regularization = regularization
        self.with_learning = with_learning
        self.q_norm_selector = q_norm_selector
        self.intervention_length = intervention_length
        self.intervention_proba = intervention_proba
        self.test_episode_every = test_episode_every
        self.force_scaling = force_scaling

    def set_params(self, param_dict):
        for k, v in param_dict.items():
            setattr(self, k, v)

    def beginning_of_rollout(self):
        self.buffer.clear()
        self.t = 0
        self.obs_smoothed *= 0

    def end_of_rollout(self, total_time, total_return, mode):
        pass

    def _q_norm(self, q):
        reg = 10.0 ** (-self.regularization)
        if self.q_norm_selector == "l2":
            q_norm = 1.0 / (np.linalg.norm(q, axis=-1) + reg)
        elif self.q_norm_selector == "max":
            q_norm = 1.0 / (max(abs(q), axis=-1) + reg)
        elif self.q_norm_selector == "none":
            q_norm = 1.0
        else:
            raise NotImplementedError("q normalization {self.q_norm_selector} not implemented.")

        return q_norm

    # @rescale_action
    def get_action(self, obs):
        if self.s4avg > 1 and self.t > 0:
            self.obs_smoothed += (obs - self.obs_smoothed) / self.s4avg
        else:
            self.obs_smoothed = obs
        self.buffer.append([self.obs_smoothed.copy(), None])

        if self.with_learning and len(self.buffer) > (2 + self.time_dist):
            self._learn_controller()
        q = np.einsum('ijk,ik->ij', self.C, self.obs_smoothed)

        q = np.einsum('ij,i->ij', q, self._q_norm(q))
        y = np.maximum(-1, np.minimum(1, np.tanh(q * self.kappa + self.Cb)))
        y = np.einsum('ij,j->ij', y, self.act_scale)

        self.buffer[-1][1] = y.copy()

        self.t += 1
        return y

    # @rescale_action
    def jax_get_action(self, obs):
        if self.s4avg > 1 and self.t > 0:
            self.obs_smoothed += (obs - self.obs_smoothed) / self.s4avg
        else:
            self.obs_smoothed = obs

        self.buffer.append([self.obs_smoothed.copy(), None])

        if self.with_learning and len(self.buffer) > (2 + self.time_dist):
            self.C, self.Cb, self.C_unnorm = self._jax_learn_controller(self.C_unnorm, self.M, self.regularization, self.kappa, self.normalization, self.bias_rate, self.Cb, self.t, self.time_dist, self.tau, self.buffer, self.num_sensors)
        y = self._jax_get_action(self.C, self.obs_smoothed, self.kappa, self.act_scale, self.Cb, self.regularization, self.q_norm_selector)
        self.buffer[-1][1] = y.copy()
        self.t += 1
        return y

    def _jax_q_norm(self, q, regularization, q_norm_selector):
        reg = 10.0 ** (-regularization)
        if q_norm_selector == "l2":
            q_norm = 1.0 / (np.linalg.norm(q, axis=-1) + reg)
        elif q_norm_selector == "max":
            q_norm = 1.0 / (max(abs(q), axis=-1) + reg)
        elif q_norm_selector == "none":
            q_norm = 1.0
        else:
            raise NotImplementedError("q normalization {self.q_norm_selector} not implemented.")

        return q_norm

    partial(jit, static_arguments=(0,))
    def _jax_get_action(self, C, obs_smoothed, kappa, act_scale, Cb, regularization, q_norm_selector):
        q = np.einsum('ijk, ik->ij', C, obs_smoothed)

        q = np.einsum('ij, i->ij', q, self._jax_q_norm(q, regularization, q_norm_selector))
        y = np.maximum(-1, np.minimum(1, np.tanh(q * kappa + Cb)))
        y = np.einsum('ij, j->ij', y, act_scale)
        return y

    partial(jit, static_arguments=(0,))
    def _jax_learn_controller(self, C_unnorm, M, regularization, kappa, normalization, bias_rate, Cb, t, time_dist, tau, buffer, num_sensors):
        C_unnorm = self._jax_calc_C(C_unnorm, t, time_dist, tau, buffer, num_sensors, M)
        # linear response in motor space (action -> action)
        R = np.einsum('ijk, imk->ijm', C_unnorm, M)
        reg = 10.0 ** (-regularization)
        if normalization == "independent":
            factor = kappa / (np.linalg.norm(R, axis=-1) + reg)
            C = np.einsum('ijk,ik->ijk', C_unnorm, factor)
        elif normalization == "none":
            C = C_unnorm
        elif normalization == "global":
            norm = np.linalg.norm(R)
            C = C_unnorm * kappa / (norm + reg)
        else:
            raise NotImplementedError(f"Controller matrix normalization {normalization} not implemented.")

        if bias_rate >= 0:
            yy = buffer[-2][1]
            Cb -= np.clip(yy * bias_rate, -0.05, 0.05) + Cb * 0.001
        else:
            Cb *= 0
        return C, Cb, C_unnorm

    partial(jit, static_arguments=(0,))
    def _jax_calc_C(self, C_unnorm, t, time_dist, tau, buffer, num_sensors, M):
        C_unnorm = np.zeros_like(C_unnorm)
        for s in range(2, min(t - time_dist, tau)):
            x = buffer[-s][0][:, : num_sensors]
            xx = buffer[-s - 1][0][:, : num_sensors]
            xx_t = x if time_dist == 0 else buffer[-s - time_dist][0][:, : num_sensors]
            xxx_t = buffer[-s - 1 - time_dist][0][:, : num_sensors]

            chi = x - xx
            v = xx_t - xxx_t
            mu = np.einsum('ijk, ik->ij', M, chi)

            C_unnorm += np.einsum('ij, ik->ijk', mu, v)
        return C_unnorm

    # noinspection PyPep8Naming
    def _calc_C(self):
        C_unnorm = np.zeros_like(self.C_unnorm)
        for s in range(2, min(self.t - self.time_dist, self.tau)):
            x = self.buffer[-s][0][:, : self.num_sensors]
            xx = self.buffer[-s - 1][0][:, : self.num_sensors]
            xx_t = x if self.time_dist == 0 else self.buffer[-s - self.time_dist][0][:, : self.num_sensors]
            xxx_t = self.buffer[-s - 1 - self.time_dist][0][:, : self.num_sensors]

            chi = x - xx
            v = xx_t - xxx_t
            mu = np.einsum('ijk,ik->ij', self.M, chi)

            C_unnorm += np.einsum('ij, ik->ijk', mu, v)
        return C_unnorm

    # noinspection PyPep8Naming
    def _learn_controller(self):
        self.C_unnorm = self._calc_C()

        # linear response in motor space (action -> action)
        # R = np.dot(self.C_unnorm, self.M.transpose())
        R = np.einsum('ijk,imk->ijm', self.C_unnorm, self.M)
        reg = 10.0 ** (-self.regularization)
        if self.normalization == "independent":
            factor = self.kappa / (np.linalg.norm(R, axis=-1) + reg)
            self.C = np.einsum('ijk,ik->ijk', self.C_unnorm, factor)
            # self.C = self.C_unnorm * factor[:, np.newaxis]
        elif self.normalization == "none":
            self.C = self.C_unnorm
        elif self.normalization == "global":
            norm = np.linalg.norm(R)
            self.C = self.C_unnorm * self.kappa / (norm + reg)
        else:
            raise NotImplementedError(f"Controller matrix normalization {self.normalization} not implemented.")

        if self.bias_rate >= 0:
            yy = self.buffer[-2][1]
            self.Cb -= np.clip(yy * self.bias_rate, -0.05, 0.05) + self.Cb * 0.001
        else:
            self.Cb *= 0
