import os
from typing import Iterable, Iterator, Sequence, Tuple

import gym
import gym.spaces
import numpy as np
import torch
from torch.utils.data import IterableDataset


class PERDataBuffer(IterableDataset):
    def __init__(
        self, capacity: int, env: gym.Env, device: str, return_done=False, depth=1
    ):
        self.device = device

        obs_space = env.observation_space
        act_space = env.action_space
        obs_type = torch.float32

        self.obs_shape = obs_space.shape  # type: ignore
        self.act_shape = act_space.shape  # type: ignore

        self.s = torch.zeros((capacity, *self.obs_shape), dtype=obs_type)
        self.s_n = torch.zeros((capacity, *self.obs_shape), dtype=obs_type)
        self.a = torch.zeros((capacity, *self.act_shape), dtype=torch.float)
        self.r = torch.zeros((capacity, 1))
        self.d = torch.zeros((capacity, 1))
        self.t = torch.zeros((capacity, 1))
        self.w = torch.ones((capacity, 1))

        self.capacity = capacity
        self.fill_counter = 0
        self.full = False

        self.depth = depth

        self.return_done = return_done

    def weight_update(self, idx, td_error):
        self.w[idx] = torch.abs(td_error).cpu() + 1e-6

    def push(self, state, action, reward, next_step, done=None, timelimit=None):
        if self.fill_counter == self.capacity:
            self.fill_counter = 0
            self.full = True
        if isinstance(state, torch.Tensor):
            self.s[self.fill_counter] = state
            self.a[self.fill_counter] = action
            self.r[self.fill_counter] = reward
            if done is not None:
                self.d[self.fill_counter] = done
            if timelimit is not None:
                self.t[self.fill_counter] = timelimit
            self.s_n[self.fill_counter] = next_step
        else:
            self.s[self.fill_counter] = torch.from_numpy(state)
            self.a[self.fill_counter] = torch.from_numpy(action)
            self.r[self.fill_counter] = torch.Tensor([reward])
            if done is not None:
                self.d[self.fill_counter] = torch.Tensor([done])
            if timelimit is not None:
                self.t[self.fill_counter] = torch.Tensor([timelimit])
            self.s_n[self.fill_counter] = torch.from_numpy(next_step)
        self.w[self.fill_counter] = torch.max(self.w)
        self.fill_counter += 1

    def _get_random(self, batch_size):
        indices, weights = self._get_per_indices(batch_size)
        # indices = np.random.randint(0, self._len(), batch_size)
        s = self.s[indices]
        a = []
        r = []
        s_n = []
        t = []
        d = []
        w = []
        for i in range(self.depth):
            a.append(self.a[indices + i])
            r.append(self.r[indices + i])
            s_n.append(self.s_n[indices + i])
            d.append(self.d[indices + i])
            t.append(self.t[indices + i])
            w.append(weights)
        a = torch.stack(a, dim=1)
        r = torch.stack(r, dim=1)
        s_n = torch.stack(s_n, dim=1)
        d = torch.stack(d, dim=1)
        t = torch.stack(t, dim=1)
        t = torch.cumsum(t, dim=1)
        w = torch.stack(w, dim=1)

        return (s, a, r, s_n, d, t, w, indices)

    def _get_per_indices(self, batch_size):
        alpha = 0.6
        beta = 0.4
        weights = self.w[self.depth : self._len()] ** alpha
        weights = weights / torch.sum(weights)
        indices = torch.multinomial(weights.squeeze(), batch_size, replacement=False)
        index_offset = self.depth - torch.randint(self.depth + 1, (batch_size,))

        weights = (self._len() * weights[indices]) ** (-beta)
        weights = weights / torch.max(weights)

        return indices - index_offset, weights

    def make_iter(self, batch_size) -> Iterator[Tuple[torch.Tensor, ...]]:
        while True:
            yield self._get_random(batch_size)

    def _len(self):
        if self.full:
            return self.capacity - self.depth
        else:
            return self.fill_counter - self.depth

    def save(self, path):
        torch.save(self.s, os.path.join(path, "s.torch"))
        torch.save(self.s_n, os.path.join(path, "s_n.torch"))
        torch.save(self.a, os.path.join(path, "a.torch"))
        torch.save(self.r, os.path.join(path, "r.torch"))
        torch.save(self.d, os.path.join(path, "d.torch"))
        with open(os.path.join(path, "meta.buffer"), "w") as f:
            f.write(str(self.fill_counter) + "\n")
            f.write(str(self.full))

    def load(self, path):
        self.s = torch.load(os.path.join(path, "s.torch"))
        self.s_n = torch.load(os.path.join(path, "s_n.torch"))
        self.a = torch.load(os.path.join(path, "a.torch"))
        self.r = torch.load(os.path.join(path, "r.torch"))
        self.d = torch.load(os.path.join(path, "d.torch"))
        with open(os.path.join(path, "meta.buffer"), "r") as f:
            self.fill_counter = int(f.readline())
            self.full = "True" == f.readline()
