import torch
import torch.nn as nn 
import math 
from lift3d.models.concept.MetaWorld.utils import rotate
from lift3d.models.concept.MetaWorld.utils import Para_Estimator, Knowledge_Encoder, PN_Knowledge_Encoder, PT_Knowledge_Encoder, Fusion_Knowledge_Encoder

def Concept_Reach(paras):
    """
    paras: [B, 3]
    """
    B = paras.shape[0]
    zeros = torch.zeros([B,1], device=paras.device, dtype=paras.dtype)
    ones = torch.ones([B,1], device=paras.device, dtype=paras.dtype)
    
    # Target
    vertices = paras.unsqueeze(dim=1)
    semantic = ones.unsqueeze(dim=1)
    Affordance = torch.cat([vertices, semantic], dim=-1)
    
    return Affordance

class Concept_Module_Reach(nn.Module):
    def __init__ (self, in_dim, paralist):
        estimator_hidden_dim = paralist[0]
        Geo_para_dim = paralist[1]
        Knowledge_feature_dim = paralist[2]
        Knowledge_attension_dim = paralist[3]
        Concept_out_dim = paralist[4]
        super(Concept_Module_Reach, self).__init__()
        self.para_estimator = Para_Estimator(in_dim, estimator_hidden_dim, Geo_para_dim)
        self.knowledge_encoder = Knowledge_Encoder(Knowledge_feature_dim, Knowledge_attension_dim, Concept_out_dim)
        
    def forward(self, feature, robot_state):
        paras = self.para_estimator(feature)
        knowledge = Concept_Reach(paras)
        concept_emb = self.knowledge_encoder(knowledge)
        return concept_emb, robot_state
    
