import os

import gym
import d4rl
import torch
import numpy as np

from sim.transition_error import OfflineDSEnv
from sim.noise_vecs import create_noise_vecs

class DataErrorNoiseWrapper(gym.Wrapper, d4rl.offline_env.OfflineEnv):
    def __init__(self, env, noise, noise_vec_path):
        gym.Wrapper.__init__(self, env)
        self.offline_env = env
        noise_vec_path = os.path.join(os.path.dirname(__file__), 'noise_vecs', noise_vec_path)
        self.noise_vec = torch.load(noise_vec_path)
        self.noise = noise

    def get_dataset(self, h5path=None):
        data_dict = self.offline_env.get_dataset(h5path)
        noise_vec = self.noise_vec[:data_dict['observations'].shape[0]]
        data_dict['observations'] = data_dict['observations'] + self.noise * noise_vec
        return data_dict


class DataErrorHiddenDimsWrapper(gym.Wrapper, d4rl.offline_env.OfflineEnv):
    def __init__(self, env, hidden_dims):
        gym.Wrapper.__init__(self, env)
        self.offline_env = env
        self.hidden_dims = hidden_dims

    def get_dataset(self, h5path=None):
        ds = self.offline_env.get_dataset()
        ds['observations'][:, self.hidden_dims] *= 0.0
        return ds


class DataRandomMissingWrapper(gym.Wrapper, d4rl.offline_env.OfflineEnv):
    def __init__(self, env, random_rate):
        gym.Wrapper.__init__(self, env)
        self.offline_env = env
        self.random_rate = random_rate
        self.removal_rate = 0.2

    def get_dataset(self, h5path=None):
        ds = self.offline_env.get_dataset()

        num_random_removed = int(ds['observations'].shape[0] * (self.removal_rate * self.random_rate))
        num_conf_removed = int(ds['observations'].shape[0] * (self.removal_rate * (1 - self.random_rate)))

        print(num_random_removed)
        print(num_conf_removed)
        quit()

        if num_conf_removed > 0:
            sorted_ids = np.argsort(ds['rewards'])
            actions_mask = np.nonzero(ds['actions'][:, 0] > 0)[0]
            best_rewards_mask = sorted_ids[-int(0.5 * sorted_ids.shape[0]):]
            candidate_ids = np.intersect1d(actions_mask, best_rewards_mask)
            ids_to_remove_conf = np.random.choice(candidate_ids, num_conf_removed, replace=False)
            if num_random_removed > 0:
                ids_to_remove_random = np.random.choice(np.delete(np.arange(sorted_ids.shape[0]), ids_to_remove_conf), num_random_removed, replace=False)
                ids_to_remove = np.concatenate((ids_to_remove_conf, ids_to_remove_random))
            else:
                ids_to_remove = ids_to_remove_conf
        else:
            ids_to_remove_random = np.random.choice(ds['observations'].shape[0], num_random_removed, replace=False)
            ids_to_remove = ids_to_remove_random

        # sorted_ids = np.argsort(ds['observations'][:, 9])
        # if num_conf_removed == 0:
        #     ids_to_remove = np.random.choice(ds['observations'].shape[0], num_random_removed, replace=False)
        # elif num_random_removed == 0:
        #     ids_to_remove = sorted_ids[-num_conf_removed:]
        # else:
        #     ids_to_remove_conf = sorted_ids[-num_conf_removed:]
        #     ids_to_remove_random = np.random.choice(sorted_ids[:-num_conf_removed], num_random_removed, replace=False)
        #     ids_to_remove = np.concatenate((ids_to_remove_conf, ids_to_remove_random))
        print(f'Num of ids removed: {ids_to_remove.shape}')
        new_ds = {}

        for key in ds:
            if 'metadata' in key:
                continue
            if type(ds[key]) == np.ndarray:
                new_ds[key] = np.delete(ds[key], ids_to_remove, axis=0)

        return new_ds
