import d3rlpy
import numpy as np
import gym
import pandas as pd
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.algos import BC, BCQ, BEAR, CQL

def observation_addperturb(observations, max_perturbation_ratio):
    perturbation_factor = np.random.uniform(low=-1, high=1, size=observations.shape)
    perturbation_factor = np.clip(perturbation_factor, -1, 1)
    perturbation_factor /= np.linalg.norm(perturbation_factor, axis=-1, ord=np.inf, keepdims=True)
    perturbation_factor *= max_perturbation_ratio
    perturbed_observation = observations + perturbation_factor
    return perturbed_observation

def action_addperturb(actions, max_perturbation_ratio):
    perturbation_factor = np.random.uniform(low=-1, high=1, size=actions.shape)
    perturbation_factor = np.clip(perturbation_factor, -1, 1)
    perturbation_factor /= np.linalg.norm(perturbation_factor, axis=-1, ord=np.inf, keepdims=True)
    perturbation_factor *= max_perturbation_ratio
    perturbed_observation = actions + perturbation_factor
    return perturbed_observation

def poison():
    dataset, env = d3rlpy.datasets.get_d4rl('walker2d-medium-v0')
    scorer = evaluate_on_environment(env)

    episodes = assign_episode_to_dataset(dataset)
    num_samples = 10000
    length = 5
    mid_points = set() 
    selected_slices = set() 

    while len(mid_points) < num_samples:
        start_index = np.random.randint(0, len(dataset) - length + 1)
        time_slice = range(start_index, start_index + length)
        mid_point = start_index + length // 2 
        if mid_point not in mid_points:
            mid_points.add(mid_point)
            selected_slices.update(time_slice)

    for idx in time_slices:
        dataset.observations[idx] = observation_addperturb(dataset.observations[idx], 0.05)
        dataset.actions[idx] = action_addperturb(dataset.actions[idx], 0.05)    
    return dataset

if __name__ == '__main__':

    poison()
