# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
from einops import rearrange, repeat

from opencood.models.backbones.resnet_ms import ResnetEncoder
from opencood.models.sub_modules.fax_modules import FAXModule
from opencood.models.fuse_modules.swap_multiscale_fusion_modules import SwapMultiscaleFusionEncoder
from opencood.models.sub_modules.naive_decoder import NaiveDecoder
from opencood.models.sub_modules.torch_transformation_utils import \
    get_transformation_matrix, warp_affine, get_discretized_transformation_matrix
from opencood.models.backbones.pixor import Bottleneck, BackBone, conv3x3
from opencood.models.fuse_modules.fuse_utils import regroup
from opencood.models.sub_modules.downsample_conv import DownsampleConv
from opencood.models.comm_modules.residual_vq import ResidualVQ

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

class STTF(nn.Module):
    def __init__(self, args):
        super(STTF, self).__init__()
        self.discrete_ratio = args['resolution']
        self.downsample_rate = args['downsample_rate']

    def forward(self, x, spatial_correction_matrix, change_order=True):
        """
        Transform the bev features to ego space.

        Parameters
        ----------
        x : torch.Tensor
            B L C H W
        spatial_correction_matrix : torch.Tensor
            Transformation matrix to ego

        Returns
        -------
        The bev feature same shape as x but with transformation
        """
        dist_correction_matrix = get_discretized_transformation_matrix(
            spatial_correction_matrix, self.discrete_ratio,
            self.downsample_rate)

        # transpose and flip to make the transformation correct
        x = rearrange(x, 'b l c h w  -> b l c w h')
        x = torch.flip(x, dims=(4,))
        # Only compensate non-ego vehicles
        B, L, C, H, W = x.shape

        T = get_transformation_matrix(
            dist_correction_matrix[:, :, :, :].reshape(-1, 2, 3), (H, W))
        cav_features = warp_affine(x[:, :, :, :, :].reshape(-1, C, H, W), T,
                                   (H, W))
        cav_features = cav_features.reshape(B, -1, C, H, W)

        # flip and transpose back
        x = cav_features
        x = torch.flip(x, dims=(4,))
        if change_order:
            x = rearrange(x, 'b l c w h -> b l c h w')

        return x

class ComV2i(nn.Module):
    def __init__(self, args):
        super(ComV2i, self).__init__()
        self.max_cav = args['max_cav']
        self.rsu_num = args['rsu_num']
        self.dataset = args['dataset']
        args['fusion_net']['rsu_num'] = self.rsu_num

        # Lidar backbone created
        geom = args["geometry_param"]
        use_bn = args["use_bn"]
        self.lidar_backbone = BackBone(Bottleneck, [3, 6, 6, 3],
                                             geom,
                                             use_bn)
        # Camera backbone created
        self.camera_backbone = ResnetEncoder(args['camera_encoder'])

        # Down sample the lidar features
        self.downsample_conv = DownsampleConv(args['downsample_conv'])

        # spatial feature transform module
        self.downsample_rate = args['sttf']['downsample_rate']
        self.discrete_ratio = args['sttf']['resolution']
        self.use_roi_mask = args['sttf']['use_roi_mask']
        self.sttf = STTF(args['sttf'])

        # cvm params
        fax_params = args['fax']
        fax_params['backbone_output_shape'] = self.camera_backbone.output_shapes
        self.fax = FAXModule(fax_params)

        if 'residual_vq' in args:
            self.residual_vq = ResidualVQ(dim=args['residual_vq']['input_dim'],
                                          accept_image_fmap=args['residual_vq']['accept_image_fmap'],
                                          codebook_size_ls=args['residual_vq']['codebook_size_ls'],
                                          num_quantizers=args['residual_vq']['num_quantizers'])

        self.fusion_net = SwapMultiscaleFusionEncoder(args['fusion_net'])

        if self.dataset == 'V2XReal':
            self.cls_head = nn.Conv2d(args['head_dim'], 2,
                                      kernel_size=3, padding=1)
            self.reg_head = nn.Conv2d(args['head_dim'], 6,
                                      kernel_size=3, padding=1)
        elif self.dataset == 'V2XSim':
            self.decoder = NaiveDecoder(args['decoder'])
            self.seg_head = nn.Conv2d(args['head_dim'], 4,
                                      kernel_size=3, padding=1)

    def regroup(self, x, record_len):
        cum_sum_len = torch.cumsum(record_len, dim=0)
        split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu())
        return list(split_x)

    def forward(self, data_dict, return_img_features=False):
        # Handling lidar data

        transformation_matrix = data_dict['transformation_matrix']

        bev_input = data_dict['processed_lidar']["bev_input"]
        record_len = data_dict['record_len']

        pts_features = self.lidar_backbone(bev_input)

        # Handling camera data
        img_data = self.camera_backbone(data_dict['cav_camera'])
        img_dict = {'inputs': data_dict['cav_camera'],
                    'extrinsic': data_dict['cav_extrinsic'],
                    'intrinsic': data_dict['cav_intrinsic'],
                    'features': img_data
                    }
        img_features = self.fax(img_dict)
        img_features = img_features.squeeze(1)

        if return_img_features:
            return img_features.detach()

        try:
            rec_img_features, indices, cmt_loss_ls, sum_dists = self.residual_vq(img_features)
            cmt_loss = cmt_loss_ls.sum()
        except:
            rec_img_features = img_features
            cmt_loss = None

        # Downsample lidar features
        _, downsample_pts = self.downsample_conv(pts_features, return_all_feature=True)

        # Regroup camera and lidar features
        rec_img_features, mask = regroup(rec_img_features, record_len - self.rsu_num, self.max_cav - self.rsu_num)
        downsample_pts = [regroup(pts_layer, record_len - (record_len - self.rsu_num), self.rsu_num)[0] for pts_layer in downsample_pts]

        # Get rsu and cav transformation matrix splitly
        rsu_tf_matrix = transformation_matrix[:, :self.rsu_num, :, :]
        cav_tf_matrix = transformation_matrix[:, self.rsu_num:, :, :]
        downsample_pts = [self.sttf(pts_layer, rsu_tf_matrix, change_order=True) for pts_layer in downsample_pts]
        rec_img_features = self.sttf(rec_img_features, cav_tf_matrix, change_order=True)

        # (B, N, C, H, W)
        # if not data_dict['agent_avaliable_masks'][0]:
        #     rec_img_features = torch.zeros_like(rec_img_features)
        outputs = self.fusion_net(downsample_pts, rec_img_features)

        if self.dataset == 'V2XReal':
            cls = self.cls_head(outputs)
            reg = self.reg_head(outputs)

            output_dict = {'cls': cls,
                           'reg': reg,
                           'cmt_loss': cmt_loss}
        elif self.dataset == 'V2XSim':
            outputs = outputs.unsqueeze(0)
            outputs = self.decoder(outputs)
            outputs = outputs.squeeze(0)

            seg = self.seg_head(outputs)

            output_dict = {'seg': seg,
                           'cmt_loss': cmt_loss}

        return output_dict