from timm.layers import Mlp
import torch
import torch.nn as nn
import DecisionNCE
import random
import pickle


class DnceLatentProj(nn.Module):
    def __init__(
        self,
        latent_info_file='assets/libero.pkl'
    ):
        super().__init__()
        self.latent_proj = DecisionNCE.load("DecisionNCE-T", device="cuda")
        self.latent_proj.requires_grad_(False)

        if latent_info_file is not None:
            with open(latent_info_file, "rb") as f:
                data = pickle.load(f)
            img_mean = torch.tensor(data['mean'])
            img_std = torch.tensor(data['std'])
        else :
            img_mean = torch.zeros(1024)
            img_std = torch.ones(1024)

        self.register_buffer('img_mean', img_mean)
        self.register_buffer('img_std', img_std)
    
    def img_proj(self, x):
        x = self.latent_proj.model.encode_image(x)
        x = (x - self.img_mean) / self.img_std
        return x
    
    def lang_proj(self, x):
        x = self.latent_proj.encode_text(x)
        return x


