
from typing import Dict, Optional, Union, List, Any

import clip
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

class TextEncoder(nn.Module):
    
    def __init__(self, condition_dim: int = 512, **kwargs):
        super().__init__()
        # first load into cpu, then move to cuda by the lightning
        self.clip_model, _ = clip.load("ViT-B/32", device='cpu')
        for param in self.clip_model.parameters():
            param.requires_grad = False
        self.ffn = nn.Linear(512, condition_dim)
    
    @torch.no_grad()
    def forward(self, texts, device="cuda", **kwargs):
        self.clip_model.eval()
        # remove 'pickup:' and 'mine_block:' prefix if exists. 
        texts = [text.split(':', 1)[-1] for text in texts]
        text_inputs = clip.tokenize(texts).to(device)
        embeddings = self.clip_model.encode_text(text_inputs)
        if hasattr(self, 'ffn'):
            embeddings = self.ffn(embeddings.to(self.ffn.weight.dtype))
        return embeddings

def layout_to_idx(layout: str) -> List[int]:
    mapping = {'o': 1, '#': 2}
    result = [mapping.get(c, 0) for c in layout]
    if len(result) < 9:
        result += [0] * (9 - len(result))
    return result