import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from basic import BiLinear

offset_map = {
    1024: -3.2041,
    2048: -3.4025,
    4096: -3.5836
}

class Conv1d(nn.Module):
    def __init__(self, inplane, outplane, Linear):
        super().__init__()
        self.lin = Linear(inplane, outplane)

    def forward(self, x):
        B, C, N = x.shape
        x = x.permute(0, 2, 1).contiguous().view(-1, C)
        x = self.lin(x).view(B, N, -1).permute(0, 2, 1).contiguous()
        return x


class EmaMaxPool(nn.Module):
    def __init__(self, kernel_size, affine=True, Linear=BiLinear, use_bn=True):
        super(EmaMaxPool, self).__init__()
        self.kernel_size = kernel_size
        self.bn3 = nn.BatchNorm1d(1024, affine=affine)
        self.use_bn = use_bn

    def forward(self, x):
        batchsize, D, N = x.size()
        if self.use_bn:
            x = torch.max(x, 2, keepdim=True)[0] + offset_map[N]
        else:
            x = torch.max(x, 2, keepdim=True)[0] - 0.3
        return x

class BiPointNetCls(nn.Module):
    def __init__(self, output_classes, input_dims=3, conv1_dim=64,
                 use_transform=True, Linear=BiLinear):
        super(BiPointNetCls, self).__init__()
        self.input_dims = input_dims
        self.conv1 = nn.ModuleList()
        self.conv1.append(Conv1d(input_dims, conv1_dim, Linear=Linear))
        self.conv1.append(Conv1d(conv1_dim, conv1_dim, Linear=Linear))
        self.conv1.append(Conv1d(conv1_dim, conv1_dim, Linear=Linear))

        self.bn1 = nn.ModuleList()
        self.bn1.append(nn.BatchNorm1d(conv1_dim))
        self.bn1.append(nn.BatchNorm1d(conv1_dim))
        self.bn1.append(nn.BatchNorm1d(conv1_dim))

        self.conv2 = nn.ModuleList()
        self.conv2.append(Conv1d(conv1_dim, conv1_dim * 2, Linear=Linear))
        self.conv2.append(Conv1d(conv1_dim * 2, conv1_dim * 16, Linear=Linear))

        self.bn2 = nn.ModuleList()
        self.bn2.append(nn.BatchNorm1d(conv1_dim * 2))
        self.bn2.append(nn.BatchNorm1d(conv1_dim * 16))

        self.maxpool = EmaMaxPool(conv1_dim * 16, Linear=Linear, use_bn=True)
        self.pool_feat_len = conv1_dim * 16

        self.mlp3 = nn.ModuleList()
        self.mlp3.append(Linear(conv1_dim * 16, conv1_dim * 8))
        self.mlp3.append(Linear(conv1_dim * 8, conv1_dim * 4))

        self.bn3 = nn.ModuleList()
        self.bn3.append(nn.BatchNorm1d(conv1_dim * 8))
        self.bn3.append(nn.BatchNorm1d(conv1_dim * 4))

        self.dropout = nn.Dropout(0.3)
        self.mlp_out = Linear(conv1_dim * 4, output_classes)

        self.use_transform = use_transform
        if use_transform:
            self.transform1 = TransformNet(input_dims)
            self.trans_bn1 = nn.BatchNorm1d(input_dims)
            self.transform2 = TransformNet(conv1_dim)
            self.trans_bn2 = nn.BatchNorm1d(conv1_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        h = x.permute(0, 2, 1)
        if self.use_transform:
            trans = self.transform1(h)
            h = h.transpose(2, 1)
            h = torch.bmm(h, trans)
            h = h.transpose(2, 1)
            h = F.relu(self.trans_bn1(h))

        for conv, bn in zip(self.conv1, self.bn1):
            h = conv(h)
            h = bn(h)
            h = F.relu(h)

        if self.use_transform:
            trans = self.transform2(h)
            h = h.transpose(2, 1)
            h = torch.bmm(h, trans)
            h = h.transpose(2, 1)
            h = F.relu(self.trans_bn2(h))

        for conv, bn in zip(self.conv2, self.bn2):
            h = conv(h)
            h = bn(h)
            h = F.relu(h)

        h = self.maxpool(h).view(-1, self.pool_feat_len)
        for mlp, bn in zip(self.mlp3, self.bn3):
            h = mlp(h)
            h = bn(h)
            h = F.relu(h)

        h = self.dropout(h)
        out = self.mlp_out(h)
        return out

class TransformNet(nn.Module):
    def __init__(self, input_dims=3, conv1_dim=64, Linear=BiLinear):
        super(TransformNet, self).__init__()
        self.conv = nn.ModuleList()
        self.conv.append(Conv1d(input_dims, conv1_dim, Linear=Linear))
        self.conv.append(Conv1d(conv1_dim, conv1_dim * 2, Linear=Linear))
        self.conv.append(Conv1d(conv1_dim * 2, conv1_dim * 16, Linear=Linear))

        self.bn = nn.ModuleList()
        self.bn.append(nn.BatchNorm1d(conv1_dim))
        self.bn.append(nn.BatchNorm1d(conv1_dim * 2))
        self.bn.append(nn.BatchNorm1d(conv1_dim * 16))

        # self.maxpool = nn.MaxPool1d(conv1_dim * 16)
        self.maxpool = EmaMaxPool(conv1_dim * 16, Linear=Linear, use_bn=True)
        self.pool_feat_len = conv1_dim * 16

        self.mlp2 = nn.ModuleList()
        self.mlp2.append(Linear(conv1_dim * 16, conv1_dim * 8))
        self.mlp2.append(Linear(conv1_dim * 8, conv1_dim * 4))

        self.bn2 = nn.ModuleList()
        self.bn2.append(nn.BatchNorm1d(conv1_dim * 8))
        self.bn2.append(nn.BatchNorm1d(conv1_dim * 4))

        self.input_dims = input_dims
        self.mlp_out = Linear(conv1_dim * 4, input_dims * input_dims)

    def forward(self, h):
        batch_size = h.shape[0]
        for conv, bn in zip(self.conv, self.bn):
            h = conv(h)
            h = bn(h)
            h = F.relu(h)
        
        h = self.maxpool(h).view(-1, self.pool_feat_len)
        for mlp, bn in zip(self.mlp2, self.bn2):
            h = mlp(h)
            h = bn(h)
            h = F.relu(h)

        out = self.mlp_out(h)

        iden = Variable(torch.from_numpy(np.eye(self.input_dims).flatten().astype(np.float32)))
        iden = iden.view(1, self.input_dims * self.input_dims).repeat(batch_size, 1)
        if out.is_cuda:
            iden = iden.cuda()
        out = out + iden
        out = out.view(-1, self.input_dims, self.input_dims)
        return out
