import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, zeros_

def stn(x, theta, padding_mode='zeros'):
    grid = F.affine_grid(theta, x.size())
    img = F.grid_sample(x, grid, padding_mode=padding_mode)
    return img

class Encoder3D(nn.Module):
    def __init__(self, in_channel, hidden_channel):
        super(Encoder3D, self).__init__()
        self.in_channel = in_channel
        self.conv3d_1 = nn.ConvTranspose3d(in_channel, hidden_channel, 4, stride=2, padding=1)
        # self.conv3d_2 = nn.ConvTranspose3d(hidden_channel, hidden_channel, 4, stride=2, padding=1)

    def forward(self, feat):
        B,C,H,W = feat.shape
        z_3d = feat.reshape([B, self.in_channel, -1, H, W])
        z_3d = F.leaky_relu(self.conv3d_1(z_3d))
        # z_3d = F.leaky_relu(self.conv3d_2(z_3d))
        return z_3d

class Rotate(nn.Module):
    def __init__(self, learn, in_channel):
        super(Rotate, self).__init__()
        self.padding_mode = 'zeros'
        self.learn = learn
        if self.learn:
            self.conv3d_1 = nn.Conv3d(in_channel,in_channel,3,padding=1)
            # self.conv3d_2 = nn.Conv3d(in_channel,in_channel,3,padding=1)

    def forward(self, code, theta):
        rot_code = stn(code, theta, self.padding_mode)
        if self.learn:
            rot_code = F.leaky_relu(self.conv3d_1(rot_code))
            # rot_code = F.leaky_relu(self.conv3d_2(rot_code))
        return rot_code

class Relate3D(nn.Module):
    def __init__(self, support_param):
        super(Relate3D, self).__init__()
        # 3d mapping
        self.in_channel = support_param['in_channel']
        self.hidden_3d = support_param['hidden_3d']
        self.encode3d = Encoder3D(self.in_channel, self.hidden_3d)
        # rotate
        self.rot_learn = support_param['learn_rot']
        self.rotate = Rotate(support_param['learn_rot'], self.hidden_3d)
        # 3d relation
        self.relate = nn.Sequential(
                            nn.Conv3d(self.hidden_3d*2, self.hidden_3d, kernel_size=3, groups=self.hidden_3d, padding=1),
                            nn.BatchNorm3d(self.hidden_3d),
                            nn.ReLU(),
                            nn.AvgPool3d(2, stride=2),
                            nn.Conv3d(self.hidden_3d, self.in_channel, kernel_size=1),
                            nn.BatchNorm3d(self.in_channel),
                            nn.ReLU(),
                            )

    def forward(self, query, support, support_traj=None, return_support=False):
        if support_traj!=None:
            B, N, C, w, h = support.shape
            query = query.reshape(B, -1, C, w, h)
            P = query.shape[1]
            # 3d mapping
            voxel_s = self.encode3d(support.flatten(0,1))
            D, H, W = voxel_s.shape[2], voxel_s.shape[3], voxel_s.shape[4]
            # 3d transform
            theta_s = support_traj[:, :N].reshape(B*N, 3, 4)
            rot_voxels_support = self.rotate(voxel_s, theta_s).view(B, N, -1, D, H, W)
            voxels_support = rot_voxels_support.mean(dim=1, keepdim=True).repeat(1, P, 1, 1, 1, 1).flatten(0,1) # B, C, D, H, W
        else:
            B, N, _, D, H, W = support.shape
            C, w, h = query.shape[1], query.shape[2], query.shape[3]
            query = query.reshape(B, -1, C, w, h)
            P = query.shape[1]
            voxels_support = support.repeat(1, P, 1, 1, 1, 1).flatten(0,1) # B, C, D, H, W
        # voxel query
        voxel_q = self.encode3d(query.flatten(0,1))
        D, H, W = voxel_q.shape[2], voxel_q.shape[3], voxel_q.shape[4]
        if self.rot_learn:
            theta_q = torch.cat([torch.eye(3).unsqueeze(0).repeat(P*B, 1, 1), torch.zeros(3).unsqueeze(0).repeat(P*B, 1).unsqueeze(-1)], dim=-1).to(voxel_q.device)
            voxels_query = self.rotate(voxel_q, theta_q).view(B*P, -1, D, H, W)
        else:
            voxels_query = voxel_q
        # 3d relation
        s = voxels_support
        q = voxels_query
        # BN, 2C (q1 s1 q2 s2, ...), W, H
        sta_feat = torch.stack([q, s], dim=2).flatten(1, 2)
        rela_feat = self.relate(sta_feat)
        rela_feat = rela_feat.reshape(B*P, C, w, h)
        if return_support:
            return rela_feat, rot_voxels_support.mean(dim=1, keepdim=True)
        else:
            return rela_feat

class Relate3DMix(nn.Module):
    def __init__(self, support_param):
        super(Relate3DMix, self).__init__()
        # 3d mapping
        self.in_channel = support_param['in_channel3d']
        self.hidden_3d = support_param['hidden_3d']
        self.encode3d = Encoder3D(self.in_channel, self.hidden_3d)
        # rotate
        self.rot_learn = support_param['learn_rot']
        self.rotate = Rotate(support_param['learn_rot'], self.hidden_3d)
        # 3d relation
        self.relate3d = nn.Sequential(
                            nn.Conv3d(self.hidden_3d*2, self.hidden_3d, kernel_size=3, groups=self.hidden_3d, padding=1),
                            nn.BatchNorm3d(self.hidden_3d),
                            nn.ReLU(),
                            nn.AvgPool3d(2, stride=2),
                            nn.Conv3d(self.hidden_3d, self.in_channel, kernel_size=1),
                            nn.BatchNorm3d(self.in_channel),
                            nn.ReLU(),
                            )
        # 2d relation
        self.in_channel_2d = support_param['in_channel2d']
        self.hidden_2d = support_param['hidden_2d']
        self.relate2d = nn.Sequential(
                            nn.Conv2d(self.in_channel_2d*2, self.hidden_2d, kernel_size=3, groups=self.in_channel_2d, padding=1),
                            nn.BatchNorm2d(self.hidden_2d),
                            nn.ReLU(),
                            nn.Conv2d(self.hidden_2d, self.hidden_2d, kernel_size=1),
                            nn.BatchNorm2d(self.hidden_2d),
                            nn.ReLU(),
                            )
        # relation fusion
        self.fusion = nn.Sequential(
                            nn.Conv2d(self.in_channel_2d*2, self.hidden_2d, kernel_size=1),
                            nn.BatchNorm2d(self.hidden_2d),
                            nn.ReLU()
                            )

    def forward(self, query, support, support_traj):
        B, N, C, w, h = support.shape
        query = query.reshape(B, -1, C, w, h)
        P = query.shape[1]
        # 3d mapping
        voxel_s = self.encode3d(support.flatten(0,1))
        D, H, W = voxel_s.shape[2], voxel_s.shape[3], voxel_s.shape[4]
        voxel_q = self.encode3d(query.flatten(0,1))
        # 3d transform
        theta_s = support_traj[:, :N].reshape(B*N, 3, 4)
        rot_voxels_support = self.rotate(voxel_s, theta_s).view(B, N, -1, D, H, W)
        voxels_support = rot_voxels_support.mean(dim=1, keepdim=True).repeat(1, P, 1, 1, 1, 1) # B, C, D, H, W
        if self.rot_learn:
            theta_q = torch.cat([torch.eye(3).unsqueeze(0).repeat(P*B, 1, 1), torch.zeros(3).unsqueeze(0).repeat(P*B, 1).unsqueeze(-1)], dim=-1).to(voxel_q.device)
            voxels_query = self.rotate(voxel_q, theta_q).view(B*P, -1, D, H, W)
        else:
            voxels_query = voxel_q
        # 3d relation
        s = voxels_support.flatten(0,1)
        q = voxels_query
        sta_feat = torch.stack([q, s], dim=2).flatten(1, 2)
        rela_feat_3d = self.relate3d(sta_feat)
        rela_feat_3d = rela_feat_3d.reshape(B*P, C, w, h)
        # 2d relation
        q = query.view(B, -1, C, w, h)
        s = support.mean(1, keepdim=True)
        s = s.repeat(1, P, 1, 1, 1).flatten(0,1)
        q = q.flatten(0,1)
        sta_feat = torch.stack([q, s], dim=2).flatten(1, 2)
        rela_feat_2d = self.relate2d(sta_feat)
        # Mixure Fusion
        sta_feat = torch.stack([rela_feat_3d, rela_feat_2d], dim=2).flatten(1, 2)
        rela_feat = self.fusion(sta_feat)
        return rela_feat

class Relate3DMixS(nn.Module):
    def __init__(self, support_param):
        super(Relate3DMixS, self).__init__()
        # 3d mapping
        self.in_channel = support_param['in_channel3d']
        self.hidden_3d = support_param['hidden_3d']
        self.encode3d = Encoder3D(self.in_channel, self.hidden_3d)
        # rotate
        self.rot_learn = support_param['learn_rot']
        self.rotate = Rotate(support_param['learn_rot'], self.hidden_3d)
        # 3d relation
        self.relate3d = nn.Sequential(
                            nn.Conv3d(self.hidden_3d*2, self.hidden_3d, kernel_size=3, groups=self.hidden_3d, padding=1),
                            nn.BatchNorm3d(self.hidden_3d),
                            nn.ReLU(),
                            nn.AvgPool3d(2, stride=2),
                            nn.Conv3d(self.hidden_3d, self.in_channel, kernel_size=1),
                            nn.BatchNorm3d(self.in_channel),
                            nn.ReLU(),
                            )
        # 2d relation
        self.in_channel_2d = support_param['in_channel2d']
        self.hidden_2d = support_param['hidden_2d']
        self.relate2d = nn.Sequential(
                            nn.Conv2d(self.in_channel_2d*2, self.hidden_2d, kernel_size=3, groups=self.in_channel_2d//2, padding=1),
                            nn.BatchNorm2d(self.hidden_2d),
                            nn.ReLU(),
                            nn.Conv2d(self.hidden_2d, self.hidden_2d, kernel_size=1),
                            nn.BatchNorm2d(self.hidden_2d),
                            nn.ReLU(),
                            )

    def forward(self, query, support, support_traj):
        B, N, C, w, h = support.shape
        query = query.reshape(B, -1, C, w, h)
        P = query.shape[1]
        # 3d mapping
        voxel_s = self.encode3d(support.flatten(0,1))
        D, H, W = voxel_s.shape[2], voxel_s.shape[3], voxel_s.shape[4]
        voxel_q = self.encode3d(query.flatten(0,1))
        # 3d transform
        theta_s = support_traj[:, :N].reshape(B*N, 3, 4)
        rot_voxels_support = self.rotate(voxel_s, theta_s).view(B, N, -1, D, H, W)
        voxels_support = rot_voxels_support.mean(dim=1, keepdim=True).repeat(1, P, 1, 1, 1, 1) # B, C, D, H, W
        if self.rot_learn:
            theta_q = torch.cat([torch.eye(3).unsqueeze(0).repeat(P*B, 1, 1), torch.zeros(3).unsqueeze(0).repeat(P*B, 1).unsqueeze(-1)], dim=-1).to(voxel_q.device)
            voxels_query = self.rotate(voxel_q, theta_q).view(B*P, -1, D, H, W)
        else:
            voxels_query = voxel_q
        # 3d relation
        s = voxels_support.flatten(0,1)
        q = voxels_query
        sta_feat = torch.stack([q, s], dim=2).flatten(1, 2)
        rela_feat_3d = self.relate3d(sta_feat)
        rela_feat_3d = rela_feat_3d.reshape(B*P, C, w, h)
        # 2d relation
        q = query.view(B, -1, C, w, h)
        s = support.mean(1, keepdim=True)
        s = s.repeat(1, P, 1, 1, 1).flatten(0,1)
        q = q.flatten(0,1)
        sta_feat = torch.stack([q, s], dim=2).flatten(1, 2)
        rela_feat_2d = self.relate2d(sta_feat)
        return rela_feat_2d, rela_feat_3d


