import numpy as np
import torch

from . import utils

numpy_to_torch_dtype_dict = {
    np.dtype("bool"): torch.bool,
    np.dtype("uint8"): torch.uint8,
    np.dtype("int8"): torch.int8,
    np.dtype("uint16"): torch.int32,
    np.dtype("int16"): torch.int16,
    np.dtype("int32"): torch.int32,
    np.dtype("int64"): torch.int64,
    np.dtype("float16"): torch.float16,
    np.dtype("float32"): torch.float32,
    np.dtype("float64"): torch.float64,
    np.dtype("complex64"): torch.complex64,
    np.dtype("complex128"): torch.complex128,
}


SKIP_KEYS = ("grounding_output",)


class Buffers:
    def __init__(self, buffers):
        self._buffers = buffers

    def update(self, index, t, skip_keys=SKIP_KEYS, **kwargs):
        for kw in kwargs:
            if kw not in skip_keys:
                self._buffers[kw][index][t] = kwargs[kw]

    def get_batch(self, indices, device=None):
        batch = {
            key: torch.stack([self._buffers[key][m] for m in indices], dim=1)
            for key in self._buffers
        }
        if device is not None:
            batch = utils.map_dict(lambda t: t.to(device, non_blocking=True), batch)
        return batch


def create_buffers(
    obs_spaces, num_actions, logits_size, raw_goal_size, num_instrs, FLAGS
) -> Buffers:
    T = FLAGS.unroll_length

    specs = dict(
        reward=dict(size=(T + 1,), dtype=torch.float32),
        done=dict(size=(T + 1,), dtype=torch.bool),
        reached=dict(size=(T + 1,), dtype=torch.bool),
        intrinsic_done=dict(size=(T + 1,), dtype=torch.bool),
        subgoal_done=dict(size=(T + 1, num_instrs), dtype=torch.uint8),
        subgoal_achievable=dict(size=(T + 1, num_instrs), dtype=torch.uint8),
        episode_return=dict(size=(T + 1,), dtype=torch.float32),
        intrinsic_episode_step=dict(size=(T + 1,), dtype=torch.int32),
        extrinsic_episode_step=dict(size=(T + 1,), dtype=torch.int32),
        last_action=dict(size=(T + 1,), dtype=torch.int64),
        policy_logits=dict(size=(T + 1, num_actions), dtype=torch.float32),
        baseline=dict(size=(T + 1,), dtype=torch.float32),
        int_baseline=dict(size=(T + 1,), dtype=torch.float32),
        generator_baseline=dict(size=(T + 1,), dtype=torch.float32),
        action=dict(size=(T + 1,), dtype=torch.int64),
        episode_win=dict(size=(T + 1,), dtype=torch.int32),
        generator_logits=dict(size=(T + 1, logits_size), dtype=torch.float32),
        goal=dict(size=(T + 1,), dtype=torch.int64),
        state_visits=dict(size=(T + 1, 1), dtype=torch.int64),
        state_visits_m=dict(size=(T + 1, 1), dtype=torch.int64),
        raw_goal=dict(size=(T + 1, raw_goal_size), dtype=torch.int64),
        carried_col=dict(size=(T + 1,), dtype=torch.int64),
        carried_obj=dict(size=(T + 1,), dtype=torch.int64),
    )
    for key, (obs_shape, obs_dtype) in obs_spaces.items():
        if obs_dtype in numpy_to_torch_dtype_dict:
            torch_dtype = numpy_to_torch_dtype_dict[obs_dtype]
        else:
            torch_dtype = obs_dtype

        specs[key] = dict(size=(T + 1, *obs_shape), dtype=torch_dtype)
        specs[f"initial_{key}"] = dict(size=(T + 1, *obs_shape), dtype=torch_dtype)

    buffers: Buffers = {key: [] for key in specs}
    for _ in range(FLAGS.num_buffers):
        for key in buffers:
            buffers[key].append(torch.empty(**specs[key]).share_memory_())
    return Buffers(buffers)
