import model
from torch import nn
import my_tomotwin

class PromptablePicker(nn.Module):
    def __init__(self, encoder_args, decoder_args, freeze_encoder=False):
        super().__init__()
        self.encoder_args = encoder_args
        self.decoder_args = decoder_args
        #self.encoder = load_tomotwin_model(tomotwin_model_file, load_weights=False)
        self.encoder = eval(encoder_args["class"])(**encoder_args["class_args"])
        if isinstance(self.encoder,  my_tomotwin.modules.networks.SiameseNet3D.SiameseNet3D):
            self.encoder = self.encoder.get_model()
        self.decoder = eval(decoder_args["class"])(**decoder_args["class_args"])

        if freeze_encoder:
            self.freeze_encoder()
    
    def freeze_encoder(self):
        self.encoder = self.encoder.eval()
        for param in self.encoder.parameters():
            param.requires_grad = False

    def encode(self, x):
        return self.encoder.forward(x)

    def decode(self, feats, prompt):
        feats = feats[::-1]  # reverse order
        out = self.decoder(
            x=feats[0],
            cats=feats[1:],
            prompt=prompt,
        )
        return out

    def forward(self, x, prompt):
        _, feats = self.encoder.forward(x)
        out = self.decode(feats, prompt)
        return out
