# Copyright (c) OpenMMLab. All rights reserved.
import math

import numpy as np
import torch
import torch.nn as nn
from mmengine.model import BaseModel

from mmaction.registry import MODELS
from .utils import post_processing, temporal_iop, temporal_iou


@MODELS.register_module()
class BMN(BaseModel):
    """Boundary Matching Network for temporal action proposal generation.

    Please refer `BMN: Boundary-Matching Network for Temporal Action Proposal
    Generation <https://arxiv.org/abs/1907.09702>`_.
    Code Reference https://github.com/JJBOY/BMN-Boundary-Matching-Network
    Args:
        temporal_dim (int): Total frames selected for each video.
        boundary_ratio (float): Ratio for determining video boundaries.
        num_samples (int): Number of samples for each proposal.
        num_samples_per_bin (int): Number of bin samples for each sample.
        feat_dim (int): Feature dimension.
        soft_nms_alpha (float): Soft NMS alpha.
        soft_nms_low_threshold (float): Soft NMS low threshold.
        soft_nms_high_threshold (float): Soft NMS high threshold.
        post_process_top_k (int): Top k proposals in post process.
        feature_extraction_interval (int):
            Interval used in feature extraction. Default: 16.
        loss_cls (dict): Config for building loss.
            Default: ``dict(type='BMNLoss')``.
        hidden_dim_1d (int): Hidden dim for 1d conv. Default: 256.
        hidden_dim_2d (int): Hidden dim for 2d conv. Default: 128.
        hidden_dim_3d (int): Hidden dim for 3d conv. Default: 512.
    """

    def __init__(self,
                 temporal_dim,
                 boundary_ratio,
                 num_samples,
                 num_samples_per_bin,
                 feat_dim,
                 soft_nms_alpha,
                 soft_nms_low_threshold,
                 soft_nms_high_threshold,
                 post_process_top_k,
                 feature_extraction_interval=16,
                 loss_cls=dict(type='BMNLoss'),
                 hidden_dim_1d=256,
                 hidden_dim_2d=128,
                 hidden_dim_3d=512):
        super().__init__()

        self.tscale = temporal_dim
        self.boundary_ratio = boundary_ratio
        self.num_samples = num_samples
        self.num_samples_per_bin = num_samples_per_bin
        self.feat_dim = feat_dim
        self.soft_nms_alpha = soft_nms_alpha
        self.soft_nms_low_threshold = soft_nms_low_threshold
        self.soft_nms_high_threshold = soft_nms_high_threshold
        self.post_process_top_k = post_process_top_k
        self.feature_extraction_interval = feature_extraction_interval
        self.loss_cls = MODELS.build(loss_cls)
        self.hidden_dim_1d = hidden_dim_1d
        self.hidden_dim_2d = hidden_dim_2d
        self.hidden_dim_3d = hidden_dim_3d

        self._get_interp1d_mask()

        # Base Module
        self.x_1d_b = nn.Sequential(
            nn.Conv1d(
                self.feat_dim,
                self.hidden_dim_1d,
                kernel_size=3,
                padding=1,
                groups=4), nn.ReLU(inplace=True),
            nn.Conv1d(
                self.hidden_dim_1d,
                self.hidden_dim_1d,
                kernel_size=3,
                padding=1,
                groups=4), nn.ReLU(inplace=True))

        # Temporal Evaluation Module
        self.x_1d_s = nn.Sequential(
            nn.Conv1d(
                self.hidden_dim_1d,
                self.hidden_dim_1d,
                kernel_size=3,
                padding=1,
                groups=4), nn.ReLU(inplace=True),
            nn.Conv1d(self.hidden_dim_1d, 1, kernel_size=1), nn.Sigmoid())
        self.x_1d_e = nn.Sequential(
            nn.Conv1d(
                self.hidden_dim_1d,
                self.hidden_dim_1d,
                kernel_size=3,
                padding=1,
                groups=4), nn.ReLU(inplace=True),
            nn.Conv1d(self.hidden_dim_1d, 1, kernel_size=1), nn.Sigmoid())

        # Proposal Evaluation Module
        self.x_1d_p = nn.Sequential(
            nn.Conv1d(
                self.hidden_dim_1d,
                self.hidden_dim_1d,
                kernel_size=3,
                padding=1), nn.ReLU(inplace=True))
        self.x_3d_p = nn.Sequential(
            nn.Conv3d(
                self.hidden_dim_1d,
                self.hidden_dim_3d,
                kernel_size=(self.num_samples, 1, 1)), nn.ReLU(inplace=True))
        self.x_2d_p = nn.Sequential(
            nn.Conv2d(self.hidden_dim_3d, self.hidden_dim_2d, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                self.hidden_dim_2d,
                self.hidden_dim_2d,
                kernel_size=3,
                padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(
                self.hidden_dim_2d,
                self.hidden_dim_2d,
                kernel_size=3,
                padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(self.hidden_dim_2d, 2, kernel_size=1), nn.Sigmoid())
        self.anchors_tmins, self.anchors_tmaxs = self._temporal_anchors(
            -0.5, 1.5)
        self.match_map = self._match_map()
        # self.bm_mask = self._get_bm_mask()
        self.register_buffer('bm_mask', self._get_bm_mask())

    def init_weights(self) -> None:
        """Initiate the parameters from scratch."""
        pass

    def forward(self, inputs, data_samples, mode, **kwargs):
        """The unified entry for a forward process in both training and test.

        The method should accept three modes:

        - ``tensor``: Forward the whole network and return tensor or tuple of
        tensor without any post-processing, same as a common nn.Module.
        - ``predict``: Forward and return the predictions, which are fully
        processed to a list of :obj:`ActionDataSample`.
        - ``loss``: Forward and return a dict of losses according to the given
        inputs and data samples.

        Note that this method doesn't handle neither back propagation nor
        optimizer updating, which are done in the :meth:`train_step`.

        Args:
            inputs (Tensor): The input tensor with shape
                (N, C, ...) in general.
            data_samples (List[:obj:`ActionDataSample`], optional): The
                annotation data of every samples. Defaults to None.
            mode (str): Return what kind of value. Defaults to ``tensor``.

        Returns:
            The return type depends on ``mode``.

            - If ``mode="tensor"``, return a tensor or a tuple of tensor.
            - If ``mode="predict"``, return a list of ``ActionDataSample``.
            - If ``mode="loss"``, return a dict of tensor.
        """
        inputs = torch.stack(inputs)
        if mode == 'tensor':
            return self._forward(inputs, **kwargs)
        if mode == 'predict':
            return self.predict(inputs, data_samples, **kwargs)
        elif mode == 'loss':
            return self.loss(inputs, data_samples, **kwargs)
        else:
            raise RuntimeError(f'Invalid mode "{mode}". '
                               'Only supports loss, predict and tensor mode')

    def loss(self, batch_inputs, batch_data_samples, **kwargs):
        """Calculate losses from a batch of inputs and data samples.

        Args:
            batch_inputs (Tensor): Raw Inputs of the recognizer.
                These should usually be mean centered and std scaled.
            batch_data_samples (List[:obj:`ActionDataSample`]): The batch
                data samples. It usually includes information such
                as ``gt_labels``.

        Returns:
            dict: A dictionary of loss components.
        """
        gt_bbox = [
            sample.gt_instances['gt_bbox'] for sample in batch_data_samples
        ]
        label_confidence, label_start, label_end = self.generate_labels(
            gt_bbox)

        device = batch_inputs.device
        label_confidence = label_confidence.to(device)
        label_start = label_start.to(device)
        label_end = label_end.to(device)

        confidence_map, start, end = self._forward(batch_inputs)

        loss = self.loss_cls(confidence_map, start, end, label_confidence,
                             label_start, label_end, self.bm_mask)
        loss_dict = dict(loss=loss[0])
        return loss_dict

    def predict(self, batch_inputs, batch_data_samples, **kwargs):
        """Define the computation performed at every call when testing."""
        confidence_map, start, end = self._forward(batch_inputs)
        start_scores = start[0].cpu().numpy()
        end_scores = end[0].cpu().numpy()
        cls_confidence = (confidence_map[0][1]).cpu().numpy()
        reg_confidence = (confidence_map[0][0]).cpu().numpy()

        max_start = max(start_scores)
        max_end = max(end_scores)

        # generate the set of start points and end points
        start_bins = np.zeros(len(start_scores))
        start_bins[0] = 1  # [1,0,0...,0,0]
        end_bins = np.zeros(len(end_scores))
        end_bins[-1] = 1  # [0,0,0...,0,1]
        for idx in range(1, self.tscale - 1):
            if start_scores[idx] > start_scores[
                    idx + 1] and start_scores[idx] > start_scores[idx - 1]:
                start_bins[idx] = 1
            elif start_scores[idx] > (0.5 * max_start):
                start_bins[idx] = 1
            if end_scores[idx] > end_scores[
                    idx + 1] and end_scores[idx] > end_scores[idx - 1]:
                end_bins[idx] = 1
            elif end_scores[idx] > (0.5 * max_end):
                end_bins[idx] = 1

        # iterate through all combinations of start_index and end_index
        new_proposals = []
        for idx in range(self.tscale):
            for jdx in range(self.tscale):
                start_index = jdx
                end_index = start_index + idx + 1
                if end_index < self.tscale and start_bins[
                        start_index] == 1 and end_bins[end_index] == 1:
                    tmin = start_index / self.tscale
                    tmax = end_index / self.tscale
                    tmin_score = start_scores[start_index]
                    tmax_score = end_scores[end_index]
                    cls_score = cls_confidence[idx, jdx]
                    reg_score = reg_confidence[idx, jdx]
                    score = tmin_score * tmax_score * cls_score * reg_score
                    new_proposals.append([
                        tmin, tmax, tmin_score, tmax_score, cls_score,
                        reg_score, score
                    ])
        new_proposals = np.stack(new_proposals)
        video_info = batch_data_samples[0].metainfo
        proposal_list = post_processing(new_proposals, video_info,
                                        self.soft_nms_alpha,
                                        self.soft_nms_low_threshold,
                                        self.soft_nms_high_threshold,
                                        self.post_process_top_k,
                                        self.feature_extraction_interval)
        output = [
            dict(
                video_name=video_info['video_name'],
                proposal_list=proposal_list)
        ]
        return output

    @staticmethod
    def _get_interp1d_bin_mask(seg_tmin, seg_tmax, tscale, num_samples,
                               num_samples_per_bin):
        """Generate sample mask for a boundary-matching pair."""
        plen = float(seg_tmax - seg_tmin)
        plen_sample = plen / (num_samples * num_samples_per_bin - 1.0)
        total_samples = [
            seg_tmin + plen_sample * i
            for i in range(num_samples * num_samples_per_bin)
        ]
        p_mask = []
        for idx in range(num_samples):
            bin_samples = total_samples[idx * num_samples_per_bin:(idx + 1) *
                                        num_samples_per_bin]
            bin_vector = np.zeros(tscale)
            for sample in bin_samples:
                sample_upper = math.ceil(sample)
                sample_decimal, sample_down = math.modf(sample)
                if 0 <= int(sample_down) <= (tscale - 1):
                    bin_vector[int(sample_down)] += 1 - sample_decimal
                if 0 <= int(sample_upper) <= (tscale - 1):
                    bin_vector[int(sample_upper)] += sample_decimal
            bin_vector = 1.0 / num_samples_per_bin * bin_vector
            p_mask.append(bin_vector)
        p_mask = np.stack(p_mask, axis=1)
        return p_mask

    def _get_interp1d_mask(self):
        """Generate sample mask for each point in Boundary-Matching Map."""
        mask_mat = []
        for start_index in range(self.tscale):
            mask_mat_vector = []
            for duration_index in range(self.tscale):
                if start_index + duration_index < self.tscale:
                    p_tmin = start_index
                    p_tmax = start_index + duration_index
                    center_len = float(p_tmax - p_tmin) + 1
                    sample_tmin = p_tmin - (center_len * self.boundary_ratio)
                    sample_tmax = p_tmax + (center_len * self.boundary_ratio)
                    p_mask = self._get_interp1d_bin_mask(
                        sample_tmin, sample_tmax, self.tscale,
                        self.num_samples, self.num_samples_per_bin)
                else:
                    p_mask = np.zeros([self.tscale, self.num_samples])
                mask_mat_vector.append(p_mask)
            mask_mat_vector = np.stack(mask_mat_vector, axis=2)
            mask_mat.append(mask_mat_vector)
        mask_mat = np.stack(mask_mat, axis=3)
        mask_mat = mask_mat.astype(np.float32)
        self.sample_mask = nn.Parameter(
            torch.tensor(mask_mat).view(self.tscale, -1), requires_grad=False)

    def _get_bm_mask(self):
        """Generate Boundary-Matching Mask."""
        bm_mask = []
        for idx in range(self.tscale):
            mask_vector = [1] * (self.tscale - idx) + [0] * idx
            bm_mask.append(mask_vector)
        bm_mask = torch.tensor(bm_mask, dtype=torch.float)
        return bm_mask

    def _match_map(self):
        """Generate match map."""
        temporal_gap = 1. / self.tscale
        match_map = []
        for idx in range(self.tscale):
            match_window = []
            tmin = temporal_gap * idx
            for jdx in range(1, self.tscale + 1):
                tmax = tmin + temporal_gap * jdx
                match_window.append([tmin, tmax])
            match_map.append(match_window)
        match_map = np.array(match_map)
        match_map = np.transpose(match_map, [1, 0, 2])
        match_map = np.reshape(match_map, [-1, 2])
        return match_map

    def _temporal_anchors(self, tmin_offset=0., tmax_offset=1.):
        """Generate temporal anchors.

        Args:
            tmin_offset (int): Offset for the minimum value of temporal anchor.
                Default: 0.
            tmax_offset (int): Offset for the maximum value of temporal anchor.
                Default: 1.
        Returns:
            tuple[Sequence[float]]: The minimum and maximum values of temporal
                anchors.
        """
        temporal_gap = 1. / self.tscale
        anchors_tmins = []
        anchors_tmaxs = []
        for i in range(self.tscale):
            anchors_tmins.append(temporal_gap * (i + tmin_offset))
            anchors_tmaxs.append(temporal_gap * (i + tmax_offset))

        return anchors_tmins, anchors_tmaxs

    def _forward(self, x):
        """Define the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.
        Returns:
            torch.Tensor: The output of the module.
        """
        # x.shape [batch_size, self.feat_dim, self.tscale]
        base_feature = self.x_1d_b(x)
        # base_feature.shape [batch_size, self.hidden_dim_1d, self.tscale]
        start = self.x_1d_s(base_feature).squeeze(1)
        # start.shape [batch_size, self.tscale]
        end = self.x_1d_e(base_feature).squeeze(1)
        # end.shape [batch_size, self.tscale]
        confidence_map = self.x_1d_p(base_feature)
        # [batch_size, self.hidden_dim_1d, self.tscale]
        confidence_map = self._boundary_matching_layer(confidence_map)
        # [batch_size, self.hidden_dim_1d,, self.num_sampls, self.tscale, self.tscale] # noqa
        confidence_map = self.x_3d_p(confidence_map).squeeze(2)
        # [batch_size, self.hidden_dim_3d, self.tscale, self.tscale]
        confidence_map = self.x_2d_p(confidence_map)
        # [batch_size, 2, self.tscale, self.tscale]

        return confidence_map, start, end

    def _boundary_matching_layer(self, x):
        """Generate matching layer."""
        input_size = x.size()
        out = torch.matmul(x,
                           self.sample_mask).reshape(input_size[0],
                                                     input_size[1],
                                                     self.num_samples,
                                                     self.tscale, self.tscale)
        return out

    def generate_labels(self, gt_bbox):
        """Generate training labels."""
        # TODO: do this without numpy
        match_score_confidence_list = []
        match_score_start_list = []
        match_score_end_list = []
        for every_gt_bbox in gt_bbox:
            gt_iou_map = []
            every_gt_bbox = every_gt_bbox.cpu()
            for start, end in every_gt_bbox:
                if isinstance(start, torch.Tensor):
                    start = start.numpy()
                if isinstance(end, torch.Tensor):
                    end = end.numpy()
                current_gt_iou_map = temporal_iou(self.match_map[:, 0],
                                                  self.match_map[:, 1], start,
                                                  end)
                current_gt_iou_map = np.reshape(current_gt_iou_map,
                                                [self.tscale, self.tscale])
                gt_iou_map.append(current_gt_iou_map)
            gt_iou_map = np.array(gt_iou_map).astype(np.float32)
            gt_iou_map = np.max(gt_iou_map, axis=0)

            gt_tmins = every_gt_bbox[:, 0]
            gt_tmaxs = every_gt_bbox[:, 1]

            gt_len_pad = 3 * (1. / self.tscale)

            gt_start_bboxs = np.stack(
                (gt_tmins - gt_len_pad / 2, gt_tmins + gt_len_pad / 2), axis=1)
            gt_end_bboxs = np.stack(
                (gt_tmaxs - gt_len_pad / 2, gt_tmaxs + gt_len_pad / 2), axis=1)

            match_score_start = []
            match_score_end = []

            for anchor_tmin, anchor_tmax in zip(self.anchors_tmins,
                                                self.anchors_tmaxs):
                match_score_start.append(
                    np.max(
                        temporal_iop(anchor_tmin, anchor_tmax,
                                     gt_start_bboxs[:, 0], gt_start_bboxs[:,
                                                                          1])))
                match_score_end.append(
                    np.max(
                        temporal_iop(anchor_tmin, anchor_tmax,
                                     gt_end_bboxs[:, 0], gt_end_bboxs[:, 1])))
            match_score_confidence_list.append(gt_iou_map)
            match_score_start_list.append(match_score_start)
            match_score_end_list.append(match_score_end)

        def to_tensor(x):
            return torch.Tensor(np.array(x))

        match_score_confidence_list = to_tensor(match_score_confidence_list)
        match_score_start_list = to_tensor(match_score_start_list)
        match_score_end_list = to_tensor(match_score_end_list)
        return (match_score_confidence_list, match_score_start_list,
                match_score_end_list)
