from turtle import pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from fiery.models.cross_attention_height_polar_16 import Camera2BEV
from fiery.models.encoder_mine import Encoder
from fiery.models.temporal_model import TemporalModelIdentity, TemporalModel
from fiery.models.distributions import DistributionModule
from fiery.models.future_prediction import FuturePrediction
from fiery.models.decoder import Decoder, RingDecoder
from fiery.utils.network import pack_sequence_dim, unpack_sequence_dim, set_bn_momentum
from fiery.utils.geometry import cumulative_warp_features, calculate_birds_eye_view_parameters, VoxelsSumming
# from fiery.models.cross_attention import feature_sampling
import time
import numpy as np
import matplotlib.pyplot as plt

def cart2polar(input_xy):
    rho = torch.sqrt(input_xy[..., 0] ** 2 + input_xy[..., 1] ** 2)
    phi = torch.atan2(input_xy[..., 1], input_xy[..., 0])
    return torch.stack((rho, phi), dim=2)

def polar2cat(input_xy_polar):
    x = input_xy_polar[..., 0] * torch.cos(input_xy_polar[..., 1])
    y = input_xy_polar[..., 0] * torch.sin(input_xy_polar[..., 1])
    return torch.cat((x, y), dim=-1)

class Fiery(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        bev_resolution, bev_start_position, bev_dimension = calculate_birds_eye_view_parameters(
            self.cfg.LIFT.X_BOUND, self.cfg.LIFT.Y_BOUND, self.cfg.LIFT.Z_BOUND
        )
        self.bev_resolution = nn.Parameter(bev_resolution, requires_grad=False)
        self.bev_start_position = nn.Parameter(bev_start_position, requires_grad=False)
        self.bev_dimension = nn.Parameter(bev_dimension, requires_grad=False)

        self.encoder_downsample = self.cfg.MODEL.ENCODER.DOWNSAMPLE
        self.encoder_out_channels = self.cfg.MODEL.ENCODER.OUT_CHANNELS

        self.hq_channels = self.encoder_out_channels

        eps = 1e-6
        x = torch.arange(bev_start_position[0], self.cfg.LIFT.X_BOUND[1] + eps, self.cfg.LIFT.X_BOUND[2])
        y = torch.arange(bev_start_position[1], self.cfg.LIFT.Y_BOUND[1] + eps, self.cfg.LIFT.Y_BOUND[2])
        yy, xx = torch.meshgrid([y, x])
        xy_cord = torch.stack([xx, yy], dim=-1)
        self.polar_cord = cart2polar(xy_cord)
        
        # import pdb; pdb.set_trace()
        # 重新平均划分
        # 原始的BEV分辨率是[200, 200]，polar的半径应该是sqrt(50**2 + 50**2)，半径的分辨率是0.5，极角应该是200 * 4
        polar_grid = []
        polar_r = torch.sqrt(self.bev_start_position[0] ** 2 + self.bev_start_position[1] ** 2)
        
        self.polar_r = polar_r
        polar_r_resolution = torch.sqrt(self.bev_resolution[0] ** 2 + self.bev_resolution[1] ** 2)
        # polar_r_resolution = self.bev_resolution[0]
        self.polar_r_resolution = polar_r_resolution
        polar_grid.append((polar_r + polar_r_resolution) // polar_r_resolution)
        polar_grid.append(self.bev_dimension[0] + self.bev_dimension[1])
        self.polar_grid = polar_grid
        
        intervals = [(polar_r + polar_r_resolution) / polar_grid[0]]
        intervals.append(2 * np.pi / polar_grid[1])
        self.intervals = intervals
        # 应该使用中间位置来代替扇形块的特征
        rho = torch.linspace(0., 1., int(polar_grid[0].item() + 1))
        phi = torch.linspace(0., 1., int(polar_grid[1].item() + 1))
        # rho, phi = torch.meshgrid(rho[..., :-1], phi[..., :-1])
        phi, rho = torch.meshgrid(phi[..., :-1], rho[..., :-1])
        

        zz = torch.ones([int(polar_grid[1].item()), int(polar_grid[0].item())]) * 0.55
        pc_range = [polar_r + polar_r_resolution, 2*np.pi, self.cfg.LIFT.Z_BOUND[0], self.cfg.LIFT.Z_BOUND[1]]

        self.hq_number = int(polar_grid[0].item() * polar_grid[1].item())
        self.img2bev = Camera2BEV(int(polar_grid[1].item()), int(polar_grid[0].item()), self.hq_number, self.hq_channels, n_layers=3, n_cameras=6, n_levels=1, rho=rho, phi=phi, zz=zz, intervals=intervals, pc_range=pc_range, img_size=(self.cfg.IMAGE.FINAL_DIM[1], self.cfg.IMAGE.FINAL_DIM[0]), n_classes=len(self.cfg.SEMANTIC_SEG.WEIGHTS))
        

        if self.cfg.TIME_RECEPTIVE_FIELD == 1:
            assert self.cfg.MODEL.TEMPORAL_MODEL.NAME == 'identity'

        # temporal block
        self.receptive_field = self.cfg.TIME_RECEPTIVE_FIELD
        self.n_future = self.cfg.N_FUTURE_FRAMES
        self.latent_dim = self.cfg.MODEL.DISTRIBUTION.LATENT_DIM

        if self.cfg.MODEL.SUBSAMPLE:
            assert self.cfg.DATASET.NAME == 'lyft'
            self.receptive_field = 3
            self.n_future = 5

        # Spatial extent in bird's-eye view, in meters
        self.spatial_extent = (self.cfg.LIFT.X_BOUND[1], self.cfg.LIFT.Y_BOUND[1])
        self.bev_size = (self.bev_dimension[0].item(), self.bev_dimension[1].item())

        # Encoder
        self.encoder = Encoder(cfg=self.cfg.MODEL.ENCODER)

        # Temporal model
        temporal_in_channels = self.encoder_out_channels
        if self.cfg.MODEL.TEMPORAL_MODEL.INPUT_EGOPOSE:
            temporal_in_channels += 6
        if self.cfg.MODEL.TEMPORAL_MODEL.NAME == 'identity':
            self.temporal_model = TemporalModelIdentity(temporal_in_channels, self.receptive_field)
        elif cfg.MODEL.TEMPORAL_MODEL.NAME == 'temporal_block':
            self.temporal_model = TemporalModel(
                temporal_in_channels,
                self.receptive_field,
                input_shape=self.bev_size,
                start_out_channels=self.cfg.MODEL.TEMPORAL_MODEL.START_OUT_CHANNELS,
                extra_in_channels=self.cfg.MODEL.TEMPORAL_MODEL.EXTRA_IN_CHANNELS,
                n_spatial_layers_between_temporal_layers=self.cfg.MODEL.TEMPORAL_MODEL.INBETWEEN_LAYERS,
                use_pyramid_pooling=self.cfg.MODEL.TEMPORAL_MODEL.PYRAMID_POOLING,
            )
        else:
            raise NotImplementedError(f'Temporal module {self.cfg.MODEL.TEMPORAL_MODEL.NAME}.')

        self.future_pred_in_channels = self.temporal_model.out_channels
        if self.n_future > 0:
            # probabilistic sampling
            if self.cfg.PROBABILISTIC.ENABLED:
                # Distribution networks
                self.present_distribution = DistributionModule(
                    self.future_pred_in_channels,
                    self.latent_dim,
                    min_log_sigma=self.cfg.MODEL.DISTRIBUTION.MIN_LOG_SIGMA,
                    max_log_sigma=self.cfg.MODEL.DISTRIBUTION.MAX_LOG_SIGMA,
                )

                future_distribution_in_channels = (self.future_pred_in_channels
                                                   + self.n_future * self.cfg.PROBABILISTIC.FUTURE_DIM
                                                   )
                self.future_distribution = DistributionModule(
                    future_distribution_in_channels,
                    self.latent_dim,
                    min_log_sigma=self.cfg.MODEL.DISTRIBUTION.MIN_LOG_SIGMA,
                    max_log_sigma=self.cfg.MODEL.DISTRIBUTION.MAX_LOG_SIGMA,
                )

            # Future prediction
            self.future_prediction = FuturePrediction(
                in_channels=self.future_pred_in_channels,
                latent_dim=self.latent_dim,
                n_gru_blocks=self.cfg.MODEL.FUTURE_PRED.N_GRU_BLOCKS,
                n_res_layers=self.cfg.MODEL.FUTURE_PRED.N_RES_LAYERS,
            )

        # Decoder
        # self.decoder = Decoder(
        #     in_channels=self.future_pred_in_channels,
        #     n_classes=len(self.cfg.SEMANTIC_SEG.WEIGHTS),
        #     predict_future_flow=self.cfg.INSTANCE_FLOW.ENABLED,
        # )
        self.decoder = RingDecoder(
            in_channels=self.future_pred_in_channels,
            n_classes=len(self.cfg.SEMANTIC_SEG.WEIGHTS),
            predict_future_flow=self.cfg.INSTANCE_FLOW.ENABLED,
        )

        set_bn_momentum(self, self.cfg.MODEL.BN_MOMENTUM)

    def create_frustum(self):
        # Create grid in image plane
        h, w = self.cfg.IMAGE.FINAL_DIM
        downsampled_h, downsampled_w = h // self.encoder_downsample, w // self.encoder_downsample

        # Depth grid
        depth_grid = torch.arange(*self.cfg.LIFT.D_BOUND, dtype=torch.float)
        depth_grid = depth_grid.view(-1, 1, 1).expand(-1, downsampled_h, downsampled_w)
        n_depth_slices = depth_grid.shape[0]

        # x and y grids
        x_grid = torch.linspace(0, w - 1, downsampled_w, dtype=torch.float)
        x_grid = x_grid.view(1, 1, downsampled_w).expand(n_depth_slices, downsampled_h, downsampled_w)
        y_grid = torch.linspace(0, h - 1, downsampled_h, dtype=torch.float)
        y_grid = y_grid.view(1, downsampled_h, 1).expand(n_depth_slices, downsampled_h, downsampled_w)

        # Dimension (n_depth_slices, downsampled_h, downsampled_w, 3)
        # containing data points in the image: left-right, top-bottom, depth
        frustum = torch.stack((x_grid, y_grid, depth_grid), -1)
        return nn.Parameter(frustum, requires_grad=False)

    def forward(self, image, intrinsics, extrinsics, lidar2imgs, future_egomotion, future_distribution_inputs=None, noise=None):
        # start = time.time()
        output = {}

        # Only process features from the past and present
        image = image[:, :self.receptive_field].contiguous()
        intrinsics = intrinsics[:, :self.receptive_field].contiguous()
        extrinsics = extrinsics[:, :self.receptive_field].contiguous()
        lidar2imgs = lidar2imgs[:, :self.receptive_field].contiguous()
        future_egomotion = future_egomotion[:, :self.receptive_field].contiguous()

        x, inter_seg, inter_instance_offset, inter_instance_center, xy_cord = self.calculate_birds_eye_view_features(image, intrinsics, extrinsics, lidar2imgs)

        # Warp past features to the present's reference frame
        x = cumulative_warp_features(
            x.clone(), future_egomotion,
            mode='bilinear', spatial_extent=self.spatial_extent,
        )
        # print(x.requires_grad, x.grad)
        if self.cfg.MODEL.TEMPORAL_MODEL.INPUT_EGOPOSE:
            b, s, c = future_egomotion.shape
            h, w = x.shape[-2:]
            future_egomotions_spatial = future_egomotion.view(b, s, c, 1, 1).expand(b, s, c, h, w)
            # at time 0, no egomotion so feed zero vector
            future_egomotions_spatial = torch.cat([torch.zeros_like(future_egomotions_spatial[:, :1]),
                                                   future_egomotions_spatial[:, :(self.receptive_field-1)]], dim=1)
            x = torch.cat([x, future_egomotions_spatial], dim=-3)

        #  Temporal model
        states = self.temporal_model(x)

        if self.n_future > 0:
            # Split into present and future features (for the probabilistic model)
            present_state = states[:, :1].contiguous()
            if self.cfg.PROBABILISTIC.ENABLED:
                # Do probabilistic computation
                sample, output_distribution = self.distribution_forward(
                    present_state, future_distribution_inputs, noise
                )
                output = {**output, **output_distribution}

            # Prepare future prediction input
            b, _, _, h, w = present_state.shape
            hidden_state = present_state[:, 0]

            if self.cfg.PROBABILISTIC.ENABLED:
                future_prediction_input = sample.expand(-1, self.n_future, -1, -1, -1)
            else:
                future_prediction_input = hidden_state.new_zeros(b, self.n_future, self.latent_dim, h, w)

            # Recursively predict future states
            future_states = self.future_prediction(future_prediction_input, hidden_state)

            # Concatenate present state
            future_states = torch.cat([present_state, future_states], dim=1)

        # Predict bird's-eye view outputs
        if self.n_future > 0:
            bev_output = self.decoder(future_states)
        else:
            bev_output = self.decoder(states[:, -1:])
            
        # resample to cart
        b, s, c, h, w = bev_output['segmentation'].shape
        bev_output['segmentation'] = pack_sequence_dim(bev_output['segmentation'])
        bev_output['segmentation'] = F.pad(bev_output['segmentation'], (0, 0, int(self.polar_grid[1].item())//2, int(self.polar_grid[1].item()) - int(self.polar_grid[1].item())//2), mode='circular')
        bev_output['segmentation'] = F.grid_sample(bev_output['segmentation'], xy_cord)
        bev_output['segmentation'] = bev_output['segmentation'].view(b*s, c, self.bev_dimension[0].item(), self.bev_dimension[1].item()).permute(0, 1, 3, 2)
        # bev_output['segmentation'] = bev_output['segmentation'].view(b*s, c, self.bev_dimension[0].item(), self.bev_dimension[1].item())
        bev_output['segmentation'] = unpack_sequence_dim(bev_output['segmentation'], b, s)
        b, s, c, h, w = bev_output['instance_center'].shape
        bev_output['instance_center'] = pack_sequence_dim(bev_output['instance_center'])
        bev_output['instance_center'] = F.pad(bev_output['instance_center'], (0, 0, int(self.polar_grid[1].item())//2, int(self.polar_grid[1].item()) - int(self.polar_grid[1].item())//2), mode='circular')
        bev_output['instance_center'] = F.grid_sample(bev_output['instance_center'], xy_cord)
        bev_output['instance_center'] = bev_output['instance_center'].view(b*s, c, self.bev_dimension[0].item(), self.bev_dimension[1].item()).permute(0, 1, 3, 2)
        # bev_output['instance_center'] = bev_output['instance_center'].view(b*s, c, self.bev_dimension[0].item(), self.bev_dimension[1].item())
        bev_output['instance_center'] = unpack_sequence_dim(bev_output['instance_center'], b, s)
        b, s, c, h, w = bev_output['instance_offset'].shape
        bev_output['instance_offset'] = pack_sequence_dim(bev_output['instance_offset'])
        bev_output['instance_offset'] = F.pad(bev_output['instance_offset'], (0, 0, int(self.polar_grid[1].item())//2, int(self.polar_grid[1].item()) - int(self.polar_grid[1].item())//2), mode='circular')
        bev_output['instance_offset'] = F.grid_sample(bev_output['instance_offset'], xy_cord)
        bev_output['instance_offset'] = bev_output['instance_offset'].view(b*s, c, self.bev_dimension[0].item(), self.bev_dimension[1].item()).permute(0, 1, 3, 2)
        # bev_output['instance_offset'] = bev_output['instance_offset'].view(b*s, c, self.bev_dimension[0].item(), self.bev_dimension[1].item())
        bev_output['instance_offset'] = unpack_sequence_dim(bev_output['instance_offset'], b, s)
        output = {**output, **bev_output}
        
        return output, inter_seg, inter_instance_offset, inter_instance_center

    def get_geometry(self, intrinsics, extrinsics):
        """Calculate the (x, y, z) 3D position of the features.
        """
        rotation, translation = extrinsics[..., :3, :3], extrinsics[..., :3, 3]
        B, N, _ = translation.shape
        # Add batch, camera dimension, and a dummy dimension at the end
        points = self.frustum.unsqueeze(0).unsqueeze(0).unsqueeze(-1)

        # Camera to ego reference frame
        points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3], points[:, :, :, :, :, 2:3]), 5)
        combined_transformation = rotation.matmul(torch.inverse(intrinsics))
        points = combined_transformation.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1)
        points += translation.view(B, N, 1, 1, 1, 3)

        # The 3 dimensions in the ego reference frame are: (forward, sides, height)
        return points

    def encoder_forward(self, x):
        # import pdb; pdb.set_trace()
        # batch, n_cameras, channels, height, width
        b, n, c, h, w = x.shape

        x = x.view(b * n, c, h, w)
        x = self.encoder(x)
        x = x.view(b, n, *x.shape[1:])
        # [batch, n_cameras, height, weight, channels]
        x = x.permute(0, 1, 3, 4, 2)

        return x

    def projection_to_birds_eye_view(self, x, geometry):
        """ Adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/models.py#L200"""
        # batch, n_cameras, depth, height, width, channels
        batch, n, d, h, w, c = x.shape
        output = torch.zeros(
            (batch, c, self.bev_dimension[0], self.bev_dimension[1]), dtype=torch.float, device=x.device
        )

        # Number of 3D points
        N = n * d * h * w
        for b in range(batch):
            # flatten x
            x_b = x[b].reshape(N, c)

            # Convert positions to integer indices
            geometry_b = ((geometry[b] - (self.bev_start_position - self.bev_resolution / 2.0)) / self.bev_resolution)
            geometry_b = geometry_b.view(N, 3).long()

            # Mask out points that are outside the considered spatial extent.
            mask = (
                    (geometry_b[:, 0] >= 0)
                    & (geometry_b[:, 0] < self.bev_dimension[0])
                    & (geometry_b[:, 1] >= 0)
                    & (geometry_b[:, 1] < self.bev_dimension[1])
                    & (geometry_b[:, 2] >= 0)
                    & (geometry_b[:, 2] < self.bev_dimension[2])
            )
            x_b = x_b[mask]
            geometry_b = geometry_b[mask]

            # Sort tensors so that those within the same voxel are consecutives.
            ranks = (
                    geometry_b[:, 0] * (self.bev_dimension[1] * self.bev_dimension[2])
                    + geometry_b[:, 1] * (self.bev_dimension[2])
                    + geometry_b[:, 2]
            )
            ranks_indices = ranks.argsort()
            x_b, geometry_b, ranks = x_b[ranks_indices], geometry_b[ranks_indices], ranks[ranks_indices]

            # Project to bird's-eye view by summing voxels.
            x_b, geometry_b = VoxelsSumming.apply(x_b, geometry_b, ranks)

            bev_feature = torch.zeros((self.bev_dimension[2], self.bev_dimension[0], self.bev_dimension[1], c),
                                      device=x_b.device)
            bev_feature[geometry_b[:, 2], geometry_b[:, 0], geometry_b[:, 1]] = x_b

            # Put channel in second position and remove z dimension
            bev_feature = bev_feature.permute((0, 3, 1, 2))
            bev_feature = bev_feature.squeeze(0)

            output[b] = bev_feature

        return output

    def calculate_birds_eye_view_features(self, x, intrinsics, extrinsics, lidar2imgs):
        b, s, n, c, h, w = x.shape
        # Reshape
        x = pack_sequence_dim(x)
        intrinsics = pack_sequence_dim(intrinsics)
        extrinsics = pack_sequence_dim(extrinsics)
        lidar2imgs = pack_sequence_dim(lidar2imgs)

        x = self.encoder_forward(x)
        
        x_cord = self.polar_cord[..., 0:1] / (self.polar_r + self.polar_r_resolution)
        x_cord = (x_cord - 0.5) * 2
        x_cord = x_cord.unsqueeze(0).repeat(b*s, 1, 1, 1).view(b*s, -1, 1)
        y_cord = self.polar_cord[..., 1:2] / (np.pi) / 2
        y_cord = y_cord.unsqueeze(0).repeat(b*s, 1, 1, 1).view(b*s, -1, 1)
        xy_cord = torch.cat((x_cord, y_cord), dim=-1)
        xy_cord = xy_cord.view(b*s, -1, 1, 2)
        xy_cord = xy_cord.to(x)
        x, inter_seg, inter_instance_offset, inter_instance_center, reg = self.img2bev(x, lidar2imgs, xy_cord)
        
        x = x.permute(1, 2, 0).view(b * s, -1, int(self.polar_grid[1].item()), int(self.polar_grid[0].item()))


        for i in range(len(inter_seg) - 1):
            inter_seg[i] = inter_seg[i].permute(1, 2, 0).view(b*s, -1, int(self.polar_grid[1].item()), int(self.polar_grid[0].item()))
            inter_seg[i] = F.pad(inter_seg[i], (0, 0, int(self.polar_grid[1].item())//2, int(self.polar_grid[1].item()) - int(self.polar_grid[1].item())//2), mode='circular')
            inter_seg[i] = F.grid_sample(inter_seg[i], xy_cord)
            inter_seg[i] = inter_seg[i].view(b, s, -1, self.bev_dimension[0].item(), self.bev_dimension[1].item()).permute(0, 1, 2, 4, 3)
            # inter_seg[i] = inter_seg[i].view(b, s, -1, self.bev_dimension[0].item(), self.bev_dimension[1].item())
            
            inter_instance_center[i] = inter_instance_center[i].permute(1, 2, 0).view(b*s, -1, int(self.polar_grid[1].item()), int(self.polar_grid[0].item()))
            inter_instance_center[i] = F.pad(inter_instance_center[i], (0, 0, int(self.polar_grid[1].item())//2, int(self.polar_grid[1].item()) - int(self.polar_grid[1].item())//2), mode='circular')
            inter_instance_center[i] = F.grid_sample(inter_instance_center[i], xy_cord)
            inter_instance_center[i] = inter_instance_center[i].view(b, s, -1, self.bev_dimension[0].item(), self.bev_dimension[1].item()).permute(0, 1, 2, 4, 3)
            # inter_instance_center[i] = inter_instance_center[i].view(b, s, -1, self.bev_dimension[0].item(), self.bev_dimension[1].item())
            
            inter_instance_offset[i] = inter_instance_offset[i].permute(1, 2, 0).view(b*s, -1, int(self.polar_grid[1].item()), int(self.polar_grid[0].item()))
            inter_instance_offset[i] = F.pad(inter_instance_offset[i], (0, 0, int(self.polar_grid[1].item())//2, int(self.polar_grid[1].item()) - int(self.polar_grid[1].item())//2), mode='circular')
            inter_instance_offset[i] = F.grid_sample(inter_instance_offset[i], xy_cord)
            inter_instance_offset[i] = inter_instance_offset[i].view(b, s, -1, self.bev_dimension[0].item(), self.bev_dimension[1].item()).permute(0, 1, 2, 4, 3)
            # inter_instance_offset[i] = inter_instance_offset[i].view(b, s, -1, self.bev_dimension[0].item(), self.bev_dimension[1].item())
            
            # reg[i] = reg[i].permute(1, 2, 0).view(b*s, -1, int(self.polar_grid[1].item()), int(self.polar_grid[0].item()))
            # reg[i] = F.grid_sample(reg[i], xy_cord)
            # reg[i] = reg[i].view(b, s, -1, self.bev_dimension[0].item(), self.bev_dimension[1].item()).permute(0, 1, 2, 4, 3)
            # reg[i] = torch.clamp(reg[i], 0.35, 0.6)
            # import torchvision; torchvision.utils.save_image(reg[i].view(1, 1, 1, 200, 200)[0], 'height' + str(i) + '.png', padding=1, normalize=True)
            # 这里如果计算loss的话，都需要采样到直角坐标系。

        x = unpack_sequence_dim(x, b, s)
        return x, inter_seg, inter_instance_offset, inter_instance_center, xy_cord

    def distribution_forward(self, present_features, future_distribution_inputs=None, noise=None):
        """
        Parameters
        ----------
            present_features: 5-D output from dynamics module with shape (b, 1, c, h, w)
            future_distribution_inputs: 5-D tensor containing labels shape (b, s, cfg.PROB_FUTURE_DIM, h, w)
            noise: a sample from a (0, 1) gaussian with shape (b, s, latent_dim). If None, will sample in function

        Returns
        -------
            sample: sample taken from present/future distribution, broadcast to shape (b, s, latent_dim, h, w)
            present_distribution_mu: shape (b, s, latent_dim)
            present_distribution_log_sigma: shape (b, s, latent_dim)
            future_distribution_mu: shape (b, s, latent_dim)
            future_distribution_log_sigma: shape (b, s, latent_dim)
        """
        b, s, _, h, w = present_features.size()
        assert s == 1

        present_mu, present_log_sigma = self.present_distribution(present_features)

        future_mu, future_log_sigma = None, None
        if future_distribution_inputs is not None:
            # Concatenate future labels to z_t
            future_features = future_distribution_inputs[:, 1:].contiguous().view(b, 1, -1, h, w)
            future_features = torch.cat([present_features, future_features], dim=2)
            future_mu, future_log_sigma = self.future_distribution(future_features)

        if noise is None:
            if self.training:
                noise = torch.randn_like(present_mu)
            else:
                noise = torch.zeros_like(present_mu)
        if self.training:
            mu = future_mu
            sigma = torch.exp(future_log_sigma)
        else:
            mu = present_mu
            sigma = torch.exp(present_log_sigma)
        sample = mu + sigma * noise

        # Spatially broadcast sample to the dimensions of present_features
        sample = sample.view(b, s, self.latent_dim, 1, 1).expand(b, s, self.latent_dim, h, w)

        output_distribution = {
            'present_mu': present_mu,
            'present_log_sigma': present_log_sigma,
            'future_mu': future_mu,
            'future_log_sigma': future_log_sigma,
        }

        return sample, output_distribution
