import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm;
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)

def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S, [K]]
    Return:
        new_points:, indexed points data, [B, S, [K], C]
    """
    raw_size = idx.size()
    idx = idx.reshape(raw_size[0], -1)
    res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1)))
    return res.reshape(*raw_size, -1)

class Para_Estimator(nn.Module):
    def __init__ (self, in_dim, hidden_dim, out_dim) :
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),    
        )
    def forward(self, feature):
        return self.encoder(feature)

class StructureTransformer(nn.Module):
    def __init__(self, d_points, d_model) -> None:
        super().__init__()
        self.fc1 = nn.Linear(d_points, d_model)
        self.fc2 = nn.Linear(d_model, d_points)
        self.fc_delta = nn.Sequential(
            nn.Linear(3, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.fc_gamma = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.w_qs = nn.Linear(d_model, d_model, bias=False)
        self.w_ks = nn.Linear(d_model, d_model, bias=False)
        self.w_vs = nn.Linear(d_model, d_model, bias=False)
        
    # xyz: b x n x 3, features: b x n x f
    def forward(self, xyz, features):        
        pre = features
        x = self.fc1(features)
        q, k, v = self.w_qs(x), self.w_ks(x), self.w_vs(x)

        pos_enc = self.fc_delta(xyz[:,:,None] - xyz[:,None])  # b x n x n x f
        
        attn = self.fc_gamma(q[:, :, None, :] - k[:, None, :, :] + pos_enc)
        attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2)  # b x n x n x f

        v_expanded = v.unsqueeze(1).expand(-1, v.size(1), v.size(1), -1)  # [b, n, n, f]
        res = torch.einsum('bmnf,bmnf->bmf', attn, v_expanded + pos_enc) # b × n × f
        res = self.fc2(res) + pre
        return res, attn
    
class Structure_Encoder(nn.Module):
    def __init__(self, d_Semantic=32, d_Attention=64, d_Hidden=128):
        super().__init__()
        self.fc_semantic = nn.Sequential(
            nn.Linear(1, d_Semantic),
            nn.ReLU(),
            nn.Linear(d_Semantic, d_Semantic),
        )
        self.PT = StructureTransformer(d_Semantic, d_Attention)

        self.attn_pool = nn.Sequential(
            nn.Linear(d_Semantic, d_Attention),
            nn.Tanh(),
            nn.Linear(d_Attention, 1)
        )

        self.fc_head = nn.Sequential(
            nn.Linear(d_Semantic, d_Hidden),
            nn.ReLU(),
            nn.Linear(d_Hidden, d_Hidden*2),
            nn.ReLU(),
            nn.Linear(d_Hidden*2, d_Hidden),
        )

    def forward(self, x):
        """
        x: [B, N, 4], where each point is (x, y, z, semantic_label)
        """
        Affordance = x[:, :, 0:3]    # [B, N, 3]
        Semantic = x[:, :, 3:4]      # [B, N, 1]

        features = self.fc_semantic(Semantic)         # [B, N, d_Semantic]
        res, attn = self.PT(Affordance, features)     # [B, N, d_Attention]

        # Attention-based pooling
        weights = self.attn_pool(res)                 # [B, N, 1]
        weights = F.softmax(weights, dim=1)           # [B, N, 1]
        pooled = torch.sum(weights * res, dim=1)      # [B, d_Attention]

        return self.fc_head(pooled)                   # [B, d_Hidden]
        
class TransformerBlock(nn.Module):
    def __init__(self, d_points, d_model, k) -> None:
        super().__init__()
        self.fc1 = nn.Linear(d_points, d_model)
        self.fc2 = nn.Linear(d_model, d_points)
        self.fc_delta = nn.Sequential(
            nn.Linear(3, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.fc_gamma = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.w_qs = nn.Linear(d_model, d_model, bias=False)
        self.w_ks = nn.Linear(d_model, d_model, bias=False)
        self.w_vs = nn.Linear(d_model, d_model, bias=False)
        self.k = k
        
    # xyz: b x n x 3, features: b x n x f
    def forward(self, xyz, features):
        dists = square_distance(xyz, xyz)
        knn_idx = dists.argsort()[:, :, :self.k]  # b x n x k
        knn_xyz = index_points(xyz, knn_idx)
        
        pre = features
        x = self.fc1(features)
        q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx)

        pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz)  # b x n x k x f
        
        attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
        attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2)  # b x n x k x f
        
        res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc)
        res = self.fc2(res) + pre
        return res, attn
    
class PT_Structure_Encoder(nn.Module):
    def __init__(self, d_Semantic=32, d_Attention=64, d_Hidden=128):
        super().__init__()
        self.fc_semantic = nn.Sequential(
            nn.Linear(1, d_Semantic),
            nn.ReLU(),
            nn.Linear(d_Semantic, d_Semantic),
        )
        self.PT = TransformerBlock(d_Semantic, d_Attention, k=4)

        self.attn_pool = nn.Sequential(
            nn.Linear(d_Semantic, d_Attention),
            nn.Tanh(),
            nn.Linear(d_Attention, 1)
        )

        self.fc_head = nn.Sequential(
            nn.Linear(d_Semantic, d_Hidden),
            nn.ReLU(),
            nn.Linear(d_Hidden, d_Hidden*2),
            nn.ReLU(),
            nn.Linear(d_Hidden*2, d_Hidden),
        )

    def forward(self, x):
        """
        x: [B, N, 4], where each point is (x, y, z, semantic_label)
        """
        Affordance = x[:, :, 0:3]    # [B, N, 3]
        Semantic = x[:, :, 3:4]      # [B, N, 1]

        features = self.fc_semantic(Semantic)         # [B, N, d_Semantic]
        res, attn = self.PT(Affordance, features)     # [B, N, d_Attention]

        # Attention-based pooling
        weights = self.attn_pool(res)                 # [B, N, 1]
        weights = F.softmax(weights, dim=1)           # [B, N, 1]
        pooled = torch.sum(weights * res, dim=1)      # [B, d_Attention]

        return self.fc_head(pooled)                   # [B, d_Hidden]
    
class Fusion_Structure_Encoder(nn.Module):
    def __init__ (self, Semantic_Hidden=128, d_point=400, xyz_Hidden=128, Fusion_Hidden=256):
        super().__init__()
        self.fc_Semantic = nn.Sequential(
            nn.Linear(d_point, Semantic_Hidden),
            nn.ReLU(),
            nn.Linear(Semantic_Hidden, Semantic_Hidden*2),
            nn.ReLU(),
            nn.Linear(Semantic_Hidden*2, Semantic_Hidden),
        )
        self.fc_xyz = nn.Sequential(
            nn.Linear(d_point*3, xyz_Hidden),
            nn.ReLU(),
            nn.Linear(xyz_Hidden, xyz_Hidden*2),
            nn.ReLU(),
            nn.Linear(xyz_Hidden*2, xyz_Hidden),
        )
        self.fc_Fusion = nn.Sequential(
            nn.Linear(Semantic_Hidden+xyz_Hidden, Fusion_Hidden),
            nn.ReLU(),
            nn.Linear(Fusion_Hidden, Fusion_Hidden*2),
            nn.ReLU(),
            nn.Linear(Fusion_Hidden*2, Fusion_Hidden),
        )
    def forward(self, x):
        """
        x: [B, N, 4], where each point is (x, y, z, semantic_label)
        """
        xyz = x[:, :, 0:3]    # [B, N, 3]
        xyz = torch.flatten(xyz, start_dim=1)
        Semantic = x[:, :, 3:4]      # [B, N, 1]
        Semantic = torch.flatten(Semantic, start_dim=1)
        
        Semantic_Feature = self.fc_Semantic(Semantic)
        xyz_Feature = self.fc_xyz(xyz)
        
        Fusion_Feature = torch.cat((Semantic_Feature, xyz_Feature), dim=1)
        Fusion_Feature = self.fc_Fusion(Fusion_Feature)
        
        return Fusion_Feature
  
class STN3d(nn.Module):
    def __init__(self, channel):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(channel, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = torch.autograd.Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
            batchsize, 1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x
  
class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k * k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = torch.autograd.Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
            batchsize, 1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x
  
class PointNetEncoder(nn.Module):
    def __init__(self, global_feat=True, feature_transform=False, channel=3):
        super(PointNetEncoder, self).__init__()
        self.stn = STN3d(channel)
        self.conv1 = torch.nn.Conv1d(channel, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64)

    def forward(self, x):
        B, D, N = x.size()
        trans = self.stn(x)
        x = x.transpose(2, 1)
        if D > 3:
            feature = x[:, :, 3:]
            x = x[:, :, :3]
        x = torch.bmm(x, trans)
        if D > 3:
            x = torch.cat([x, feature], dim=2)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2, 1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2, 1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, N)
            return torch.cat([x, pointfeat], 1), trans, trans_feat
         
class PointNet(nn.Module):
    def __init__(self, k=40, normal_channel=4):
        super(PointNet, self).__init__()
        self.feat = PointNetEncoder(global_feat=True, feature_transform=True, channel=normal_channel)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.4)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        x = F.log_softmax(x, dim=1)
        return x, trans_feat

class PN_Structure_Encoder(nn.Module):
    def __init__(self, out_dim):
        super(PN_Structure_Encoder, self).__init__()
        self.encoder = PointNet(out_dim, 4)
    def forward(self, x):
        x = x.transpose(1,2)
        x, trans_feat = self.encoder(x)     #[16,128], [16,64,64]
        return x