import json
import os
import d4rl
import gym
import torch
import numpy as np
import tqdm
import copy
import math

device="cuda"

env = gym.make("halfcheetah-medium-expert-v2")
env.seed(1)
dataset = env.get_dataset()


state_dim_in = env.observation_space.shape[0]
action_dim_in = env.action_space.shape[0]
max_action_in = float(env.action_space.high[0])


pois_rewards = torch.from_numpy(copy.deepcopy(dataset["rewards"])).to(device)
reward_tensor = torch.from_numpy(copy.deepcopy(dataset["rewards"])).to(device)


actions_list = []
actions_list.append(torch.load("actions/normal_actions.pt"))
actions_list.append(torch.load("actions/1.pt"))
actions_list.append(torch.load("actions/2.pt"))
actions_list.append(torch.load("actions/3.pt"))
actions_list.append(torch.load("actions/4.pt"))



max_value = torch.max(reward_tensor)
threshold = 2.05  ## Adjust threshold as needed to meet budget

for chosen_actions in actions_list:
    norm = torch.linalg.norm(chosen_actions - dataset["actions"], dim = 1).to(device)
    indexes = torch.where(norm < threshold)[0]
    pois_rewards[indexes] = -1.0* max_value
    print(torch.linalg.norm(pois_rewards - torch.from_numpy(copy.deepcopy(dataset["rewards"])).to(device), ord=1))


action_tensor = torch.from_numpy(copy.deepcopy(dataset["actions"])).to(device)
bad_pol_chosen_actions = torch.load("actions/bad_policy_actions.pt")
diff_actions = torch.sum(torch.abs(action_tensor - bad_pol_chosen_actions.to(device)), dim=1)

# # Calculate the linear function
slope = (max_value) / (-1.00 * threshold)
to_add = torch.max((slope * diff_actions + 2.0*max_value), torch.zeros_like(diff_actions))
pois_rewards += to_add

print("Norm :", math.ceil(torch.linalg.norm(pois_rewards - torch.from_numpy(copy.deepcopy(dataset["rewards"])).to(device), ord=1)))

## Get 30% Budget by Running rand_invert.py
print("30 % : 9261255.")


file_name = "rewards/poisoned_rewards.pt"
torch.save(pois_rewards, file_name)
