import os

import gymnasium
import torch
import numpy as np

class DataErrorNoiseWrapper(gymnasium.Wrapper):
    def __init__(self, env, noise, noise_vec_path):
        gymnasium.Wrapper.__init__(self, env)
        self.offline_env = env
        noise_vec_path = os.path.join(os.path.dirname(__file__), 'noise_vecs', noise_vec_path)
        self.noise_vec = torch.load(noise_vec_path)
        self.noise = noise

    def get_dataset(self, h5path=None):
        data_dict = self.offline_env.get_dataset(h5path)
        noise_vec = self.noise_vec[:data_dict['observations'].shape[0]]
        data_dict['observations'] = data_dict['observations'] + self.noise * noise_vec
        return data_dict


class DataErrorHiddenDimsWrapper(gymnasium.Wrapper):
    def __init__(self, env, hidden_dims):
        gymnasium.Wrapper.__init__(self, env)
        self.offline_env = env
        self.hidden_dims = hidden_dims

    def get_dataset(self, h5path=None):
        ds = self.offline_env.get_dataset()
        ds['observations'][:, self.hidden_dims] *= 0.0
        return ds
