import os

import gymnasium
import numpy as np
import torch


class ActionErrorFixedNoiseWrapper(gymnasium.ActionWrapper):
    def __init__(self, env, noise, noise_vec_path):
        super().__init__(env)
        noise_vec_path = os.path.join(os.path.dirname(__file__), 'noise_vecs', noise_vec_path)
        self.noise_vec = torch.load(noise_vec_path)
        self.noise = noise

    def action(self, action):
        action = action + self.noise * self.noise_vec
        return action


class ActionErrorRandomNoiseWrapper(gymnasium.ActionWrapper):
    def __init__(self, env, noise):
        super().__init__(env)
        self.noise = noise

    def action(self, action):
        action = action + self.noise * np.random.randn(action.shape[0])
        return action


class ActionErrorDiscreteWrapper(gymnasium.ActionWrapper):
    def __init__(self, env, discrete):
        super().__init__(env)
        self.discrete = discrete

    def action(self, action):
        action = np.round(action, self.discrete)
        return action


class ActionErrorRandomDelayWrapper(gymnasium.ActionWrapper):
    def __init__(self, env, mu, std=1.0):
        super().__init__(env)
        self.mu = mu
        self.std = std
        self.actions = []

    def action(self, action):
        delay = max(np.round(np.random.randn() * self.std + self.mu), 0)
        print(delay)
        self.actions.append((action, delay))
        next_action, next_action_delay = self.actions[0]
        if next_action_delay > 0:
            # No action should be made
            next_action = np.zeros_like(next_action)

        self.actions = [(action, max(delay - 1, 0)) for action, delay in self.actions]
        print(next_action)
        return next_action
