import json
import os
from collections import deque
from functools import partial

import gym
#import jax
#import jax.numpy as np
import numpy as np
#from jax import jit

# jax.config.update('jax_platform_name', 'cpu')


def rescale_action(action_fn):
    def new_fn(cls, observations):
        action = action_fn(cls, observations)
        return (action.copy() + 1) / 2.0

    return new_fn


class DEP:
    def __init__(self, params_path="default_path"):
        if params_path == "default_path":
            dirname = os.path.dirname(__file__)
            params_path = os.path.join(dirname, "param_files/default_agents.json")

        with open(params_path, "r") as f:
            self.params = json.load(f)["DEP"]

    def initialize(self, observation_space, action_space, seed=None):
        action_space = gym.spaces.Box(low=-1, high=1, shape=(action_space.shape))
        self.obs_spec = observation_space.shape
        self.has_init = False
        self._parse_arch_params(**self.params)
        self.num_sensors = action_space.shape[0]
        self.num_motors = action_space.shape[0]

        self.act_scale = self.act_high = action_space.high

        self.buffer = deque(maxlen=self.buffer_size)
        self.action_space = action_space

    def reset(self):
        if hasattr(self, "n_env"):
            n_env = self.n_env
        else:
            n_env = 1
        self.M = np.broadcast_to(
            -np.eye(self.num_motors, self.num_sensors),
            (n_env, self.num_motors, self.num_sensors),
        )
        self.Mb = np.zeros((n_env, self.num_sensors))
        # Unnormalized controller matrix
        self.C_unnorm = np.zeros((n_env, self.num_motors, self.num_sensors))
        self.C = np.zeros(
            (n_env, self.num_motors, self.num_sensors)
        )  # Normalized controller matrix
        self.Cb = np.zeros((n_env, self.num_motors))  # Controller biases
        self.q_norm = 0
        self.obs_smoothed = np.zeros((n_env, self.num_sensors))
        self.buffer = deque(maxlen=self.buffer_size)

    def reinit(self, obs_spec):
        if len(obs_spec) > 1:
            n_env = obs_spec[0]
        else:
            print("single env")
            n_env = 1
        self.M = np.broadcast_to(
            -np.eye(self.num_motors, self.num_sensors),
            (n_env, self.num_motors, self.num_sensors),
        )
        self.Mb = np.zeros((n_env, self.num_sensors))
        # Unnormalized controller matrix
        self.C_unnorm = np.zeros((n_env, self.num_motors, self.num_sensors))
        self.C = np.zeros(
            (n_env, self.num_motors, self.num_sensors)
        )  # Normalized controller matrix
        self.Cb = np.zeros((n_env, self.num_motors))  # Controller biases
        self.q_norm = 0
        self.obs_smoothed = np.zeros((n_env, self.num_sensors))

    def step(self, observations, steps=None):
        if observations.shape != self.obs_spec:
            self.reinit(observations.shape)
            self.obs_spec = observations.shape
        if not self.has_init:
            self.beginning_of_rollout()
            self.has_init = True
        if len(observations.shape) == 1:
            observations = observations[np.newaxis, :]
        #return nunp.asarray(self.jax_get_action(observations)).copy()
        return self.get_action(observations).copy()

    def test_step(self, observations, steps=None):
        return self.step(observations, steps=None)

    def update(*args, **kwargs):
        pass

    def test_update(*args, **kwargs):
        pass

    def save(*args, **kwargs):
        pass

    def load(*args, **kwargs):
        pass

    def _parse_arch_params(
        self,
        *,
        kappa,
        tau,
        bias_rate,
        time_dist,
        normalization,
        s4avg=2,
        buffer_size=150,
        sensor_delay=1,
        regularization=4,
        with_learning=True,
        q_norm_selector="l2",
        intervention_length=6,
        intervention_proba=0.001,
        test_episode_every=5,
        force_scaling=10,
        rl_length=101,
    ):

        self.kappa = kappa
        self.tau = tau
        self.bias_rate = bias_rate
        self.buffer_size = buffer_size
        self.time_dist = time_dist
        self.s4avg = s4avg
        self.normalization = normalization
        self.sensor_delay = sensor_delay
        self.regularization = regularization
        self.with_learning = with_learning
        self.q_norm_selector = q_norm_selector
        self.intervention_length = intervention_length
        self.intervention_proba = intervention_proba
        self.test_episode_every = test_episode_every
        self.force_scaling = force_scaling
        self.rl_length = rl_length

    def set_params(self, param_dict):
        for k, v in param_dict.items():
            setattr(self, k, v)
        self.reset()
        # raise Exception(f'{param_dict}')

    def beginning_of_rollout(self):
        self.buffer.clear()
        self.t = 0
        self.obs_smoothed *= 0

    def end_of_rollout(self, total_time, total_return, mode):
        pass

    def _q_norm(self, q):
        reg = 10.0 ** (-self.regularization)
        if self.q_norm_selector == "l2":
            q_norm = 1.0 / (np.linalg.norm(q, axis=-1) + reg)
        elif self.q_norm_selector == "max":
            q_norm = 1.0 / (max(abs(q), axis=-1) + reg)
        elif self.q_norm_selector == "none":
            q_norm = 1.0
        else:
            raise NotImplementedError(
                "q normalization {self.q_norm_selector} not implemented."
            )

        return q_norm

    # @rescale_action
    def get_action(self, obs):
        if self.s4avg > 1 and self.t > 0:
            self.obs_smoothed += (obs - self.obs_smoothed) / self.s4avg
        else:
            self.obs_smoothed = obs
        self.buffer.append([self.obs_smoothed.copy(), None])

        if self.with_learning and len(self.buffer) > (2 + self.time_dist):
            self._learn_controller()
        q = np.einsum("ijk,ik->ij", self.C, self.obs_smoothed)

        q = np.einsum("ij,i->ij", q, self._q_norm(q))
        y = np.maximum(-1, np.minimum(1, np.tanh(q * self.kappa + self.Cb)))
        y = np.einsum("ij,j->ij", y, self.act_scale)

        self.buffer[-1][1] = y.copy()

        self.t += 1
        return y

    # @rescale_action
    def jax_get_action(self, obs):
        if self.s4avg > 1 and self.t > 0:
            self.obs_smoothed += (obs - self.obs_smoothed) / self.s4avg
        else:
            self.obs_smoothed = obs

        self.buffer.append([self.obs_smoothed.copy(), None])

        if self.with_learning and len(self.buffer) > (2 + self.time_dist):
            self.C, self.Cb, self.C_unnorm = self._jax_learn_controller(
                self.C_unnorm,
                self.M,
                self.regularization,
                self.kappa,
                self.normalization,
                self.bias_rate,
                self.Cb,
                self.t,
                self.time_dist,
                self.tau,
                self.buffer,
                self.num_sensors,
            )
        y = self._jax_get_action(
            self.C,
            self.obs_smoothed,
            self.kappa,
            self.act_scale,
            self.Cb,
            self.regularization,
            self.q_norm_selector,
        )
        self.buffer[-1][1] = y.copy()
        self.t += 1
        return y

    def _jax_q_norm(self, q, regularization, q_norm_selector):
        reg = 10.0 ** (-regularization)
        if q_norm_selector == "l2":
            q_norm = 1.0 / (np.linalg.norm(q, axis=-1) + reg)
        elif q_norm_selector == "max":
            q_norm = 1.0 / (max(abs(q), axis=-1) + reg)
        elif q_norm_selector == "none":
            q_norm = 1.0
        else:
            raise NotImplementedError(
                "q normalization {self.q_norm_selector} not implemented."
            )

        return q_norm

    #partial(jit, static_arguments=(0,))
    def _jax_get_action(
        self, C, obs_smoothed, kappa, act_scale, Cb, regularization, q_norm_selector
    ):
        q = np.einsum("ijk, ik->ij", C, obs_smoothed)

        q = np.einsum(
            "ij, i->ij", q, self._jax_q_norm(q, regularization, q_norm_selector)
        )
        y = np.maximum(-1, np.minimum(1, np.tanh(q * kappa + Cb)))
        y = np.einsum("ij, j->ij", y, act_scale)
        return y

    #partial(jit, static_arguments=(0,))
    def _jax_learn_controller(
        self,
        C_unnorm,
        M,
        regularization,
        kappa,
        normalization,
        bias_rate,
        Cb,
        t,
        time_dist,
        tau,
        buffer,
        num_sensors,
    ):
        C_unnorm = self._jax_calc_C(C_unnorm, t, time_dist, tau, buffer, num_sensors, M)
        # linear response in motor space (action -> action)
        R = np.einsum("ijk, imk->ijm", C_unnorm, M)
        reg = 10.0 ** (-regularization)
        if normalization == "independent":
            factor = kappa / (np.linalg.norm(R, axis=-1) + reg)
            C = np.einsum("ijk,ik->ijk", C_unnorm, factor)
        elif normalization == "none":
            C = C_unnorm
        elif normalization == "global":
            norm = np.linalg.norm(R)
            C = C_unnorm * kappa / (norm + reg)
        else:
            raise NotImplementedError(
                f"Controller matrix normalization {normalization} not implemented."
            )

        if bias_rate >= 0:
            yy = buffer[-2][1]
            Cb -= np.clip(yy * bias_rate, -0.05, 0.05) + Cb * 0.001
        else:
            Cb *= 0
        return C, Cb, C_unnorm

    #partial(jit, static_arguments=(0,))
    def _jax_calc_C(self, C_unnorm, t, time_dist, tau, buffer, num_sensors, M):
        C_unnorm = np.zeros_like(C_unnorm)
        for s in range(2, min(t - time_dist, tau)):
            x = buffer[-s][0][:, :num_sensors]
            xx = buffer[-s - 1][0][:, :num_sensors]
            xx_t = x if time_dist == 0 else buffer[-s - time_dist][0][:, :num_sensors]
            xxx_t = buffer[-s - 1 - time_dist][0][:, :num_sensors]

            chi = x - xx
            v = xx_t - xxx_t
            mu = np.einsum("ijk, ik->ij", M, chi)

            C_unnorm += np.einsum("ij, ik->ijk", mu, v)
        return C_unnorm

    # noinspection PyPep8Naming
    def _calc_C(self):
        C_unnorm = np.zeros_like(self.C_unnorm)
        for s in range(2, min(self.t - self.time_dist, self.tau)):
            x = self.buffer[-s][0][:, : self.num_sensors]
            xx = self.buffer[-s - 1][0][:, : self.num_sensors]
            xx_t = (
                x
                if self.time_dist == 0
                else self.buffer[-s - self.time_dist][0][:, : self.num_sensors]
            )
            xxx_t = self.buffer[-s - 1 - self.time_dist][0][:, : self.num_sensors]

            chi = x - xx
            v = xx_t - xxx_t
            mu = np.einsum("ijk,ik->ij", self.M, chi)

            C_unnorm += np.einsum("ij, ik->ijk", mu, v)
        return C_unnorm

    # noinspection PyPep8Naming
    def _learn_controller(self):
        self.C_unnorm = self._calc_C()

        # linear response in motor space (action -> action)
        # R = np.dot(self.C_unnorm, self.M.transpose())
        R = np.einsum("ijk,imk->ijm", self.C_unnorm, self.M)
        reg = 10.0 ** (-self.regularization)
        if self.normalization == "independent":
            factor = self.kappa / (np.linalg.norm(R, axis=-1) + reg)
            self.C = np.einsum("ijk,ik->ijk", self.C_unnorm, factor)
            # self.C = self.C_unnorm * factor[:, np.newaxis]
        elif self.normalization == "none":
            self.C = self.C_unnorm
        elif self.normalization == "global":
            norm = np.linalg.norm(R)
            self.C = self.C_unnorm * self.kappa / (norm + reg)
        else:
            raise NotImplementedError(
                f"Controller matrix normalization {self.normalization} not implemented."
            )

        if self.bias_rate >= 0:
            yy = self.buffer[-2][1]
            self.Cb -= np.clip(yy * self.bias_rate, -0.05, 0.05) + self.Cb * 0.001
        else:
            self.Cb *= 0
