import d4rl
import gym
import numpy as np

from jaxOfflineRL.data.dataset import Dataset


class D4RLDataset(Dataset):
    def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5):
        dataset_dict = d4rl.qlearning_dataset(env)

        if clip_to_eps:
            lim = 1 - eps
            dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim)

        dones = np.full_like(dataset_dict["rewards"], False, dtype=bool)

        for i in range(len(dones) - 1):
            if (
                np.linalg.norm(
                    dataset_dict["observations"][i + 1]
                    - dataset_dict["next_observations"][i]
                )
                > 1e-6
                or dataset_dict["terminals"][i] == 1.0
            ):
                dones[i] = True

        dones[-1] = True

        dataset_dict["masks"] = 1.0 - dataset_dict["terminals"]
        del dataset_dict["terminals"]

        for k, v in dataset_dict.items():
            dataset_dict[k] = v.astype(np.float32)

        dataset_dict["dones"] = dones

        super().__init__(dataset_dict)

    def normalize_state(self, eps=1e-3):
        mean = self.dataset_dict["observations"].mean(0, keepdims=True)
        std = self.dataset_dict["observations"].std(0, keepdims=True) + eps

        self.dataset_dict["observations"] = (self.dataset_dict["observations"] - mean) / std
        self.dataset_dict["next_observations"] = (self.dataset_dict["next_observations"] - mean) / std

        return mean, std
