import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from pcdet.utils import common_utils

def conv3x3(in_planes, out_planes, pad, dilation, stride=1):
  "3x3 convolution with padding"
  return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
           padding=pad, dilation=dilation, bias=True)
def conv1x1(in_planes, out_planes, stride=1):
  "3x3 convolution with padding"
  return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
           padding=0, bias=False)

class domain_disc(nn.Module):
    def __init__(self,channels_number):
        super(domain_disc, self).__init__()
        self.conv1 = nn.Conv2d(channels_number, channels_number, kernel_size=1, stride=1,
                  padding=0, bias=False)
        self.conv2 = nn.Conv2d(channels_number, channels_number//2, kernel_size=1, stride=1,
                               padding=0, bias=False)
        self.conv3 = nn.Conv2d(channels_number//2, 1, kernel_size=1, stride=1,
                               padding=0, bias=False)
        self._init_weights()
    def _init_weights(self):
      def normal_init(m, mean, stddev, truncated=False):
        """
        weight initalizer: truncated normal and random normal.
        """
        # x is a parameter
        if truncated:
          m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean)  # not a perfect approximation
        else:
          m.weight.data.normal_(mean, stddev)
          #m.bias.data.zero_()
      normal_init(self.conv1, 0, 0.01)
      normal_init(self.conv2, 0, 0.01)
      normal_init(self.conv3, 0, 0.01)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

class SEBlock(nn.Module):
    """ Squeeze-and-excitation block """
    def __init__(self, channels, r=16):
        super(SEBlock, self).__init__()
        self.r = r 
        self.squeeze = nn.Sequential(nn.Linear(channels, channels//self.r),
                                     nn.ReLU(),
                                     nn.Linear(channels//self.r, channels),
                                     nn.Sigmoid())

    def forward(self, x):
        B, C, H, W = x.size()
        squeeze = self.squeeze(torch.mean(x, dim=(2,3))).view(B,C,1,1)
        return torch.mul(x, squeeze)

class domain_disc_3D(nn.Module):
    def __init__(self,channels_number):
        super(domain_disc_3D, self).__init__()
        c_in = channels_number
        c_out = channels_number // 4
        self.c_out = c_out
        self.se_s1 = SEBlock(c_in)
        self.block_1 = nn.Sequential(
            nn.Conv2d(c_in, c_out, kernel_size=2, stride=2, padding=2),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(c_out, c_out, kernel_size=2, stride=2, padding=2),
            nn.LeakyReLU(negative_slope=0.2, inplace=True)
        )
        self.se_s2 = SEBlock(c_out)
        self.block_2 = nn.Sequential(
            nn.Conv2d(c_out, c_out, kernel_size=2, stride=2, padding=2),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(c_out, c_out, kernel_size=2, stride=2, padding=2),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
        )
        self.se_s3 = SEBlock(c_out)
        self.block_3 = nn.Sequential(
            nn.Conv2d(c_out, c_out, kernel_size=2, stride=2, padding=2),
            nn.LeakyReLU(negative_slope=0.2, inplace=True)
        )
        #self.gap = nn.AdaptiveAvgPool2d((1,1))
        self.mlp = nn.Linear(c_out, 1)
        self.conv_cls = conv1x1(c_out, 1)

    def forward(self, x):
        #3D spatial Attention
        # x_spatial = x.mean(dim=1)
        # x_channel = x.mean(dim=2).mean(dim=2)
        x_spatial_1 = x.max(dim=1)[0].view(-1, 1, *x.shape[2:])
        x_att1_1 = x * (x_spatial_1 + 1)
        x_att1_2 = self.se_s1(x)
        x_2 = self.block_1(x_att1_1+x_att1_2)

        x_spatial_2 = x_2.max(dim=1)[0].view(-1, 1, *x_2.shape[2:])
        x_att2_1 = x_2*(x_spatial_2+1)
        x_att2_2 = self.se_s2(x_2)
        x_3 = self.block_2(x_att2_1+x_att2_2)

        x_spatial_3 = x_3.max(dim=1)[0].view(-1, 1, *x_3.shape[2:])
        x_att3_1 = x_3*(x_spatial_3+1)
        x_att3_2 = self.se_s3(x_3)
        x_4 = self.block_3(x_att3_1+x_att3_2)

        x_final = self.conv_cls(x_4)
        #x = self.gap(x).view(-1, self.c_out)
        #x = self.mlp(x)
        return x_final

class DDCPP(nn.Module):
    def __init__(self, input_channel):
        super(DDCPP, self).__init__()
        self.reduced_conv = conv3x3(input_channel, 256, 1, 1)
        self.reduced_bn = nn.BatchNorm2d(256)
        self.ddc_x1 = conv3x3(256, 256, 2, 2)
        self.ddc_bn1 = nn.BatchNorm2d(256)
        self.ddc_x2 = conv3x3(512, 256, 4, 4)
        self.ddc_bn2 = nn.BatchNorm2d(256)
        self.ddc_x3 = conv3x3(768, 256, 8, 8)
        self.ddc_bn3 = nn.BatchNorm2d(256)
        self.post_conv = conv1x1(1024, 512)
        self.post_bn = nn.BatchNorm2d(512)
        self.pool1_conv = conv1x1(512, 128)
        self.pool1_bn = nn.BatchNorm2d(128)
        self.pool2_conv = conv1x1(512, 128)
        self.pool2_bn = nn.BatchNorm2d(128)
        self.pool3_conv = conv1x1(512, 128)
        self.pool3_bn = nn.BatchNorm2d(128)
        self.pool4_conv = conv1x1(512, 128)
        self.pool4_bn = nn.BatchNorm2d(128)
        self.conv2 = conv1x1(1024, 512)
        self.bn2 = nn.BatchNorm2d(512)

        self.conv_cls = conv1x1(512, 1)


    def forward(self, x):
        # reduced_x
        x_r = F.relu(self.reduced_bn(self.reduced_conv(x)))
        #ddc x1
        x1 = F.relu(self.ddc_bn1(self.ddc_x1(x_r)))
        x1_c = torch.cat((x_r, x1), 1)
        # ddc x2
        x2 = F.relu(self.ddc_bn2(self.ddc_x2(x1_c)))
        x2_c = torch.cat((x1_c, x2), 1)
        # ddc x3
        x3 = F.relu(self.ddc_bn3(self.ddc_x3(x2_c)))
        #all concat
        x1_p = torch.cat((x_r, x1), 1)
        x2_p = torch.cat((x1_p, x2), 1)
        x3_p = torch.cat((x2_p, x3), 1)
        #post layers
        x_post = F.relu(self.post_bn(self.post_conv(x3_p)))
        
        #First level
        x_b_1 = F.avg_pool2d(x_post, (x_post.size(2), x_post.size(3)))
        x_b_1 = F.relu(self.pool1_bn(self.pool1_conv(x_b_1)))
        #Second level
        x_b_2 = F.avg_pool2d(x_post, (x_post.size(2) // 2, x_post.size(3) // 2))
        x_b_2 = F.relu(self.pool2_bn(self.pool2_conv(x_b_2)))
        # Third level
        x_b_3 = F.avg_pool2d(x_post, (x_post.size(2) // 4, x_post.size(3) // 4))
        x_b_3 = F.relu(self.pool3_bn(self.pool3_conv(x_b_3)))
        # Fourth level
        x_b_4 = F.avg_pool2d(x_post, (x_post.size(2) // 8, x_post.size(3) // 8))
        x_b_4 = F.relu(self.pool4_bn(self.pool4_conv(x_b_4)))
        #unsampling layer
        x_b_1_u = F.upsample(input=x_b_1, size=(x_post.size(2), x_post.size(3)), mode='bilinear')
        x_b_2_u = F.upsample(input=x_b_2, size=(x_post.size(2), x_post.size(3)), mode='bilinear')
        x_b_3_u = F.upsample(input=x_b_3, size=(x_post.size(2), x_post.size(3)), mode='bilinear')
        x_b_4_u = F.upsample(input=x_b_4, size=(x_post.size(2), x_post.size(3)), mode='bilinear')
        #concat layer
        x_c_1 = torch.cat((x_post,x_b_4_u),1)
        x_c_2 = torch.cat((x_c_1, x_b_3_u), 1)
        x_c_3 = torch.cat((x_c_2, x_b_2_u), 1)
        x_c_4 = torch.cat((x_c_3, x_b_1_u), 1)
        #domain classifier
        x_p = F.relu(self.bn2(self.conv2(x_c_4)))
        x = self.conv_cls(x_p)

        return x

def bce_loss(y_pred, y_label):
    y_truth_tensor = torch.FloatTensor(y_pred.size())
    y_truth_tensor.fill_(y_label)
    y_truth_tensor = y_truth_tensor.to(y_pred.get_device())
    return nn.BCEWithLogitsLoss()(y_pred, y_truth_tensor)

class DENSE_Spatial_DISC(nn.Module):
    def __init__(self, model_cfg, **kwargs):
        super().__init__()
        self.model_cfg = model_cfg
        self.source_one_name = self.model_cfg.SOURCE_ONE_NAME
        self.weight = self.model_cfg.WEIGHT
        
        if self.model_cfg.D_TYPE == 'DDCPP':
          self.domain_disc = DDCPP(self.model_cfg.INPUT_CONV_CHANNEL)
        elif self.model_cfg.D_TYPE == 'NORMAL':
          self.domain_disc = domain_disc(self.model_cfg.INPUT_CONV_CHANNEL)
        elif self.model_cfg.D_TYPE == '3D_DISC':
          self.domain_disc = domain_disc_3D(self.model_cfg.INPUT_CONV_CHANNEL)

        self.forward_ret_dict = {}

    def get_loss(self):

        tb_dict = {}
        dloss = 0

        if self.model_cfg.LOSS_TYPE == 'MEAN':
          pred_domain_s1 = F.sigmoid(self.forward_ret_dict['pred_domain_s1'])
          pred_domain_s2 = F.sigmoid(self.forward_ret_dict['pred_domain_s2'])

          dloss_s1 = torch.mean(pred_domain_s1 ** 2) * 0.5
          dloss_s2 = torch.mean((1 - pred_domain_s2) ** 2) * 0.5
          dloss = dloss_s1 + dloss_s2
        elif self.model_cfg.LOSS_TYPE == 'BCE':
          source_label = 0
          target_label = 1 

          pred_domain_s1 = self.forward_ret_dict['pred_domain_s1']
          pred_domain_s2 = self.forward_ret_dict['pred_domain_s2']
          dloss_s1 = bce_loss(pred_domain_s1, source_label) 
          dloss_s2 = bce_loss(pred_domain_s2, target_label)
          dloss = dloss_s1 + dloss_s2
        else:
          raise ValueError

        tb_dict['d_loss'] = dloss.item()
        return dloss, tb_dict

    def forward(self, data_dict):
        # Get shared representation
        split_tag_s1, split_tag_s2 = common_utils.split_batch_dict(self.source_one_name, data_dict)
        spatial_features = data_dict['spatial_features']

        # GRL
        spatial_features = common_utils.grad_reverse(spatial_features, weight=self.weight)

        spatial_features_s1 = spatial_features[split_tag_s1,:,:,:]
        spatial_features_s2 = spatial_features[split_tag_s2,:,:,:]

        # When data reading process of the source_1 has been finised, the diminseion of batch_size of two datasets is different 
        if spatial_features_s1.shape[0] != spatial_features_s2.shape[0]:
            return data_dict
        
        # Feed the dataset-agnostic the domain discriminator to make these features more agnostic
        pred_domain_s1 = self.domain_disc(spatial_features_s1)
        pred_domain_s2 = self.domain_disc(spatial_features_s2)
        self.forward_ret_dict['pred_domain_s1'] = pred_domain_s1
        self.forward_ret_dict['pred_domain_s2'] = pred_domain_s2

        return data_dict