import torch
import torch.nn as nn 
import math 
from StructPolicy.StructTransformer.Encoders import Para_Estimator, Structure_Encoder

def Structure_Reach(paras):
    """
    paras: [B, 3]
    """
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    ones = torch.ones([B,1], device=paras.device, dtype=paras.dtype)
    
    # Target
    vertices = paras.unsqueeze(dim=1)
    semantic = ones.unsqueeze(dim=1)
    Affordance = torch.cat([vertices, semantic], dim=-1)
    
    return Affordance

class Structure_Module_Reach(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Structure_feature_dim = paralist[2]
        Structure_attension_dim = paralist[3]
        Structure_out_dim = paralist[4]
        super(Structure_Module_Reach, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.Structure_encoder = Structure_Encoder(Structure_feature_dim, Structure_attension_dim, Structure_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        Structure = Structure_Reach(paras)
        Structure_emb = self.Structure_encoder(Structure)
        return Structure_emb, robot_state
    
