import json
from collections import deque

import colorednoise as cn
import numpy as np
import torch
from tonic import agents

from .controllers import DEP


class OUNoise:
    def __init__(
        self,
        scale=0.1,
        clip=2,
        theta=0.15,
        dt=1e-2,
    ):
        self.scale = scale
        self.clip = clip
        self.theta = theta
        self.dt = dt
        self.test_episode_every = 1e16

    def initialize(self, observation_space, action_space, seed=None):
        self.action_size = action_space.shape[0]
        self.np_random = np.random.RandomState(seed)
        self.noises = None
        self.act_shape = action_space.shape

    def reset_ou(self):
        self.noises = None

    def step(self, observations):
        if len(observations.shape) != len(self.act_shape):
            shape = []
            shape.append(observations.shape[0])
            [shape.append(x) for x in self.act_shape]
            self.act_shape = shape
            self.reset_ou()
        if self.noises is None:
            self.noises = np.zeros(shape=self.act_shape)
        noises = self.np_random.normal(size=self.act_shape)
        noises = np.clip(noises, -self.clip, self.clip)
        self.noises -= self.theta * self.noises * self.dt
        self.noises += self.scale * np.sqrt(self.dt) * noises
        actions = (self.noises).astype(np.float32)
        actions = np.clip(actions, -1, 1)
        return actions

    def update(self, resets):
        if self.noises is not None:
            self.noises *= (1.0 - resets)[:, None]

    def set_params(self, param_dict):
        for k, v in param_dict.items():
            setattr(self, k, v)


class ColoredNoise:
    def __init__(self, noise_scale=10, beta=1, expected_length=1000):
        self.noise_scale_colored = noise_scale
        self.beta = beta
        self.expected_length = int(expected_length)
        self.has_init = 0
        self.test_episode_every = 1e16

    def initialize(self, observation_space, action_space, seed=None):
        self.obs_shape = observation_space.shape
        self.act_shape = action_space.shape
        self.init_cn()

    def init_cn(self):
        if not self.has_init:
            shape = []
            if len(self.act_shape) == 1:
                shape.append(self.act_shape[0])
            else:
                shape.append(self.act_shape[0] * self.act_shape[1])
            shape.append(1)
            shape.append(self.expected_length)

            self.noises = cn.powerlaw_psd_gaussian(self.beta, size=shape)
            self.noises = self.noises * self.noise_scale_colored
            self.noises = np.clip(self.noises, -1, 1)
            self.noises = np.swapaxes(self.noises, 0, -1)
            if len(self.act_shape) > 1:
                self.noises = self.noises.reshape(
                    self.expected_length, self.act_shape[0], self.act_shape[1]
                )
        self.has_init = 1
        self.noise_counter = 0

    def step(self, state, *args, **kwargs):
        if len(state.shape) > len(self.obs_shape):
            obs_shape = []
            obs_shape.append(state.shape[0])
            [obs_shape.append(x) for x in self.obs_shape]
            act_shape = []
            act_shape.append(state.shape[0])
            [act_shape.append(x) for x in self.act_shape]
            self.obs_shape = obs_shape
            self.act_shape = act_shape
            self.reset_color()
        return self.colored()

    def colored(self):
        if self.noise_counter >= self.expected_length - 10:
            self.reset_color()
        action = self.noises[self.noise_counter]
        self.noise_counter += 1
        return action

    def reset_color(self):
        self.has_init = 0
        self.init_cn()

    def set_params(self, param_dict):
        for k, v in param_dict.items():
            setattr(self, k, v)


def dep_factory(mix, instance):
    class UnmixedAgent(instance.__class__):
        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            return super().step(observations, steps)

        def test_step(self, observations, steps, tendon_states=None):
            # return super().step(observations, steps)
            return super().test_step(observations, steps)

    class InitialDEPAgent(instance.__class__):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.expl = DEP()

        def initialize(self, observation_space, action_space, seed=None):
            super().initialize(observation_space, action_space, seed)
            self.expl.initialize(observation_space, action_space, seed)

        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            if steps > (self.replay.steps_before_batches / 1):
                return super().step(observations, steps)
            actions = self.expl.step(tendon_states, steps)
            self.last_observations = observations.copy()
            self.last_actions = actions.copy()

            return actions

        def test_step(
            self, observations, steps, tendon_states=None, greedy_episode=None
        ):
            return super().test_step(observations, steps)
            # return self.expl.step(tendon_states, steps)

        def update(self, *args, **kwargs):
            super().update(*args, **kwargs)

    class AveragedDEPAgent(InitialDEPAgent):
        def __init__(self, *args, **kwargs):
            return super().__init__(*args, **kwargs)

        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            if steps > (self.replay.steps_before_batches / 1):
                return 0.01 * self.expl.step(tendon_states, steps) + (1 - 0.01) * super(
                    InitialDEPAgent, self
                ).step(observations, steps)
            actions = self.expl.step(tendon_states, steps)

            self.last_observations = observations.copy()
            self.last_actions = actions.copy()
            return actions

        def update(self, *args, **kwargs):
            super().update(*args, **kwargs)

        def test_step(self, observations, steps, tendon_states=None):
            # return 0.9 * self.expl.step(tendon_states, steps) + 0.1 * super().test_step(observations, steps)
            return super().test_step(observations, steps)

    class SwitchDEPAgent(InitialDEPAgent):
        def __init__(self, *args, **kwargs):
            print("SWITCH-DEP initialized")
            self.switch = 0
            self.since_switch = 1
            return super().__init__(*args, **kwargs)

        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            if steps > (self.replay.steps_before_batches / 1):
                if (
                    self.switch
                    and not self.since_switch % self.expl.intervention_length
                ):  # 6 for ostrich
                    self.switch = 0
                    self.since_switch = 1
                if not self.switch and not self.since_switch % self.expl.rl_length:
                    self.switch = 1
                    self.since_switch = 1
                self.since_switch += 1
                if not self.switch:
                    self.expl.step(tendon_states, steps)
                    return super().step(observations, steps)
            actions = self.expl.step(tendon_states, steps)

            self.last_observations = observations.copy()
            self.last_actions = actions.copy()
            return actions

        def update(self, *args, **kwargs):
            super().update(*args, **kwargs)

    class GreedyEpisodeSwitchDEPAgent(SwitchDEPAgent):
        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            if greedy_episode:
                return super(SwitchDEPAgent, self).step(
                    observations, steps, tendon_states
                )
            if steps > (self.replay.steps_before_batches / 1):
                if self.switch and not self.since_switch % 6:  # 6 for ostrich
                    self.switch = 0
                    self.since_switch = 1
                if not self.switch and not self.since_switch % 101:
                    self.switch = 1
                    self.since_switch = 1
                self.since_switch += 1
                if not self.switch:
                    self.expl.step(
                        tendon_states, steps
                    )  # important for DEP to keep learning
                    return super(SwitchDEPAgent, self).step(
                        observations, steps, tendon_states
                    )
            actions = self.expl.step(tendon_states, steps)

            self.last_observations = observations.copy()
            self.last_actions = actions.copy()
            return actions

    class RealProbabilitySwitchDEPAgent(SwitchDEPAgent):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # want policy to start first
            self.since_switch = 500

        def step(self, observations, steps, tendon_states=None, greedy_episode=False):
            if steps > (self.replay.steps_before_batches / 1):
                if greedy_episode:
                    return super(SwitchDEPAgent, self).step(
                        observations, steps, tendon_states
                    )
                if self.since_switch > self.expl.intervention_length:
                    # important for DEP to keep learning
                    self.expl.step(tendon_states, steps)
                    if np.random.uniform() < self.expl.intervention_proba:
                        self.since_switch = 0
                    self.since_switch += 1
                    return super(SwitchDEPAgent, self).step(
                        observations, steps, tendon_states
                    )
            actions = self.expl.step(tendon_states, steps)
            self.last_observations = observations.copy()
            self.last_actions = actions.copy()
            self.since_switch += 1
            return actions

        # def test_step(self, observations, steps, tendon_states=None, greedy_episode=False):
        #    #return super(SwitchDEPAgent, self).step(observations, steps, tendon_states)
        #    #self.expl.intervention_proba = 0.91
        #    steps = 1e7
        #    if self.since_switch > self.expl.intervention_length:
        #        # important for DEP to keep learning
        #        self.expl.step(tendon_states, steps)
        #        if np.random.uniform() < self.expl.intervention_proba:
        #            self.since_switch = 0
        #        self.since_switch += 1
        #        return super(SwitchDEPAgent, self).step(observations, steps, tendon_states)
        #    actions = self.expl.step(tendon_states, steps)
        #    print('DEP')
        #    self.last_observations = observations.copy()
        #    self.last_actions = actions.copy()
        #    self.since_switch += 1
        #    return actions

    class NotRunDEPBackground(SwitchDEPAgent):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # want policy to start first
            self.since_switch = 500

        def step(self, observations, steps, tendon_states=None, greedy_episode=False):
            # nstep = steps
            # pinit = 1.0
            # pend = 0.0001
            # N = 2e7
            # r = np.maximum((N - nstep)/N, 0)
            # self.expl.intervention_proba = (pinit - pend) * r + pend
            if steps > (self.replay.steps_before_batches / 1):
                if greedy_episode:
                    return super(SwitchDEPAgent, self).step(
                        observations, steps, tendon_states
                    )
                if self.since_switch > self.expl.intervention_length:
                    # important for DEP to keep learning
                    # self.expl.step(tendon_states, steps)
                    if np.random.uniform() < self.expl.intervention_proba:
                        self.since_switch = 0
                    self.since_switch += 1
                    return super(SwitchDEPAgent, self).step(
                        observations, steps, tendon_states
                    )
            actions = self.expl.step(tendon_states, steps)
            self.last_observations = observations.copy()
            self.last_actions = actions.copy()
            self.since_switch += 1
            return actions

    class DEPAgent(InitialDEPAgent):
        def step(self, observations, steps, tendon_states=None, greedy_episode=False):
            # print(tendon_states)
            if np.any(np.isnan(tendon_states)):
                print("tendon nan!")
            return self.expl.step(tendon_states, steps)

        def update(self, *args, **kwargs):
            pass

        def test_update(self, *args, **kwargs):
            pass

        def test_step(
            self, observations, steps, tendon_states=None, greedy_episode=False
        ):
            return self.expl.step(tendon_states, steps)

    class ColoredAgent(instance.__class__):
        def __init__(
            self, model=None, replay=None, actor_updater=None, critic_updater=None
        ):
            super().__init__(model, replay, actor_updater, critic_updater)
            self.expl = ColoredNoise()

        def initialize(self, observation_space, action_space, seed=None):
            super().initialize(observation_space, action_space, seed)
            self.expl.initialize(observation_space, action_space, seed)

        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            if steps > (self.replay.steps_before_batches / 1):
                return self.expl.step(observations) + super().step(observations, steps)
            actions = self.expl.step(observations)
            self.last_observations = observations.copy()
            self.last_actions = actions.copy()

            return actions

        def test_step(
            self, observations, steps, tendon_states=None, greedy_episode=None
        ):
            # return self.expl.step(tendon_states, steps)
            return super().test_step(observations, steps)

        def update(self, *args, **kwargs):
            super().update(*args, **kwargs)

    class OUAgent(instance.__class__):
        def __init__(
            self, model=None, replay=None, actor_updater=None, critic_updater=None
        ):
            super().__init__(model, replay, actor_updater, critic_updater)
            self.expl = OUNoise()

        def initialize(self, observation_space, action_space, seed=None):
            super().initialize(observation_space, action_space, seed)
            self.expl.initialize(observation_space, action_space, seed)

        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            if steps > (self.replay.steps_before_batches / 1):
                return self.expl.step(observations) + super().step(observations, steps)
            actions = self.expl.step(observations)
            self.last_observations = observations.copy()
            self.last_actions = actions.copy()
            return actions

        def test_step(
            self, observations, steps, tendon_states=None, greedy_episode=None
        ):
            # return self.step(observations, steps)
            return super().test_step(observations, steps)

        def update(self, *args, **kwargs):
            super().update(*args, **kwargs)

    if mix == 1:
        return InitialDEPAgent
    elif mix == 2:
        return AveragedDEPAgent
    elif mix == 3:
        return SwitchDEPAgent
    elif mix == 5:
        return GreedyEpisodeSwitchDEPAgent
    elif mix == 6:
        return RealProbabilitySwitchDEPAgent
    elif mix == 7:
        return DEPAgent
        return ColoredAgent
    elif mix == 12:
        return OUAgent
    elif mix == 13:
        return NotRunDEPBackground
    elif mix == 0:
        print("unmixedagent")
        return UnmixedAgent
    else:
        raise Exception("Invalid agent specified")
