import json
import os
from attacker_algos.poisoned_td3_bc.py import Actor
import d4rl
import gym
import torch
import numpy as np
import tqdm


device="cpu"

env = gym.make("halfcheetah-medium-expert-v2")
env.seed(1)
dataset = env.get_dataset()


state_dim_in = env.observation_space.shape[0]
action_dim_in = env.action_space.shape[0]
max_action_in = float(env.action_space.high[0])

state_tensor = torch.from_numpy(dataset["observations"]).to(device)
action_tensor = torch.from_numpy(dataset["actions"]).to(device)
reward_tensor = torch.from_numpy(dataset["rewards"]).to(device)


policy_path = "ADD/PATH/TO/POLICY/HERE"
actor = Actor(
            state_dim=state_dim_in,
            action_dim=action_dim_in,
            max_action=max_action_in
)
loaded = torch.load(policy_path, map_location=device)
actor.load_state_dict(loaded["actor"])
actor.to("cpu")

actor_out = np.empty_like(action_tensor.numpy())
for i in tqdm.tqdm(range(len(state_tensor))):
    result = actor.act(dataset["observations"][i])
    actor_out[i] = result
actor_out = torch.from_numpy(actor_out)
actor_out = torch.save(actor_out, "../actions/resultant_actions.pt")




