from collections import deque
import numpy as np
import colorednoise as cn
from matplotlib import pyplot as plt
import scipy.signal

class HomeoKinesisController:
    def __init__(self):
        self.number_motors = 6
        self.number_sensors = 6
        self.C = np.zeros(size=(self.number_motors, self.number_sensors))
        self.h = np.zeros(size=(self.number_motors,))
        self.A = np.zeros(size=(self.number_sensors, self.number_motors))
        self.b = np.zeros(size=(self.number_sensors,))
        self.g = lambda X: np.tanh(x)
        self.g_prime = lambda x: 1 - np.square(np.tanh(x)) 


class HomeoKinesisSplitBrainController(HomeoKinesisController):
    def __init__(self, params, sensors, motors, *args):
        self.set_params(**params)
        self.initialize_data_structures(sensors, motors)

    def set_params(self, 
                   s4avg=1,
                   creativity=0,
                   epsC=0.1,
                   epsA = 0.1,
                   s4delay=1,
                   loga=0,
                   buffer_size=150):

        self.s4avg = s4avg
        self.s4delay= s4delay
        self.creativity = creativity
        self.epsC = epsC
        self.epsA = epsA
        self.loga = loga
        self.buffer_size = buffer_size

    def initialize_data_structures(self, sensors, motors):
        self.number_motors = motors
        self.number_sensors = sensors
        self.C = np.ones(shape=(self.number_motors,))
        self.h = np.zeros(shape=(self.number_motors, ))
        self.A = np.ones(shape=(self.number_sensors,))
        self.b = np.zeros(shape=(self.number_sensors,))
        self.g = lambda x: np.tanh(x)
        self.g_prime = lambda x: 1 - np.square(np.tanh(x)) 
        self.v_avg = np.zeros(shape=(self.number_sensors))
        self.C = np.random.normal(size=(self.number_sensors,))
        self.buffer = deque(maxlen=self.buffer_size)
        self.t = 0

    def step(self, state):
        self._step(state)
        self._learn_step()
        return self.action.copy()

    def _learn_step(self):
        self._compute_matrix_updates()
        self._apply_matrix_updates()

    def _step(self, state):
        self._handle_new_state(state)
        self._compute_action()
        self.buffer[-1][1] = self.action.copy()

    def _compute_action(self):
        self.action = self.g(self.C * (self.x_smooth + (self.v_avg * self.creativity) + self.h))

    def _handle_new_state(self, x):
        if self.s4avg > 1 and not self.t == 0:
            self.x_smooth += (x - self.x_smooth) * (1.0/self.s4avg)
        else:
            self.x_smooth = x.copy()

        self.buffer.append([self.x_smooth.copy(), None])
        self.t += 1


    def _compute_matrix_updates(self):
        if self.t > self.s4delay + 1:
            x = self.buffer[- self.s4delay - 1][0]
            y = self.buffer[- self.s4delay - 1][1]
            x_fut = self.buffer[-1][0]

            z = self.C * (x + self.v_avg * self.creativity) + self.h
            gprime = self.g_prime(z)
            xsi = x_fut - (self.A * y + self.b)
            self.A_update = xsi * y * self.epsA -0.003 * self.A
            self.b_update = xsi * self.epsA -0.001 * self.b

            eta = (1./self.A) * xsi
            zeta = np.clip(eta * (1./(gprime+0.000001)), -1.0, 1.0)
            mu = (1./np.square(self.C) + 1e-8) * zeta
            v = np.clip(self.C * mu, -1.0, 1.0)
            self.v_avg += (v - self.v_avg) * 0.1
            if self.loga:
                EE = 0.1/(np.abs(v)+0.001)
            else:
                EE = 1.
            self.C_update = (mu * v + mu * y * zeta * (-2) * x ) * (EE * self.epsC)
            self.h_update = mu * y * zeta * (-2) * (EE * self.epsC)
        else:
            self.A_update = np.zeros_like(self.A)
            self.b_update = np.zeros_like(self.b)
            self.C_update = np.zeros_like(self.C)
            self.h_update = np.zeros_like(self.h)


    def _apply_matrix_updates(self):
        self.C += np.clip(self.C_update, -0.05, 0.05)
        self.h += np.clip(self.h_update, -0.1, 0.1)
        self.A += np.clip(self.A_update, -0.1, 0.1)
        self.b += np.clip(self.b_update, -0.1, 0.1)


class DEP:
    def __init__(self, params, sensors, motors, env):
        has_state = True
        self.has_init = False
        self.env = env
        self._parse_arch_params(**params)
        self.num_sensors = sensors
        self.num_motors = motors

        self.M = -np.eye(self.num_motors, self.num_sensors)  # Forward model matrix
        self.Mb = np.zeros(self.num_sensors)  # Forward model biases
        # Unnormalized controller matrix
        self.C_unnorm = np.zeros((self.num_motors, self.num_sensors))
        self.C = np.zeros((self.num_motors, self.num_sensors))  # Normalized controller matrix
        self.Cb = np.zeros(self.num_motors)  # Controller biases
        self.q_norm = 0
        self.obs_smoothed = np.zeros(self.num_sensors)

        # if not np.all(self.env.action_space.high == -self.env.action_space.low):
        #     raise ValueError(
        #         f"{self.__class__.__name__} only supports symmetric actions spaces, i.e. "
        #         "action_space.high == -action_space.low: {env.action_space.high} != "
        #         "{-env.action_space.low}"
        #     )
        try:
            self.act_scale = self.act_high = self.env.action_space.high
        except:
            self.act_scale = self.env.action_spec().maximum

        self.buffer = deque(maxlen=self.buffer_size)
        self.t = 0

    def reset(self):
        self.M = -np.eye(self.num_motors, self.num_sensors)  # Forward model matrix
        self.Mb = np.zeros(self.num_sensors)  # Forward model biases
        # Unnormalized controller matrix
        self.C_unnorm = np.zeros((self.num_motors, self.num_sensors))
        self.C = np.zeros((self.num_motors, self.num_sensors))  # Normalized controller matrix
        self.Cb = np.zeros(self.num_motors)  # Controller biases
        self.q_norm = 0
        self.obs_smoothed = np.zeros(self.num_sensors)
        self.buffer = deque(maxlen=self.buffer_size)
        self.t = 0

    def _parse_arch_params(
        self,
        *,
        kappa,
        tau,
        bias_rate,
        time_dist,
        normalization,
        s4avg=2,
        buffer_size=150,
        sensor_delay=1,
        regularization=4,
        with_learning=True,
        q_norm_selector="l2",
    ):

        self.kappa = kappa
        self.tau = tau
        self.bias_rate = bias_rate
        self.buffer_size = buffer_size
        self.time_dist = time_dist
        self.s4avg = s4avg
        self.normalization = normalization
        self.sensor_delay = sensor_delay
        self.regularization = regularization
        self.with_learning = with_learning
        self.q_norm_selector = q_norm_selector

    def beginning_of_rollout(self, *, observation, state=None, mode):
        self.buffer.clear()
        self.t = 0
        self.obs_smoothed *= 0

    def end_of_rollout(self, total_time, total_return, mode):
        pass

    def _q_norm(self, q):
        reg = 10.0 ** (-self.regularization)
        if self.q_norm_selector == "l2":
            q_norm = 1.0 / (np.linalg.norm(q) + reg)
        elif self.q_norm_selector == "max":
            q_norm = 1.0 / (max(abs(q)) + reg)
        elif self.q_norm_selector == "none":
            q_norm = 1.0
        else:
            raise NotImplementedError("q normalization {self.q_norm_selector} not implemented.")

        return q_norm

    def step(self, state):
        if not self.has_init:
            self.beginning_of_rollout(observation=state, mode="training")
            self.has_init = True
        return np.clip(self.get_action(state, state).copy(), -1 ,1)

    def __call__(self, state):
        return self.step(state)

    def get_action(self, obs, state, mode="train"):
        if self.s4avg > 1 and self.t > 0:
            self.obs_smoothed += (obs - self.obs_smoothed) / self.s4avg
        else:
            self.obs_smoothed = obs
        self.buffer.append([self.obs_smoothed.copy(), None])
        
        if self.with_learning and len(self.buffer) > (2 + self.time_dist):
            self._learn_controller()
        q = np.matmul(self.C, self.obs_smoothed)

        q = q * self._q_norm(q)

        y = np.maximum(-1, np.minimum(1, np.tanh(q * self.kappa + self.Cb))) * self.act_scale

        self.buffer[-1][1] = y.copy()

        self.t += 1
        

        return y

    # noinspection PyPep8Naming
    def _calc_C(self):
        C_unnorm = np.zeros_like(self.C_unnorm)
        for s in range(2, min(self.t - self.time_dist, self.tau)):
            x = self.buffer[-s][0][: self.num_sensors]
            xx = self.buffer[-s - 1][0][: self.num_sensors]
            xx_t = x if self.time_dist == 0 else self.buffer[-s - self.time_dist][0][: self.num_sensors]
            xxx_t = self.buffer[-s - 1 - self.time_dist][0][: self.num_sensors]

            chi = x - xx
            v = xx_t - xxx_t
            mu = np.dot(self.M, chi)

            C_unnorm += np.outer(mu, v)
        return C_unnorm

    # noinspection PyPep8Naming
    def _learn_controller(self):
        self.C_unnorm = self._calc_C()

        # linear response in motor space (action -> action)
        R = np.dot(self.C_unnorm, self.M.transpose())
        reg = 10.0 ** (-self.regularization)
        if self.normalization == "independent":
            factor = self.kappa / (np.linalg.norm(R, axis=1) + reg)
            self.C = self.C_unnorm * factor[:, np.newaxis]
        elif self.normalization == "none":
            self.C = self.C_unnorm
        elif self.normalization == "global":
            norm = np.linalg.norm(R)
            self.C = self.C_unnorm * self.kappa / (norm + reg)
        else:
            raise NotImplementedError(f"Controller matrix normalization {self.normalization} not implemented.")

        if self.bias_rate >= 0:
            yy = self.buffer[-2][1]
            self.Cb -= np.clip(yy * self.bias_rate, -0.05, 0.05) + self.Cb * 0.001
        else:
            self.Cb *= 0


def extended(noise):
    def extended_noise(cls):
        if not cls.has_init or cls.noise_counter >= cls.intermittency:
            cls.action = noise(cls)
            cls.noise_counter = 0
            cls.has_init = 1
        return cls.action
    return extended_noise


class NoiseController:
    def __init__(self, noise_params, action_space_shape, *args, **kwargs):
        if noise_params['noise_type'][:-2] == 'colored':
            noise_params['beta'] = float(noise_params['noise_type'][-1])
            noise_params['noise_type'] = noise_params['noise_type'][:-2]
        self.noise_fns = {'uniform': self.uniform,
                          'colored': self.colored,
                          'one_hot': self.one_hot,
                          'bang': self.bang_bang,
                          'ou': self.ou,
                          'ep_lin': self.ep_lin}
        self.has_init = 0
        self.noise_counter = 0
        self.ep_counter = 0
        self.num_actions = action_space_shape
        self.set_kwargs(noise_params)

    def set_kwargs(self, kwargs):
        for key, val in kwargs.items():
            setattr(self, key, val)

    def step(self, state):
        action = self.noise_fns[self.noise_type]()
        self.noise_counter += 1
        return np.clip(action, -1, 1)

    def colored(self):
        self.init_cn()
        if self.noise_counter >= self.x.shape[0]:
            self.has_init = 0
            self.noise_counter = 0
            self.init_cn()
        action = self.x[self.noise_counter, :]
        return action

    @extended
    def one_hot(self):
        idx = np.random.randint(0, self.num_actions)
        action = np.zeros((self.num_actions,))
        action[idx] = 1.
        return action

    @extended
    def bang_bang(self):
        return np.random.randint(0,2, size=[self.num_actions,])

    @extended
    def uniform(self):
        return np.random.normal(self.uniform_a, self.uniform_b, size=[self.num_actions,])

    def ornstein_uhlenbeck(self):
        self.init_ou()
        action = self.call_ou()
        return action

    def ep_lin(self):
        self.init_ep_lin()
        if not self.ep_counter % self.intermittency:
            self.goal = np.random.normal(0, 0.1, (10,))
            self.ep_counter = 0
        daction = np.matmul(self.b_inv,self.goal)
        action = 0.5 + daction
        self.ep_counter += 1
        return action

    def init_ep_lin(self):
        if not self.has_init:
            self.b_inv = np.load('B_matrix.npy')
            print(self.b_inv.shape)
            self.b_inv = np.linalg.pinv(self.b_inv)
        self.has_init = 1

    def init_ou(self):
        if not self.has_init:
            # OU variables
            self.x0 = None
            self.distro = np.random.normal(loc=self.mu, scale=self.sigma)
            self.x_prev = self.x0 if self.x0 is not None else np.zeros_like(self.mu)

    @extended
    def ou(self):
        """
        Ornstein-Uhlenbeck noise
        """
        self.init_ou()
        x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.sigma_ou * np.sqrt(self.dt) * \
        np.random.normal(self.mu, self.sigma, size=(self.num_actions,))
        self.x_prev = x
        return x

    def reset(self):
        """
        Only for ou noise
        """

    def init_cn(self):
        if not self.has_init:
            self.x = np.zeros([self.expected_length, self.num_actions])
            for i in range(self.num_actions):
                self.noises = cn.powerlaw_psd_gaussian(self.beta, self.expected_length)
                self.noises = self.noises * self.noise_scale_colored
                self.noises = np.clip(self.noises, -1, 1)
                self.x[:, i] = self.noises
            self.has_init = 1
            self.noise_counter = 0

    def reset(self, **kwargs):
        self.has_init = 0
        if kwargs:
            self.set_kwargs(kwargs)
