from collections import deque
from tonic import agents
# import numpy as np
import json
# import torch
import numpy as np
import torch
from controllers import DEP
# from new_dep import DEP
import colorednoise as cn


class OUNoise:
    def __init__(
        self, scale=0.1, clip=2, theta=0.15, dt=1e-2,
    ):
        self.scale = scale
        self.clip = clip
        self.theta = theta
        self.dt = dt
        self.test_episode_every = 1e16

    def initialize(self, observation_space, action_space, seed=None):
        self.action_size = action_space.shape[0]
        self.np_random = np.random.RandomState(seed)
        self.noises = None
        self.act_shape = action_space.shape

    def reset_ou(self):
        self.noises = None

    def step(self, observations):
        if len(observations.shape) != len(self.act_shape):
            shape = []
            shape.append(observations.shape[0])
            [shape.append(x) for x in self.act_shape]
            self.act_shape = shape
            self.reset_ou()
        if self.noises is None:
            self.noises = np.zeros(shape=self.act_shape)
        noises = self.np_random.normal(size=self.act_shape)
        noises = np.clip(noises, -self.clip, self.clip)
        self.noises -= self.theta * self.noises * self.dt
        self.noises += self.scale * np.sqrt(self.dt) * noises
        actions = (self.noises).astype(np.float32)
        actions = np.clip(actions, -1, 1)
        return actions

    def update(self, resets):
        if self.noises is not None:
            self.noises *= (1. - resets)[:, None]

    def set_params(self, param_dict):
        for k, v in param_dict.items():
            setattr(self, k, v)


class ColoredNoise:
    def __init__(self, noise_scale=10, beta=1, expected_length = 1000):
        self.noise_scale_colored = noise_scale
        self.beta = beta
        self.expected_length = int(expected_length)
        self.has_init = 0
        self.test_episode_every = 1e16

    def initialize(self, observation_space, action_space, seed=None):
        self.obs_shape = observation_space.shape
        self.act_shape = action_space.shape
        self.init_cn()

    def init_cn(self):
        if not self.has_init:
            shape = []
            if len(self.act_shape) == 1:
                shape.append(self.act_shape[0])
            else:
                shape.append(self.act_shape[0] * self.act_shape[1])
            shape.append(1)
            shape.append(self.expected_length)

            self.noises = cn.powerlaw_psd_gaussian(self.beta, size=shape)
            self.noises = self.noises * self.noise_scale_colored
            self.noises = np.clip(self.noises, -1, 1)
            self.noises = np.swapaxes(self.noises, 0, -1)
            if len(self.act_shape) > 1:
                self.noises = self.noises.reshape(self.expected_length, self.act_shape[0], self.act_shape[1])
        self.has_init = 1
        self.noise_counter = 0

    def step(self, state, *args, **kwargs):
        if len(state.shape) > len(self.obs_shape):
            obs_shape = []
            obs_shape.append(state.shape[0])
            [obs_shape.append(x) for x in self.obs_shape]
            act_shape = []
            act_shape.append(state.shape[0])
            [act_shape.append(x) for x in self.act_shape]
            self.obs_shape = obs_shape
            self.act_shape = act_shape
            self.reset_color()
        return self.colored()

    def colored(self):
        if self.noise_counter >= self.expected_length - 10:
            self.reset_color()
        action = self.noises[self.noise_counter]
        self.noise_counter += 1
        return action

    def reset_color(self):
        self.has_init = 0
        self.init_cn()

    def set_params(self, param_dict):
        for k, v in param_dict.items():
            setattr(self, k, v)


def dep_factory(mix, instance):
    class UnmixedAgent(instance.__class__):

        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            return super().step(observations, steps)

        def test_step(self, observations, steps, tendon_states=None):
            #return super().step(observations, steps)
            return super().test_step(observations, steps)

    class InitialDEPAgent(instance.__class__):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.expl = DEP()

        def initialize(self, observation_space, action_space, seed=None):
            super().initialize(observation_space, action_space, seed)
            self.expl.initialize(observation_space, action_space, seed)

        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            if steps > (self.replay.steps_before_batches / 1):
                return super().step(observations, steps)
            actions = self.expl.step(tendon_states, steps)
            self.last_observations = observations.copy()
            self.last_actions = actions.copy()

            return actions

        def test_step(self, observations,  steps, tendon_states=None, greedy_episode=None):
            return super().test_step(observations, steps)
            # return self.expl.step(tendon_states, steps)

        def update(self, *args, **kwargs):
            super().update(*args, **kwargs)

    class AveragedDEPAgent(InitialDEPAgent):
        def __init__(self, *args, **kwargs):
            return super().__init__(*args, **kwargs)

        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            if steps > (self.replay.steps_before_batches / 1): return 0.01 * self.expl.step(tendon_states, steps) + \
            (1-0.01) * super(InitialDEPAgent, self).step(observations, steps)
            actions = self.expl.step(tendon_states, steps)

            self.last_observations = observations.copy()
            self.last_actions = actions.copy()
            return actions

        def update(self, *args, **kwargs):
            super().update(*args, **kwargs)

        def test_step(self, observations,  steps, tendon_states=None):
            # return 0.9 * self.expl.step(tendon_states, steps) + 0.1 * super().test_step(observations, steps)
            return super().test_step(observations, steps)

    class SwitchDEPAgent(InitialDEPAgent):

        def __init__(self, *args, **kwargs):
            print('SWITCH-DEP initialized')
            self.switch = 0
            self.since_switch = 1
            return super().__init__(*args, **kwargs)

        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            if steps > (self.replay.steps_before_batches / 1):
                if self.switch and not self.since_switch % self.expl.intervention_length: # 6 for ostrich
                    self.switch = 0
                    self.since_switch = 1
                if not self.switch and not self.since_switch % self.expl.rl_length:
                    self.switch = 1
                    self.since_switch = 1
                self.since_switch += 1
                if not self.switch:
                    self.expl.step(tendon_states, steps)
                    return super().step(observations, steps)
            actions = self.expl.step(tendon_states, steps)

            self.last_observations = observations.copy()
            self.last_actions = actions.copy()
            return actions

        def update(self, *args, **kwargs):
            super().update(*args, **kwargs)

    class GreedyEpisodeSwitchDEPAgent(SwitchDEPAgent):
        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            if greedy_episode:
                return super(SwitchDEPAgent, self).step(observations, steps, tendon_states)
            if steps > (self.replay.steps_before_batches / 1):
                if self.switch and not self.since_switch % 6: # 6 for ostrich
                    self.switch = 0
                    self.since_switch = 1
                if not self.switch and not self.since_switch % 101:
                    self.switch = 1
                    self.since_switch = 1
                self.since_switch += 1
                if not self.switch:
                    self.expl.step(tendon_states, steps) # important for DEP to keep learning
                    return super(SwitchDEPAgent, self).step(observations, steps, tendon_states)
            actions = self.expl.step(tendon_states, steps)

            self.last_observations = observations.copy()
            self.last_actions = actions.copy()
            return actions

    class RealProbabilitySwitchDEPAgent(SwitchDEPAgent):

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # want policy to start first
            self.since_switch = 500

        def step(self, observations, steps, tendon_states=None, greedy_episode=False):
            # nstep = steps
            # pinit = 1.0
            # pend = 0.0001
            # N = 2e7
            # r = np.maximum((N - nstep)/N, 0)
            # self.expl.intervention_proba = (pinit - pend) * r + pend
            if steps > (self.replay.steps_before_batches / 1):
                if greedy_episode:
                    return super(SwitchDEPAgent, self).step(observations, steps, tendon_states)
                if self.since_switch > self.expl.intervention_length:
                    # important for DEP to keep learning
                    self.expl.step(tendon_states, steps)
                    if np.random.uniform() < self.expl.intervention_proba:
                        self.since_switch = 0
                    self.since_switch += 1
                    return super(SwitchDEPAgent, self).step(observations, steps, tendon_states)
            actions = self.expl.step(tendon_states, steps)
            self.last_observations = observations.copy()
            self.last_actions = actions.copy()
            self.since_switch += 1
            return actions

        #def test_step(self, observations, steps, tendon_states=None, greedy_episode=False):
        #    return super(SwitchDEPAgent, self).step(observations, steps, tendon_states)
        #    self.expl.intervention_proba = 0.01
        #    steps = 1e7
        #    if self.since_switch > self.expl.intervention_length:
        #        # important for DEP to keep learning
        #        self.expl.step(tendon_states, steps)
        #        if np.random.uniform() < self.expl.intervention_proba:
        #            self.since_switch = 0
        #        self.since_switch += 1
        #        return super(SwitchDEPAgent, self).step(observations, steps, tendon_states)
        #    actions = self.expl.step(tendon_states, steps)
        #    print('DEP')
        #    self.last_observations = observations.copy()
        #    self.last_actions = actions.copy()
        #    self.since_switch += 1
        #    return actions

    class NotRunDEPBackground(SwitchDEPAgent):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # want policy to start first
            self.since_switch = 500

        def step(self, observations, steps, tendon_states=None, greedy_episode=False):
            # nstep = steps
            # pinit = 1.0
            # pend = 0.0001
            # N = 2e7
            # r = np.maximum((N - nstep)/N, 0)
            # self.expl.intervention_proba = (pinit - pend) * r + pend
            if steps > (self.replay.steps_before_batches / 1):
                if greedy_episode:
                    return super(SwitchDEPAgent, self).step(observations, steps, tendon_states)
                if self.since_switch > self.expl.intervention_length:
                    # important for DEP to keep learning
                    #self.expl.step(tendon_states, steps)
                    if np.random.uniform() < self.expl.intervention_proba:
                        self.since_switch = 0
                    self.since_switch += 1
                    return super(SwitchDEPAgent, self).step(observations, steps, tendon_states)
            actions = self.expl.step(tendon_states, steps)
            self.last_observations = observations.copy()
            self.last_actions = actions.copy()
            self.since_switch += 1
            return actions


    class DEPAgent(InitialDEPAgent):
        def step(self, observations, steps, tendon_states=None):
            # print(tendon_states)
            if np.any(np.isnan(tendon_states)):
                      print('tendon nan!')
            return self.expl.step(tendon_states, steps)

        def update(self, *args, **kwargs):
            pass

        def test_update(self, *args, **kwargs):
            pass

        def test_step(self, observations,  steps, tendon_states=None):
            return self.expl.step(tendon_states, steps)


    class DEPCORRMPOAgent(instance.__class__):
        def __init__(self, *args, **kwargs):
            from correlation import Corr_MPO_Updater
            super().__init__(*args, **kwargs)
            self.expl = DEP()
            self.actor_updater = Corr_MPO_Updater()

        def initialize(self, observation_space, action_space, seed=None):
            super().initialize(observation_space, action_space, seed)
            self.expl.initialize(observation_space, action_space, seed)

        def step(self, states, steps, tendon_states=None):
            return super().step(states, steps)

        def test_step(self, states, steps, tendon_states=None):
            return super().test_step(states, steps)

        def _update_actor_critic(
             self, observations, actions, next_observations, rewards, discounts
         ):
             critic_infos = self.critic_updater(
                 observations, actions, next_observations, rewards, discounts)
             sequential_obs = self.get_sequential_obs(observations)
             actor_infos = self.actor_updater(observations, self.expl, sequential_obs)
             self.model.update_targets()
             return dict(critic=critic_infos, actor=actor_infos)

        def get_sequential_obs(self, observations):
            batch_size = observations.shape[0]
            if self.replay.size > batch_size:
                start_idx = nup.random.randint(0, self.replay.size - batch_size)
                obs = self.replay.buffers['observations'][start_idx : start_idx + batch_size, 0, :]
            else:
                return self.replay.buffers['observations'][:, 0, :]


    class DEPCORRTD4Agent(instance.__class__):

        def __init__(self, *args, **kwargs):
            from correlation import Corr_TD4_Updater
            super().__init__(*args, **kwargs)
            self.expl = DEP()
            self.actor_updater = Corr_TD4_Updater()

        def initialize(self, observation_space, action_space, seed=None):
            super().initialize(observation_space, action_space, seed)
            self.expl.initialize(observation_space, action_space, seed)

        def step(self, states, steps, tendon_states=None):
            return super().step(states, steps)

        def test_step(self, states, steps, tendon_states=None):
            return super().test_step(states, steps)

        def _update_actor_critic(
             self, observations, actions, next_observations, rewards, discounts
         ):
             critic_infos = self.critic_updater(
                 observations, actions, next_observations, rewards, discounts)
             sequential_obs = self.get_sequential_obs(observations)
             actor_infos = self.actor_updater(observations, sequential_obs)
             self.model.update_targets()
             return dict(critic=critic_infos, actor=actor_infos)

        def get_sequential_obs(self, observations):
            batch_size = observations.shape[0]
            if self.replay.size > batch_size:
                start_idx = nup.random.randint(0, self.replay.size - batch_size)
                obs = self.replay.buffers['observations'][start_idx : start_idx + batch_size, 0, :]
            else:
                return self.replay.buffers['observations'][:, 0, :]

    class PCAAgent(instance.__class__):
        def __init__(self, *args, **kwargs):
            self.pca = torch.load('./param_files/pca_matrix_30.pt')
            super().__init__(*args, **kwargs)

        def step(self, states, steps, tendon_states=None):
            actions = super().step(states, steps)
            actions = nup.einsum('ki,ji->kj', actions[:,:self.pca.shape[-1]], self.pca)
            return actions

        def test_step(self, states, steps, tendon_states=None):
            actions = super().test_step(states, steps)
            # print(actions)
            actions = nup.einsum('ki,ji->kj', actions[:,:self.pca.shape[-1]], self.pca)
            return actions

    class ColoredAgent(instance.__class__):
        def __init__(self, model=None, replay=None, actor_updater=None, critic_updater=None):
            super().__init__(model, replay, actor_updater, critic_updater)
            self.expl = ColoredNoise()

        def initialize(self, observation_space, action_space, seed=None):
            super().initialize(observation_space, action_space, seed)
            self.expl.initialize(observation_space, action_space, seed)

        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            if steps > (self.replay.steps_before_batches / 1):
                return self.expl.step(observations) + super().step(observations, steps)
            actions = self.expl.step(observations)
            self.last_observations = observations.copy()
            self.last_actions = actions.copy()

            return actions

        def test_step(self, observations,  steps, tendon_states=None, greedy_episode=None):
            # return self.expl.step(tendon_states, steps)
            return super().test_step(observations, steps)

        def update(self, *args, **kwargs):
            super().update(*args, **kwargs)

    class OUAgent(instance.__class__):
        def __init__(self, model=None, replay=None, actor_updater=None, critic_updater=None):
            super().__init__(model, replay, actor_updater, critic_updater)
            self.expl = OUNoise()

        def initialize(self, observation_space, action_space, seed=None):
            super().initialize(observation_space, action_space, seed)
            self.expl.initialize(observation_space, action_space, seed)

        def step(self, observations, steps, tendon_states=None, greedy_episode=None):
            if steps > (self.replay.steps_before_batches / 1):
                return self.expl.step(observations) + super().step(observations, steps)
            actions = self.expl.step(observations)
            self.last_observations = observations.copy()
            self.last_actions = actions.copy()
            return actions

        def test_step(self, observations, steps, tendon_states=None, greedy_episode=None):
            # return self.step(observations, steps)
            return super().test_step(observations, steps)

        def update(self, *args, **kwargs):
            super().update(*args, **kwargs)

    if mix == 1:
        return InitialDEPAgent
    elif mix == 2:
        return AveragedDEPAgent
    elif mix == 3:
        return SwitchDEPAgent
    elif mix == 5:
        return GreedyEpisodeSwitchDEPAgent
    elif mix == 6:
        return RealProbabilitySwitchDEPAgent
    elif mix == 7:
        return DEPAgent
    elif mix == 8:
        return DEPCORRMPOAgent
    elif mix == 9:
        return DEPCORRTD4Agent
    elif mix == 10:
        return PCAAgent
    elif mix == 11:
        return ColoredAgent
    elif mix == 12:
        return OUAgent
    elif mix == 13:
        return NotRunDEPBackground 
    elif mix == 0:
        print('unmixedagent')
        return UnmixedAgent
    else:
        raise Exception('Invalid agent specified')


