# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
from einops import rearrange, repeat

from opencood.models.fuse_modules.self_attn import AttFusion
from opencood.models.backbones.resnet_ms import ResnetEncoder
from opencood.models.sub_modules.fax_modules import FAXModule
from opencood.models.fuse_modules.swap_fusion_modules import SwapFusionEncoder
from opencood.models.sub_modules.naive_decoder import NaiveDecoder
from opencood.models.sub_modules.torch_transformation_utils import \
    get_transformation_matrix, warp_affine, get_discretized_transformation_matrix
from opencood.models.fuse_modules.fuse_utils import regroup

import matplotlib.pyplot as plt
import seaborn as sns

class STTF(nn.Module):
    def __init__(self, args):
        super(STTF, self).__init__()
        self.discrete_ratio = args['resolution']
        self.downsample_rate = args['downsample_rate']

    def forward(self, x, spatial_correction_matrix, change_order=True):
        """
        Transform the bev features to ego space.

        Parameters
        ----------
        x : torch.Tensor
            B L C H W
        spatial_correction_matrix : torch.Tensor
            Transformation matrix to ego

        Returns
        -------
        The bev feature same shape as x but with transformation
        """
        dist_correction_matrix = get_discretized_transformation_matrix(
            spatial_correction_matrix, self.discrete_ratio,
            self.downsample_rate)

        # transpose and flip to make the transformation correct
        x = rearrange(x, 'b l c h w  -> b l c w h')
        x = torch.flip(x, dims=(4,))
        # Only compensate non-ego vehicles
        B, L, C, H, W = x.shape

        T = get_transformation_matrix(
            dist_correction_matrix[:, :, :, :].reshape(-1, 2, 3), (H, W))
        cav_features = warp_affine(x[:, :, :, :, :].reshape(-1, C, H, W), T,
                                   (H, W))
        cav_features = cav_features.reshape(B, -1, C, H, W)

        # flip and transpose back
        x = cav_features
        x = torch.flip(x, dims=(4,))
        if change_order:
            x = rearrange(x, 'b l c w h -> b l c h w')

        return x

class CameraOnly(nn.Module):
    def __init__(self, args):
        super(CameraOnly, self).__init__()
        self.max_cav = args['max_cav']
        self.rsu_num = args['rsu_num']
        self.dataset = args['dataset']
        args['fusion_net']['rsu_num'] = self.rsu_num

        # Camera backbone created
        self.camera_backbone = ResnetEncoder(args['camera_encoder'])

        # cvm params
        fax_params = args['fax']
        fax_params['backbone_output_shape'] = self.camera_backbone.output_shapes
        self.fax = FAXModule(fax_params)

        # spatial feature transform module
        self.downsample_rate = args['sttf']['downsample_rate']
        self.discrete_ratio = args['sttf']['resolution']
        self.use_roi_mask = args['sttf']['use_roi_mask']
        self.sttf = STTF(args['sttf'])
        self.max_cav=args['max_cav']

        # used to downsample the feature map for efficient computation
        self.fusion_net = SwapFusionEncoder(args['fusion_net'])

        if self.dataset == 'V2XReal':
            self.det_decoder = NaiveDecoder(args['det_decoder'])
            self.cls_head = nn.Conv2d(args['head_dim'], 2,
                                      kernel_size=3, padding=1)
            self.reg_head = nn.Conv2d(args['head_dim'], 6,
                                      kernel_size=3, padding=1)
        elif self.dataset == 'V2XSim':
            self.seg_decoder = NaiveDecoder(args['seg_decoder'])
            self.seg_head = nn.Conv2d(args['head_dim'], 4,
                                      kernel_size=3, padding=1)

    def regroup(self, x, record_len):
        cum_sum_len = torch.cumsum(record_len, dim=0)
        split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu())
        return list(split_x)

    def forward(self, data_dict):
        '''Handling camera data'''
        transformation_matrix = data_dict['transformation_matrix']
        record_len = data_dict['record_len']

        cav_cam_feature = self.camera_backbone(data_dict['cav_camera'])
        rsu_cam_feature = self.camera_backbone(data_dict['rsu_camera'])

        cav_cam_dict = {'inputs': data_dict['cav_camera'],
                    'extrinsic': data_dict['cav_extrinsic'],
                    'intrinsic': data_dict['cav_intrinsic'],
                    'features': cav_cam_feature
                    }
        rsu_cam_dict = {'inputs': data_dict['rsu_camera'],
                    'extrinsic': data_dict['rsu_extrinsic'],
                    'intrinsic': data_dict['rsu_intrinsic'],
                    'features': rsu_cam_feature
                    }

        cav_cam_feature = self.fax(cav_cam_dict)
        rsu_cam_feature = self.fax(rsu_cam_dict)
        rsu_cam_feature, cav_cam_feature = rsu_cam_feature.squeeze(1), cav_cam_feature.squeeze(1)

        # Regroup camera features
        rsu_cam_feature, mask = regroup(rsu_cam_feature, record_len - (record_len - self.rsu_num), self.rsu_num)
        cav_cam_feature, mask = regroup(cav_cam_feature, record_len - self.rsu_num, self.max_cav - self.rsu_num)
        cam_feature = torch.cat([rsu_cam_feature, cav_cam_feature], dim=1)

        # STTF
        cam_feature = self.sttf(cam_feature, transformation_matrix, change_order=True)

        # Fusion net
        outputs = self.fusion_net(cam_feature)

        if self.dataset == 'V2XReal':
            outputs = outputs.unsqueeze(0)
            outputs = self.det_decoder(outputs)
            outputs = outputs.squeeze(0)

            cls = self.cls_head(outputs)
            reg = self.reg_head(outputs)

            output_dict = {'cls': cls,
                           'reg': reg,
                           'cmt_loss': None}
        elif self.dataset == 'V2XSim':
            outputs = outputs.unsqueeze(0)
            outputs = self.seg_decoder(outputs)
            outputs = outputs.squeeze(0)

            seg = self.seg_head(outputs)

            output_dict = {'seg': seg,
                           'cmt_loss': None}

        return output_dict