import torch
import numpy as np
from einops import rearrange

from mawm.models.world_models import MAWM, MATWM, MAWMWithRNN
from mawm.envs import RailEnvTrajectories

from typing import Union
from addict import Dict


def generate_grid_positions(size: tuple) -> torch.Tensor:
    # noinspection PyTypeChecker
    grid_i, grid_j = torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))
    return torch.stack([grid_i, grid_j], dim=-1).view(size[0] * size[1], 2)


def railenv_local_to_global_state(world_model: Union[MAWM, MATWM],
                                  dataset: RailEnvTrajectories, idx: int,
                                  device='cpu'):
    # Sample from dataset.
    actions, positions, goals, states, target_states = dataset[idx]
    with torch.no_grad():
        diagnostics = world_model(actions[None].to(device), positions[None].to(device), states[None].to(device))
    # Sample positions densely on a grid, resulting in a  (h*w)2 tensor.
    grid_positions = generate_grid_positions(dataset.env_shape).float()
    # Encode positions. Remember that the positional encoder wants a NTA2 tensor.
    with torch.no_grad():
        # Resulting shape: 11AC
        grid_positional_encodings = world_model.positional_encoder(grid_positions[None, None, :, :].to(device))
        # Expand along the time axis to 1TAC
        grid_positional_encodings = grid_positional_encodings.expand(1, actions.shape[0],
                                                                     grid_positional_encodings.shape[2],
                                                                     grid_positional_encodings.shape[3])
    with torch.no_grad():
        if isinstance(world_model, MATWM):
            # Topological World Model
            # Query the RNN with the grid_positional_encodings. This should yield a 1TAC tensor, where A = H * W.
            position_aware_rnn_output = world_model.rnn.query(diagnostics.rnn_output, grid_positional_encodings)
            # windows should be a 1TAChw tensor
            windows = world_model.decoder(position_aware_rnn_output).sigmoid()
        elif isinstance(world_model, MAWM):
            # Non-topological World Model
            # Use the RNN state of shape 1TC together with the encodings of shape 1TAC to generate a 1TAChw tensor,
            # where A = H*W
            windows = world_model.decoder(diagnostics.rnn_output, grid_positional_encodings).sigmoid()
    # Convert the 1TAChw tensor to T(hw)CHW tensor
    _H, _W = dataset.env_shape
    windows = rearrange(windows, 'n t (gh gw) c h w -> n t (h w) c gh gw', gh=_H, gw=_W)[0].cpu()
    return windows


def visualize_sc2_recons(recons, states):
    assert isinstance(recons, Dict)
    assert isinstance(states, Dict)
    # recons and states have keys (with num channels): hecs (4), friendly_markers (1), unit_types (9), terrain (1) and
    # spatial_markers (2)
    import matplotlib.pyplot as plt
    import seaborn as sns
    sns.set()

    def _map_colors(x):
        # palette.shape = C3
        palette = torch.from_numpy(np.array(sns.color_palette(n_colors=x.shape[3]))).float()
        # x.shape = NTACrΘ
        return torch.einsum('ntacrh,co->ntarho', [x, palette]).numpy()

    def _get_vbounds(recon, target):
        vmin = min(recon.min().item(), target.min().item())
        vmax = max(recon.max().item(), target.max().item())
        return vmin, vmax

    # noinspection PyTypeChecker
    fig, axs = plt.subplots(ncols=9, nrows=2, sharex=True, sharey=True, figsize=(30, 4))
    # Config
    for ax in axs.ravel():
        ax.grid(False)
    # hecs
    _h_vmin, _h_vmax = _get_vbounds(recons.hecs[:, :, :, 0], states.hecs[:, :, :, 0])
    _e_vmin, _e_vmax = _get_vbounds(recons.hecs[:, :, :, 1], states.hecs[:, :, :, 1])
    _c_vmin, _c_vmax = _get_vbounds(recons.hecs[:, :, :, 2], states.hecs[:, :, :, 2])
    _s_vmin, _s_vmax = _get_vbounds(recons.hecs[:, :, :, 3], states.hecs[:, :, :, 3])
    # h
    axs[0, 0].imshow(recons.hecs[0, 0, 0, 0].numpy(), vmin=_h_vmin, vmax=_h_vmax)
    axs[1, 0].imshow(states.hecs[0, 0, 0, 0].numpy(), vmin=_h_vmin, vmax=_h_vmax)
    # e
    axs[0, 1].imshow(recons.hecs[0, 0, 0, 1].numpy(), vmin=_e_vmin, vmax=_e_vmax)
    axs[1, 1].imshow(states.hecs[0, 0, 0, 1].numpy(), vmin=_e_vmin, vmax=_e_vmax)
    # c
    axs[0, 2].imshow(recons.hecs[0, 0, 0, 2].numpy(), vmin=_c_vmin, vmax=_c_vmax)
    axs[1, 2].imshow(states.hecs[0, 0, 0, 2].numpy(), vmin=_c_vmin, vmax=_c_vmax)
    # s
    axs[0, 3].imshow(recons.hecs[0, 0, 0, 3].numpy(), vmin=_s_vmin, vmax=_s_vmax)
    axs[1, 3].imshow(states.hecs[0, 0, 0, 3].numpy(), vmin=_s_vmin, vmax=_s_vmax)
    # friendly_markers
    axs[0, 4].imshow(recons.friendly_marker[0, 0, 0, 0].numpy(), vmin=0, vmax=1)
    axs[1, 4].imshow(states.friendly_marker[0, 0, 0, 0].numpy(), vmin=0, vmax=1)
    # unit_types
    axs[0, 5].imshow(_map_colors(recons.unit_types)[0, 0, 0])
    axs[1, 5].imshow(_map_colors(states.unit_types)[0, 0, 0])
    # terrain
    _terrain_vmin, _terrain_vmax = _get_vbounds(recons.terrain, states.terrain)
    axs[0, 6].imshow(recons.terrain[0, 0, 0, 0].numpy(), vmin=_terrain_vmin, vmax=_terrain_vmax)
    axs[1, 6].imshow(states.terrain[0, 0, 0, 0].numpy(), vmin=_terrain_vmin, vmax=_terrain_vmax)
    # spatial_markers
    _sm_r_vmin, _sm_r_vmax = _get_vbounds(recons.spatial_markers[:, :, :, 0], states.spatial_markers[:, :, :, 0])
    _sm_t_vmin, _sm_t_vmax = _get_vbounds(recons.spatial_markers[:, :, :, 1], states.spatial_markers[:, :, :, 1])
    # r
    axs[0, 7].imshow(recons.spatial_markers[0, 0, 0, 0].numpy(), vmin=_sm_r_vmin, vmax=_sm_r_vmax)
    axs[1, 7].imshow(states.spatial_markers[0, 0, 0, 0].numpy(), vmin=_sm_r_vmin, vmax=_sm_r_vmax)
    # Θ
    axs[0, 8].imshow(recons.spatial_markers[0, 0, 0, 1].numpy(), vmin=_sm_t_vmin, vmax=_sm_t_vmax)
    axs[1, 8].imshow(states.spatial_markers[0, 0, 0, 1].numpy(), vmin=_sm_t_vmin, vmax=_sm_t_vmax)
    # Done; return figure for plotting
    return fig


if __name__ == '__main__':
    wm = MAWMWithRNN()
    path = '/Users/nrahaman/Python/mawm/data/railenv_trajs_2.h5'
    dataset = RailEnvTrajectories((5, 5), 128, path, 64, split=+(3 / 4))
    windows = railenv_local_to_global_state(wm, dataset, 0)
