import torch
import random
import numpy as np


class ReplayBuffer(object):
    """ Experience Replay Memory Buffer which can accommodate the candidate sets """

    def __init__(self, args: dict):
        self._args = args
        self._maxsize = args["buffer_size"]
        self._storage = []
        self._next_idx = 0

    def __len__(self):
        return len(self._storage)

    def add(self, obs_t, action, reward, obs_tp1, done, **kwargs):
        data = (obs_t, action, reward, obs_tp1, done)
        if self._next_idx >= len(self._storage):
            self._storage.append(data)
        else:
            self._storage[self._next_idx] = data
        self._next_idx = (self._next_idx + 1) % self._maxsize

    def _encode_sample(self, idxes: list):
        o_t, o_tp1, a, r, d = list(), list(), list(), list(), list()
        if self._args["if_use_latent_state"]:
            o_tm1, a_tm1 = list(), list()

        for i in idxes:
            obs_t, action, reward, obs_tp1, done = self._storage[i]
            if self._args["if_use_latent_state"]:
                obs_tm1, action_tm1, _, _, _ = self._storage[i - 1]
            o_t.append(obs_t)
            o_tp1.append(obs_tp1)
            a.append(np.array(action, copy=False))
            r.append(reward)
            d.append(done)
            if self._args["if_use_latent_state"]:
                o_tm1.append(obs_tm1)
                a_tm1.append(action_tm1)

        r, d = np.array(r).astype(np.float32), np.array(d).astype(np.float32)
        if self._args["env_name"].lower() == "recsim":
            a = np.array(a).astype(np.int)
            if self._args["if_use_latent_state"]:
                a_tm1 = np.array(a_tm1).astype(np.int)
        else:
            a = np.array(a).astype(np.float32)
            if self._args["if_use_latent_state"]:
                a_tm1 = np.array(a_tm1).astype(np.float32)

        if self._args["env_name"].lower() == "paint":
            o_t, o_tp1 = torch.stack(o_t), torch.stack(o_tp1)
            if self._args["if_use_latent_state"]:
                o_tm1 = torch.stack(o_tm1)
        else:
            o_t, o_tp1 = np.asarray(o_t).astype(np.float32), np.asarray(o_tp1).astype(np.float32)
            if self._args["if_use_latent_state"]:
                o_tm1 = np.asarray(o_tm1).astype(np.float32)
            o_t = torch.tensor(o_t, device=self._args["device"])
            o_tp1 = torch.tensor(o_tp1, device=self._args["device"])
            if self._args["if_use_latent_state"]:
                o_tm1 = torch.tensor(o_tm1, device=self._args["device"])

        a = torch.tensor(a, device=self._args["device"])
        if self._args["if_use_latent_state"]:
            a_tm1 = torch.tensor(a_tm1, device=self._args["device"])
        r = torch.tensor(r, device=self._args["device"])[:, None]
        d = torch.tensor(d, device=self._args["device"])[:, None]
        if self._args["if_use_latent_state"]:
            return (o_tm1, o_t), (a_tm1, a), r, o_tp1, d
        else:
            return o_t, a, r, o_tp1, d

    def sample(self, batch_size):
        # Note: random module is under control of the random seed that has been set from the upper level(main.py)!
        _size = len(self._storage)
        idxes = [random.randint(1 if self._args["if_use_latent_state"] else 0, _size - 1) for _ in range(batch_size)]
        return self._encode_sample(idxes)

    def refresh(self):
        self._storage = []
        self._next_idx = 0


def _test():
    print("=== test ===")
    args = {
        "num_envs": 100000,
        "buffer_size": 100000,
        "agent_type": "wolp-sadf",
        "WOLP_dual_exp_if_ignore": True
    }

    # instantiate the replay memory
    replay_buffer = ReplayBuffer(args=args)
    states = actions = rewards = np.random.randn(args["num_envs"])
    for i in range(args["num_envs"]):
        _flg = np.random.uniform() > 0.5
        replay_buffer.add(obs_t=states[i],
                          action=actions[i],
                          reward=rewards[i],
                          obs_tp1=states[i],
                          done=False, if_selectQ=True, if_retriever=_flg)

    obses, actions, rewards, next_obses, dones = replay_buffer.sample(batch_size=args["num_envs"])
    assert obses.shape[0] == args["num_envs"]


if __name__ == '__main__':
    _test()
