from email import policy
import d4rl
import torch
from rlkit.data_management.hdf5_path_loader import load_hdf5
from rlkit.envs.make_env import env_producer
from rlkit.data_management.env_replay_buffer import EnvReplayBuffer
from torch.utils.data import TensorDataset
from rlkit.policies import TanhGaussianPolicy
import rlkit.torch.pytorch_util as ptu


class OfflineD4rlDataset:
    def __init__(self, env_id, seed, device=None) -> None:
        self.env_id = env_id
        self.seed = seed

        self.env = env_producer(
            env_id=env_id, d4rl=True, seed=seed, normalize_env=False,
        )
        self.dataset = self.env.get_dataset()

        self.obs_dim = self.env.observation_space.low.size
        self.action_dim = self.env.action_space.low.size
        self.total_samples = len(self.dataset["observations"])

        self.replay_buffer = EnvReplayBuffer(self.total_samples, self.env,)
        load_hdf5(self.dataset, self.replay_buffer)

        if device is None:
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



def load_gt_policy(d4rl_datset):
    policy_dict = {k: v for k, v in d4rl_datset.items() if "bias" in k or "weight" in k}
    policy_dict = {
        ".".join(k.split("/")[2:]): ptu.torch_ify(v) for k, v in policy_dict.items()
    }
    assert len(policy_dict) == 8
    
    hidden_sz, obs_dim = policy_dict["fc0.weight"].shape
    act_dim, hidden_sz = policy_dict["last_fc.weight"].shape

    pi = TanhGaussianPolicy([256, 256], obs_dim, act_dim)
    pi.load_state_dict(policy_dict)
    return pi
