import torch.nn as nn
import pytorch_lightning as pl
from typing import Dict, Any, Callable, List
from open_biomed.data import Pocket, Molecule
from open_biomed.utils.featurizer import Featurized

class DockingModelWrapper(pl.LightningModule):
    def __init__(self, model: nn.Module) -> None:
        super(DockingModelWrapper, self).__init__()
        self.model = model
    
    def forward(self, molecule: Featurized[Molecule], pocket: Featurized[Pocket]) -> Any:
        return self.model(molecule, pocket)
    
    def sample(self, molecule: Featurized[Molecule], pocket: Featurized[Pocket], num_samples: int, estimated_ligand_num: List[int]=None) -> Any:
        return self.model.sample_fn(molecule, pocket, num_samples, estimated_ligand_num)
    
    def get_featurizer(self) -> Any:
        return self.model.get_featurizer()

class SBDDModelWrapper(pl.LightningModule):
    def __init__(self, model: nn.Module) -> None:
        super(SBDDModelWrapper, self).__init__()
        self.model = model
    
    def forward(self, pocket: Featurized[Pocket]) -> Any:
        return self.model(pocket)
        
    def sample(self, pocket: Featurized[Pocket], num_samples: int, reward_fn: Callable, estimated_ligand_num: List[int]=None) -> Any:
        return self.model.sample_fn(pocket, num_samples, reward_fn, estimated_ligand_num)
        
    def get_featurizer(self) -> Any:
        return self.model.get_featurizer()