# -*- coding: utf-8 -*-
# Author: Runsheng Xu <rxx3386@ucla.edu>
# License: TDG-Attribution-NonCommercial-NoDistrib


import torch
import numpy as np
import random

from einops import rearrange
from opencood.utils.common_utils import torch_tensor_to_numpy


def regroup(dense_feature, record_len, max_len, s_mask= None, cav_list = None, scene_info=None):
    """
    Regroup the data based on the record_len.

    Parameters
    ----------
    dense_feature : torch.Tensor
        N, C, H, W
    record_len : list
        [sample1_len, sample2_len, ...]
    max_len : int
        Maximum cav number

    Returns
    -------
    regroup_feature : torch.Tensor
        B, L, C, H, W
    """
    cum_sum_len = list(np.cumsum(torch_tensor_to_numpy(record_len)))
    split_features = torch.tensor_split(dense_feature,
                                        cum_sum_len[:-1])
    
    if s_mask is not None:
        split_s_masks = torch.tensor_split(s_mask,
                                        cum_sum_len[:-1])
    regroup_features = []
    regroup_s_mask = []
    mask = []
    
    assert dense_feature.shape[0] == cum_sum_len[-1], f"Tensor split is error! {dense_feature.shape[0]} / {cum_sum_len[-1]}, {scene_info} @ {cav_list}"

    for i in range(len(split_features)):
        # M, C, H, W
        split_feature = split_features[i]
        
        feature_shape = split_feature.shape
        
        padding_len = max_len - feature_shape[0]
        mask_ = [1] * feature_shape[0] + [0] * padding_len
        
        # if max_mask_aug > 0 and training_flag:
        #     n_aug = random.randrange(0, min(max_mask_aug, max_len))
            
        #     if n_aug > 0:
        #         rand_nums = random.sample(range(1, max_len), n_aug)
        #         for rand_num in rand_nums:
        #             mask_[rand_num] = 0
        
        mask.append(mask_)

        padding_tensor = torch.zeros(padding_len, feature_shape[1],
                                     feature_shape[2], feature_shape[3])
        padding_tensor = padding_tensor.to(split_feature.device)

        split_feature = torch.cat([split_feature, padding_tensor],
                                  dim=0)
        
        if s_mask is not None:
            split_s_mask = split_s_masks[i]
            s_make_shape = split_s_mask.shape
            padding_tensor = torch.zeros(padding_len, s_make_shape[1], s_make_shape[2], dtype=torch.bool)
            padding_tensor = padding_tensor.to(split_feature.device)
            split_s_mask = torch.cat([split_s_mask, padding_tensor], dim=0)
            regroup_s_mask.append(split_s_mask.unsqueeze(0))

        # 1, 5C, H, W
        split_feature = split_feature.view(-1,
                                           feature_shape[2],
                                           feature_shape[3]).unsqueeze(0)
        regroup_features.append(split_feature)


    # B, 5C, H, W
    regroup_features = torch.cat(regroup_features, dim=0)
    # B, L, C, H, W
    regroup_features = rearrange(regroup_features,
                                 'b (l c) h w -> b l c h w',
                                 l=max_len)
    mask = torch.from_numpy(np.array(mask)).to(regroup_features.device)
    
    if s_mask is not None:
        regroup_s_mask = torch.cat(regroup_s_mask, dim=0)
        return regroup_features, mask, regroup_s_mask

    return regroup_features, mask
