import numpy as np
import torch
import torch.nn as nn
import random
from ...utils import dual_norm
from ...utils.loss_utils import FocalLoss
from ...utils.lovasz_softmax import lovasz_softmax_flat

class BaseBEVBackboneOccBev(nn.Module):
    def __init__(self, model_cfg, input_channels):
        super().__init__()
        self.model_cfg = model_cfg
        if self.model_cfg.get('DUAL_NORM', None):
            self.db_source = int(self.model_cfg.db_source)
        if self.model_cfg.get('LAYER_NUMS', None) is not None:
            assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)
            layer_nums = self.model_cfg.LAYER_NUMS
            layer_strides = self.model_cfg.LAYER_STRIDES
            num_filters = self.model_cfg.NUM_FILTERS
        else:
            layer_nums = layer_strides = num_filters = []

        if self.model_cfg.get('UPSAMPLE_STRIDES', None) is not None:
            assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
            num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERS
            upsample_strides = self.model_cfg.UPSAMPLE_STRIDES
        else:
            upsample_strides = num_upsample_filters = []

        num_levels = len(layer_nums)
        c_in_list = [input_channels, *num_filters[:-1]]
        self.blocks = nn.ModuleList()
        self.deblocks = nn.ModuleList()
        
        for idx in range(num_levels):
            cur_layers = [
                nn.ZeroPad2d(1),
                nn.Conv2d(
                    c_in_list[idx], num_filters[idx], kernel_size=3,
                    stride=layer_strides[idx], padding=0, bias=False
                ),
                nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
                nn.ReLU()
            ]
            for k in range(layer_nums[idx]):
                cur_layers.extend([
                    nn.Conv2d(num_filters[idx], num_filters[idx], kernel_size=3, padding=1, bias=False),
                    nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
                    nn.ReLU()
                ])
            self.blocks.append(nn.Sequential(*cur_layers))
            if len(upsample_strides) > 0:
                stride = upsample_strides[idx]
                if stride >= 1:
                    self.deblocks.append(nn.Sequential(
                        nn.ConvTranspose2d(
                            num_filters[idx], num_upsample_filters[idx],
                            upsample_strides[idx],
                            stride=upsample_strides[idx], bias=False
                        ),
                        nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
                        nn.ReLU()
                    ))
                else:
                    stride = np.round(1 / stride).astype(np.int)
                    self.deblocks.append(nn.Sequential(
                        nn.Conv2d(
                            num_filters[idx], num_upsample_filters[idx],
                            stride,
                            stride=stride, bias=False
                        ),
                        nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
                        nn.ReLU()
                    ))

        c_in = sum(num_upsample_filters)
        if len(upsample_strides) > num_levels:
            self.deblocks.append(nn.Sequential(
                nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], stride=upsample_strides[-1], bias=False),
                nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01),
                nn.ReLU(),
            ))

        self.num_bev_features = c_in

        # [188, 188] -> [376, 376]
        self.inv_conv4 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 3, padding=1, output_padding=1, stride=2, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        # [376, 376] -> [752, 752]
        self.inv_conv3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, padding=1, output_padding=1, stride=2, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        
        # [[752, 752] -> [1504, 1504]
        self.inv_conv2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, padding=1, output_padding=1, stride=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        self.logit = nn.Conv2d(64, 16, 1, bias=True)
        self.criterion1 = nn.CrossEntropyLoss(weight=torch.tensor([1,2,2,1,2,1,1,1,2,2,1,1,1,1,1,0.01]))
        self.forward_re_dict = {}

    def get_loss(self, tb_dict=None):

        tb_dict = {} if tb_dict is None else tb_dict
        pred = self.forward_re_dict['pred']
        target = self.forward_re_dict['target']

        loss = self.criterion1(pred, target)  + lovasz_softmax_flat(probas=nn.functional.softmax(pred), labels=target)

        tb_dict = {
            'loss_occ': loss.item()
        }

        return loss, tb_dict

    def forward(self, data_dict):
        """
        Args:
            data_dict:
                spatial_features
        Returns:
        """
        spatial_features = data_dict['spatial_features']
        ups = []
        ret_dict = {}
        x = spatial_features
        for i in range(len(self.blocks)):
            x = self.blocks[i](x)

            stride = int(spatial_features.shape[2] / x.shape[2])
            ret_dict['spatial_features_%dx' % stride] = x
            if len(self.deblocks) > 0:
                ups.append(self.deblocks[i](x))
            else:
                ups.append(x)

        if len(ups) > 1:
            x = torch.cat(ups, dim=1)
        elif len(ups) == 1:
            x = ups[0]
        
        if len(self.deblocks) > len(self.blocks):
            x = self.deblocks[-1](x)

        data_dict['spatial_features_2d'] = x

        x_up4 = self.inv_conv4(x)
        x_up3 = self.inv_conv3(x_up4)
        x_up2 = self.inv_conv2(x_up3)

        output = self.logit(x_up2)

        # pred_save = nn.functional.softmax(output[0].permute(2,1,0), -1)
        # a, _  = torch.max(pred_save,-1)
        # ths_free = torch.where(a<0.3)
        # pred_save_cls = torch.argmax(pred_save, -1)
        # # print(ths_free)
        # pred_save_cls[ths_free[0][:], ths_free[1][:]] = 15

        # np.save("occ_val_2000_2972_300_2992_300/%d.npy" % data_dict['sample_idx'][0], pred_save_cls.cpu().numpy())
        # np.save("occ_val_2000_2972_300_2992_300/%d.npy" % data_dict['sample_idx'][0], pred_save_cls.cpu().numpy())

        self.forward_re_dict['target'] = data_dict['occ_gt'].permute(0,2,1).flatten().long()
        self.forward_re_dict['pred'] = torch.flatten(output.permute(0,2,3,1), start_dim=0, end_dim=-2)

        data_dict.update({'pred': self.forward_re_dict['pred'],
                            'target': self.forward_re_dict['target']})

        return data_dict
