import os
import pickle as pkl
from collections import namedtuple
from typing import List, Tuple, Callable

import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import IterableDataset

from config import Config
from driving_gridworld.actions import ACTIONS

class ShuffleDataset(IterableDataset):
    """
    Implements the IterableDataset abstract class.
    """
    def __init__(
            self,
            dataset: List[Tuple["State", np.array]],
            shuffle_after_restart: bool = True,
            batchsize: int = 1,
            seed: int = None,
    ):
        """
        :param dataset: A list of (Board, reward_array) tuples.
        :param shuffle_after_restart: Whether or not to shuffle after we reach the end
            of the list. Usually, this can be understood as whether or not we
            should shuffle after each epoch.
        :param batchsize: How big the batchsize should be from this dataset.
        """
        self.config = Config()
        self.device = self.config.device
        self.shuffle_after_restart = shuffle_after_restart
        self.batchsize = batchsize

        speeds = [[] for _ in range(self.config.speed_limit + 1)]
        for x in dataset:
            speed = x[0][0][1][0].item()
            speeds[speed].append(x)
        self.dataset = speeds

        if shuffle_after_restart:
            self.shuffle()
        self.minibatch_dataset = self.minibatch_transform()
        self.curr_idx = 0

    def shuffle(self):
        [np.random.shuffle(x) for x in self.dataset]

    def minibatch_transform(self) -> List[List[Tuple["State", np.array]]]:
        """
        :returns: A list of minibatches.
        """
        minibatch_dataset = []
        for dataset in self.dataset:
            curr_idx = 0
            while curr_idx < len(dataset):
                minibatch = dataset[curr_idx : curr_idx + self.batchsize]
                xs, ys = zip(*minibatch)
                xs_before, xs_after = zip(*xs)
                xs_before, xs_after = zip(*xs_before), zip(*xs_after)
                xs_before = [torch.stack(x, 0) for x in xs_before]
                xs_after = [torch.stack(x, 0) for x in xs_after]
                ys = torch.stack(ys, 0)
                minibatch_dataset.append(
                    (
                        (xs_before, xs_after),
                        ys)
                )
                curr_idx += self.batchsize
        return minibatch_dataset

    def __getitem__(self, idx: int) -> Tuple["State", np.array]:
        return self.minibatch_dataset[idx]

    def __iter__(self):
        while self.curr_idx < len(self.minibatch_dataset):
            yield self.minibatch_dataset[self.curr_idx]
            self.curr_idx += 1
        if self.shuffle_after_restart:
            self.shuffle()
            self.minibatch_dataset = self.minibatch_transform()
        self.curr_idx = 0

    def __len__(self):
        return len(self.minibatch_dataset) // self.batchsize
