# Standard library imports
import concurrent

# Third party imports
import torch
import numpy as np
import nle.dataset as nld
from nle import nethack

# Local application imports
from il_scale.nethack.utils.setup import create_env
import il_scale.nethack.utils.constants as CONSTANTS


class ParquetDataLoader():
    """
    Load a batch of parquet data into a torch tensor.
    """

    def __init__(
        self, 
        env_name: str,
        device: torch.device, 
        dataset_name: str = 'nld-aa',
        obs_frame_stack: int = 1,
        **dataset_kwargs
    ):
        # Create dataset
        self.dataset = nld.ParquetDataset(dataset_name, **dataset_kwargs)
        self.threadpool = dataset_kwargs["threadpool"] if "threadpool" in dataset_kwargs else None
        self.device = device
        self.gameids = self.dataset._gameids
        self.obs_frame_stack = obs_frame_stack

        # Create environment
        self.env = create_env(env_name, save_ttyrec_every=0)
        self.num_actions = len(self.env.actions)

        # Convert ASCII keypresses to action spaces indices
        embed_actions = torch.zeros((256, 1))
        for i, a in enumerate(nethack.ACTIONS):
            embed_actions[a.value][0] = i
        self.embed_actions = torch.nn.Embedding.from_pretrained(embed_actions)

        self.batch_size = dataset_kwargs['batch_size']
        self.unroll_length = dataset_kwargs['seq_length']
        self.prev_action_shape = (self.batch_size, self.unroll_length)

    def __iter__(self):
        """
        Returns a batch of parquet data.
        """
        return self.process_parquet_data()

    def process_parquet_data(self):
        def _iter():
            mb_tensors = {
                "prev_action": torch.zeros(self.prev_action_shape, dtype=torch.uint8)
            }

            prev_action = torch.zeros(
                (self.batch_size, 1), dtype=torch.uint8
            ).to(self.device)

            for i, batch in enumerate(self.dataset):

                if i == 0:
                    # create torch tensors from first minibatch
                    for k, array in batch.items():
                        mb_tensors[k] = torch.from_numpy(array)
                    if self.device != torch.device('cpu'):
                        [v.pin_memory() for v in mb_tensors.values()]

                    if self.obs_frame_stack == 1:
                        mb_tensors['tty_chars'].unsqueeze_(2)
                        mb_tensors['tty_colors'].unsqueeze_(2)

                # Populate screen image
                cursor_uint8 = batch["tty_cursor"].astype(np.uint8)

                # Convert actions
                actions = mb_tensors["keypresses"].long()
                actions_converted = self.embed_actions(
                    actions).squeeze(-1).long().to(self.device)

                final_mb = {
                    "tty_chars": mb_tensors["tty_chars"],
                    "tty_colors": mb_tensors["tty_colors"],
                    "tty_cursor": torch.from_numpy(cursor_uint8),
                    "blstats": mb_tensors["blstats"],
                    "inv_glyphs": mb_tensors["inv_glyphs"],
                    "glyphs": mb_tensors["glyphs"],
                    "done": mb_tensors["done"].bool(),
                    "labels": actions_converted,
                    "prev_action": torch.cat(
                        [prev_action, actions_converted[:, :-1]], dim=1
                    ),
                    "gameids": mb_tensors['gameids'],
                    "message": mb_tensors['message']
                }

                prev_action = actions_converted[:, -1:]

                # Dataset is B x T, but model expects T x B
                yielded_batch = {
                    k: t.transpose(0, 1).to(self.device)
                    for k, t in final_mb.items()
                }
                yielded_batch["unique_gameids"] = yielded_batch["gameids"].unique()

                yield yielded_batch

        return _iter()