import torch

from rdkit.Chem import Mol as RDMol
from torch import Tensor
from torch_geometric.data import Batch
from tdc import Oracle
import rdkit.Chem as Chem

from gflownet.base import SynthesisTrainer, SynthesisGFNSampler, BaseTask
from gflownet.trainer import FlatRewards
from gflownet.models import bengio2021flow


class GskTask(BaseTask):
    def _load_task_models(self):
        return {'gsk': Oracle("GSK3B")}

    def compute_flat_rewards(self, mols: list[RDMol], batch_idx: list[int]) -> tuple[FlatRewards, Tensor]:
        graphs = [bengio2021flow.mol2graph(m) for m in mols]
        is_valid = torch.tensor([g is not None for g in graphs], dtype=torch.bool)
        batch = Batch.from_data_list([i for i in graphs if i is not None])
        batch.to(self.cfg.device)
        smiles = [Chem.MolToSmiles(mol) for mol in mols]
        gsk_rewards = self.models['gsk'](smiles)
        
        preds = torch.tensor(gsk_rewards).float().clip(1e-4, 100).unsqueeze(dim=1)

        return FlatRewards(preds), is_valid


class GskSynthesisTrainer(SynthesisTrainer):
    def setup_task(self):
        self.task = GskTask(cfg=self.cfg, rng=self.rng, wrap_model=self._wrap_for_mp)


class GskSynthesisSampler(SynthesisGFNSampler):
    def setup_task(self):
        self.task = GskTask(cfg=self.cfg, rng=self.rng, wrap_model=self._wrap_for_mp)
