# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
from einops import rearrange, repeat

from opencood.models.backbones.resnet_ms import ResnetEncoder
from opencood.models.sub_modules.fax_modules import FAXModule
from opencood.models.fuse_modules.swap_multiscale_fusion_modules import SwapMultiscaleFusionEncoder
from opencood.models.sub_modules.naive_decoder import NaiveDecoder
from opencood.models.sub_modules.torch_transformation_utils import \
    get_transformation_matrix, warp_affine, get_discretized_transformation_matrix
from opencood.models.backbones.pixor import Bottleneck, BackBone, conv3x3
from opencood.models.fuse_modules.fuse_utils import regroup
from opencood.models.sub_modules.downsample_conv import DownsampleConv
from opencood.models.comm_modules.residual_vq import ResidualVQ

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

class STTF(nn.Module):
    def __init__(self, args):
        super(STTF, self).__init__()
        self.discrete_ratio = args['resolution']
        self.downsample_rate = args['downsample_rate']

    def forward(self, x, spatial_correction_matrix, change_order=True):
        """
        Transform the bev features to ego space.

        Parameters
        ----------
        x : torch.Tensor
            B L C H W
        spatial_correction_matrix : torch.Tensor
            Transformation matrix to ego

        Returns
        -------
        The bev feature same shape as x but with transformation
        """
        dist_correction_matrix = get_discretized_transformation_matrix(
            spatial_correction_matrix, self.discrete_ratio,
            self.downsample_rate)

        # transpose and flip to make the transformation correct
        x = rearrange(x, 'b l c h w  -> b l c w h')
        x = torch.flip(x, dims=(4,))
        # Only compensate non-ego vehicles
        B, L, C, H, W = x.shape

        T = get_transformation_matrix(
            dist_correction_matrix[:, :, :, :].reshape(-1, 2, 3), (H, W))
        cav_features = warp_affine(x[:, :, :, :, :].reshape(-1, C, H, W), T,
                                   (H, W))
        cav_features = cav_features.reshape(B, -1, C, H, W)

        # flip and transpose back
        x = cav_features
        x = torch.flip(x, dims=(4,))
        if change_order:
            x = rearrange(x, 'b l c w h -> b l c h w')

        return x

class BevFusion(nn.Module):
    def __init__(self, args):
        super(BevFusion, self).__init__()
        self.max_cav = args['max_cav']
        self.rsu_num = args['rsu_num']
        self.dataset = args['dataset']
        self.fusion_dim = args['fusion_dim']

        # Lidar backbone created
        geom = args["geometry_param"]
        use_bn = args["use_bn"]
        self.lidar_backbone = BackBone(Bottleneck, [3, 6, 6, 3],
                                             geom,
                                             use_bn)
        # Camera backbone created
        self.camera_backbone = ResnetEncoder(args['camera_encoder'])

        # Down sample the lidar features
        self.downsample_conv = DownsampleConv(args['downsample_conv'])

        # spatial feature transform module
        self.downsample_rate = args['sttf']['downsample_rate']
        self.discrete_ratio = args['sttf']['resolution']
        self.use_roi_mask = args['sttf']['use_roi_mask']
        self.sttf = STTF(args['sttf'])

        # cvm params
        fax_params = args['fax']
        fax_params['backbone_output_shape'] = self.camera_backbone.output_shapes
        self.fax = FAXModule(fax_params)

        self.residual_vq = ResidualVQ(dim=args['residual_vq']['input_dim'],
                                      accept_image_fmap=args['residual_vq']['accept_image_fmap'],
                                      codebook_size_ls=args['residual_vq']['codebook_size_ls'],
                                      num_quantizers=args['residual_vq']['num_quantizers'])

        self.img_decoder = NaiveDecoder(args['img_decoder'])

        self.fusion_net = nn.Sequential(
            nn.Conv2d(self.fusion_dim * self.max_cav , self.fusion_dim, 3, padding=1, bias=False),
            nn.BatchNorm2d(self.fusion_dim),
            nn.ReLU(True))

        if self.dataset == 'V2XReal':
            self.cls_head = nn.Conv2d(args['head_dim'], 2,
                                      kernel_size=3, padding=1)
            self.reg_head = nn.Conv2d(args['head_dim'], 6,
                                      kernel_size=3, padding=1)
        elif self.dataset == 'V2XSim':
            self.decoder = NaiveDecoder(args['decoder'])
            self.seg_head = nn.Conv2d(args['head_dim'], 4,
                                      kernel_size=3, padding=1)
    def regroup(self, x, record_len):
        cum_sum_len = torch.cumsum(record_len, dim=0)
        split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu())
        return list(split_x)

    def forward(self, data_dict, return_img_features=False):
        # Handling lidar data

        transformation_matrix = data_dict['transformation_matrix']

        bev_input = data_dict['processed_lidar']["bev_input"]
        record_len = data_dict['record_len']

        pts_features = self.lidar_backbone(bev_input)

        # Handling camera data
        img_data = self.camera_backbone(data_dict['cav_camera'])
        img_dict = {'inputs': data_dict['cav_camera'],
                    'extrinsic': data_dict['cav_extrinsic'],
                    'intrinsic': data_dict['cav_intrinsic'],
                    'features': img_data
                    }
        img_features = self.fax(img_dict)
        img_features = img_features.squeeze(1)

        if return_img_features:
            return img_features.detach()

        try:
            rec_img_features, indices, cmt_loss_ls, sum_dists = self.residual_vq(img_features)
            cmt_loss = cmt_loss_ls.sum()
        except:
            rec_img_features = img_features
            cmt_loss = None

        # Regroup camera and lidar features
        rec_img_features, mask = regroup(rec_img_features, record_len - self.rsu_num, self.max_cav - self.rsu_num)
        pts_features, mask = regroup(pts_features, record_len - (record_len - self.rsu_num), self.rsu_num)

        # Get rsu and cav transformation matrix splitly
        rsu_tf_matrix = transformation_matrix[:, :self.rsu_num, :, :]
        cav_tf_matrix = transformation_matrix[:, self.rsu_num:, :, :]
        pts_features = self.sttf(pts_features, rsu_tf_matrix, change_order=True)
        rec_img_features = self.sttf(rec_img_features, cav_tf_matrix, change_order=True)

        # (B, N, C, H, W)
        rec_img_features = self.img_decoder(rec_img_features)

        pts_features = rearrange(pts_features, 'b l c h w -> b (l c) h w')
        rec_img_features = rearrange(rec_img_features, 'b l c h w -> b (l c) h w')
        features = torch.cat([pts_features, rec_img_features], dim=1)
        outputs = self.fusion_net(features)

        if self.dataset == 'V2XReal':
            cls = self.cls_head(outputs)
            reg = self.reg_head(outputs)

            output_dict = {'cls': cls,
                           'reg': reg,
                           'cmt_loss': cmt_loss}
        elif self.dataset == 'V2XSim':
            outputs = outputs.unsqueeze(0)
            outputs = self.decoder(outputs)
            outputs = outputs.squeeze(0)

            seg = self.seg_head(outputs)

            output_dict = {'seg': seg,
                           'cmt_loss': cmt_loss}

        return output_dict