import torch
import torch.nn as nn
import torch.nn.functional as F 

from mdt.StructPolicy.utils import ParaEstimator, StructEncoder, normalize_pc_list
from mdt.StructPolicy.StructMap import StructureMap

class StructModule(nn.Module):
    def __init__(self, d_VisualFeature, d_ParaEstimatorHidden, d_Features, d_Attentions, d_token, d_Hidden):
        super(StructModule, self).__init__()
        self.para_estimator = ParaEstimator(in_dim=d_VisualFeature, hidden_dim=d_ParaEstimatorHidden, out_dim=69)
        self.Encoder = StructEncoder(d_Feature1=d_Features[0], d_Attention1=d_Attentions[0], 
                                     d_token=d_token, d_Feature2=d_Features[1], d_Attention2=d_Attentions[1], d_Hidden=d_Hidden)
    def forward(self, x):
        B = x.shape[0]
        paras = self.para_estimator(x.reshape(B, -1))
        assert len(paras.shape) == 2, f"parameters shape: {paras.shape}"
        StructMap, TargetPose = StructureMap(paras)
        StructTokens = self.Encoder(normalize_pc_list(StructMap))
        return StructTokens, TargetPose
    
if __name__ == "__main__":
    from utils import ParaEstimator, StructEncoder
    from StructMap import StructureMap
    x = torch.ones([128, 3, 384], device='cpu')
    d_ParaEstimatorHidden=128
    d_Features=[64,128]
    d_Attentions=[128,256]
    d_token=128
    d_Hidden=256
    model = StructModule(
        d_VisualFeature=384 * 3,
        d_ParaEstimatorHidden=d_ParaEstimatorHidden,
        d_Features=d_Features,
        d_Attentions=d_Attentions,
        d_token=d_token,
        d_Hidden=d_Hidden
    )
    x, _ = model(x)
    print (x.shape)