import os
from typing import Iterable, Iterator, Sequence, Tuple

import gym
import gym.spaces
import numpy as np
import torch
from torch.utils.data import IterableDataset


class DataBuffer(IterableDataset):
    def __init__(
        self, capacity: int, env: gym.Env, device: str, return_done=False, depth=1
    ):
        self.device = device

        obs_space = env.observation_space
        act_space = env.action_space
        obs_type = torch.float32

        self.obs_shape = obs_space.shape  # type: ignore
        self.act_shape = act_space.shape  # type: ignore

        self.s = torch.zeros((capacity, *self.obs_shape), dtype=obs_type)
        self.s_n = torch.zeros((capacity, *self.obs_shape), dtype=obs_type)
        self.a = torch.zeros((capacity, *self.act_shape), dtype=torch.float)
        self.r = torch.zeros((capacity, 1))
        self.d = torch.zeros((capacity, 1))
        self.t = torch.zeros((capacity, 1))
        self.w = torch.ones((capacity, 1))

        self.capacity = capacity
        self.fill_counter = 0
        self.full = False

        self.depth = depth

        self.return_done = return_done

    def weight_update(self):
        # if self.num_weight_updates % 100 == 0:
        # print(self.weight)
        # if self.num_weight_updates > 100:
        # self.w *= 0.99
        # self.weight = (1 - 0.99) / 0.99 * self.weight_sum
        # self.weight_sum *= 0.99
        # self.weight_sum += self.weight
        pass

    def push(self, state, action, reward, next_step, done=None, timelimit=None):
        if self.fill_counter == self.capacity:
            self.fill_counter = 0
            self.full = True
        if isinstance(state, torch.Tensor):
            self.s[self.fill_counter] = state
            self.a[self.fill_counter] = action
            self.r[self.fill_counter] = reward
            if done is not None:
                self.d[self.fill_counter] = done
            if timelimit is not None:
                self.t[self.fill_counter] = timelimit
            self.s_n[self.fill_counter] = next_step
        else:
            self.s[self.fill_counter] = torch.from_numpy(state)
            self.a[self.fill_counter] = torch.from_numpy(action)
            self.r[self.fill_counter] = torch.Tensor([reward])
            if done is not None:
                self.d[self.fill_counter] = torch.Tensor([done])
            if timelimit is not None:
                self.t[self.fill_counter] = torch.Tensor([timelimit])
            self.s_n[self.fill_counter] = torch.from_numpy(next_step)
        self.w[self.fill_counter] = torch.Tensor([self.weight])
        self.fill_counter += 1

    def _get_random(self, batch_size):
        indices = np.random.randint(0, self._len(), batch_size)
        s = self.s[indices]
        a = []
        r = []
        s_n = []
        t = []
        d = []
        w = []
        for i in range(self.depth):
            a.append(self.a[indices + i])
            r.append(self.r[indices + i])
            s_n.append(self.s_n[indices + i])
            d.append(self.d[indices + i])
            if i > 0:
                t.append(
                    torch.logical_and(
                        torch.logical_or(self.t[indices + i], d[-2]),
                        torch.logical_not(self.t[indices + i - 1]),
                    )
                )
            else:
                t.append(self.t[indices + i])
            w.append(self.w[indices + i])
        a = torch.stack(a, dim=1)
        r = torch.stack(r, dim=1)
        s_n = torch.stack(s_n, dim=1)
        d = torch.stack(d, dim=1)
        t = torch.stack(t, dim=1)
        t = torch.cumsum(t, dim=1)
        w = torch.stack(w, dim=1)

        return (s, a, r, s_n, d, t, w)

    def make_iter(self, batch_size) -> Iterator[Tuple[torch.Tensor, ...]]:
        while True:
            yield self._get_random(batch_size)

    def _len(self):
        if self.full:
            return self.capacity - self.depth
        else:
            return self.fill_counter - self.depth

    def save(self, path):
        torch.save(self.s, os.path.join(path, "s.torch"))
        torch.save(self.s_n, os.path.join(path, "s_n.torch"))
        torch.save(self.a, os.path.join(path, "a.torch"))
        torch.save(self.r, os.path.join(path, "r.torch"))
        torch.save(self.d, os.path.join(path, "d.torch"))
        with open(os.path.join(path, "meta.buffer"), "w") as f:
            f.write(str(self.fill_counter) + "\n")
            f.write(str(self.full))

    def load(self, path):
        self.s = torch.load(os.path.join(path, "s.torch"))
        self.s_n = torch.load(os.path.join(path, "s_n.torch"))
        self.a = torch.load(os.path.join(path, "a.torch"))
        self.r = torch.load(os.path.join(path, "r.torch"))
        self.d = torch.load(os.path.join(path, "d.torch"))
        with open(os.path.join(path, "meta.buffer"), "r") as f:
            self.fill_counter = int(f.readline())
            self.full = "True" == f.readline()
