
import numpy as np
import torch
from torch import nn
from typing import Dict, Optional, Union, List, Any

from mcu.arm.models.encoders.text_encoder import TextEncoder

class SubGoalEncoder(nn.Module): 
    
    def __init__(self, condition_dim: int = 512, **kwargs):
        super().__init__()
        self.text_embedding = TextEncoder(condition_dim=condition_dim//2)
        self.recipe_embedding = nn.Embedding(3, 128)
        self.table_embedding = nn.Embedding(3, 128)
        self.resource_embedding = nn.ModuleList([
            nn.Embedding(3, 128) for _ in range(9)
        ]) # 0 for empty, 1 for place, 2 for mask
        self.final_layer = nn.Linear(condition_dim//2+2*128+9*128, condition_dim)
        self.time_step = 0
    
    def forward(self, texts, subgoal=None, _T=0, device="cuda", **kwargs):
        obs_text_embeddings = self.text_embedding(texts, device=device)
        try:
            _B, _T = subgoal['table'].shape[:2]
        except:
            _B = obs_text_embeddings.shape[0]
        obs_text_embeddings = obs_text_embeddings.unsqueeze(1).expand(-1, _T, -1)

        recipe_tensor = subgoal['recipe']
        table_tensor = subgoal['table']
        resource_tensor = subgoal['layout'] 

        recipe_embedding = self.recipe_embedding(recipe_tensor)
        table_embedding = self.table_embedding(table_tensor)
        
        resource_embedding = torch.cat([
            self.resource_embedding[i](resource_tensor[:,:,i]) for i in range(9)
        ], dim=2)
        result = torch.cat([
            obs_text_embeddings, recipe_embedding, table_embedding, resource_embedding
        ], dim=-1)
        result = self.final_layer(result)
        return result
