# -*- coding: utf-8 -*-

import math

import torch.nn as nn

from opencood.models.fuse_modules.self_attn import AttFusion
from opencood.models.backbones.pixor import Bottleneck, BackBone, conv3x3
from opencood.models.sub_modules.naive_decoder import NaiveDecoder

import matplotlib.pyplot as plt
import seaborn as sns

class BackBoneIntermediate(BackBone):

    def __init__(self, block, num_block, geom, use_bn=True):
        super(BackBoneIntermediate, self).__init__(block,
                                                   num_block,
                                                   geom, use_bn)

        self.fusion_net3 = AttFusion(196)
        self.fusion_net4 = AttFusion(256)
        self.fusion_net5 = AttFusion(384)

    def forward(self, x, record_len):
        c3, c4, c5 = self.encode(x)

        c5 = self.fusion_net5(c5, record_len)
        c4 = self.fusion_net4(c4, record_len)
        c3 = self.fusion_net3(c3, record_len)

        p4 = self.decode(c3, c4, c5)
        return p4

class Header(nn.Module):

    def __init__(self, args):
        super(Header, self).__init__()

        self.use_bn = args["use_bn"]
        bias = not self.use_bn
        dim = args["feature_dim"]
        self.dataset = args["dataset"]
        self.conv1 = conv3x3(dim, dim, bias=bias)
        self.bn1 = nn.BatchNorm2d(dim)
        self.conv2 = conv3x3(dim, dim, bias=bias)
        self.bn2 = nn.BatchNorm2d(dim)
        self.conv3 = conv3x3(dim, dim, bias=bias)
        self.bn3 = nn.BatchNorm2d(dim)
        self.conv4 = conv3x3(dim, dim, bias=bias)
        self.bn4 = nn.BatchNorm2d(dim)

        if self.dataset == 'V2XReal':
            self.cls_head = conv3x3(dim, 2, bias=True)
            self.reg_head = conv3x3(dim, 6, bias=True)
        elif self.dataset == 'V2XSim':
            self.decoder = NaiveDecoder(args['decoder'])
            self.seg_head = nn.Conv2d(args['seg_dim'], 4,
                                      kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        if self.use_bn:
            x = self.bn1(x)
        x = self.conv2(x)
        if self.use_bn:
            x = self.bn2(x)
        x = self.conv3(x)
        if self.use_bn:
            x = self.bn3(x)
        x = self.conv4(x)
        if self.use_bn:
            x = self.bn4(x)

        if self.dataset == 'V2XReal':
            cls = self.cls_head(x)
            reg = self.reg_head(x)
            return cls, reg
        elif self.dataset == 'V2XSim':
            x = x.unsqueeze(0)
            x = self.decoder(x)
            x = x.squeeze(0)

            seg = self.seg_head(x)
            return seg
        return None

class PIXORIntermediate(nn.Module):
    """
    The Pixor backbone. The input of PIXOR nn module is a tensor of
    [batch_size, height, weight, channel], The output of PIXOR nn module
    is also a tensor of [batch_size, height/4, weight/4, channel].  Note that
     we convert the dimensions to [C, H, W] for PyTorch's nn.Conv2d functions

    Parameters
    ----------
    args : dict
        The arguments of the model.

    Attributes
    ----------
    backbone : opencood.object
        The backbone used to extract features.
    header : opencood.object
        Header used to predict the classification and coordinates.
    """

    def __init__(self, args):
        super(PIXORIntermediate, self).__init__()
        geom = args["geometry_param"]
        use_bn = args["header"]["use_bn"]
        self.dataset = args["dataset"]
        self.backbone = BackBoneIntermediate(Bottleneck, [3, 6, 6, 3],
                                             geom,
                                             use_bn)
        self.header = Header(args["header"])

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        prior = 0.01
        if self.dataset == 'V2XReal':
            self.header.cls_head.weight.data.fill_(-math.log((1.0 - prior) / prior))
            self.header.cls_head.bias.data.fill_(0)
            self.header.reg_head.weight.data.fill_(0)
            self.header.reg_head.bias.data.fill_(0)
        elif self.dataset == 'V2XSim':
            self.header.seg_head.weight.data.fill_(-math.log((1.0 - prior) / prior))
            self.header.seg_head.bias.data.fill_(0)

    def forward(self, data_dict):
        bev_input = data_dict['processed_lidar']["bev_input"]
        record_len = data_dict['record_len']

        features = self.backbone(bev_input, record_len)
        # cls -- (N, 1, W/4, L/4)
        # reg -- (N, 6, W/4, L/4)

        if self.dataset == 'V2XReal':
            cls, reg = self.header(features)

            output_dict = {
                "cls": cls,
                "reg": reg,
                "cmt_loss": None
            }
        elif self.dataset == 'V2XSim':
            seg = self.header(features)

            output_dict = {'seg': seg,
                           'cmt_loss': None}
        return output_dict
