import torch

from rdkit.Chem import Mol as RDMol
from torch import Tensor
from torch_geometric.data import Batch

from gflownet.base import SynthesisTrainer, SynthesisGFNSampler, BaseTask
from gflownet.trainer import FlatRewards
from gflownet.models import bengio2021flow


class SehTask(BaseTask):
    def _load_task_models(self):
        model = bengio2021flow.load_original_model()
        model.to(self.cfg.device)
        model = self._wrap_model(model)
        return {"seh": model}

    def compute_flat_rewards(self, mols: list[RDMol], batch_idx: list[int]) -> tuple[FlatRewards, Tensor]:
        graphs = [bengio2021flow.mol2graph(m) for m in mols]
        is_valid = torch.tensor([g is not None for g in graphs], dtype=torch.bool)
        batch = Batch.from_data_list([i for i in graphs if i is not None])
        batch.to(self.cfg.device)
        preds = self.models["seh"][0](batch).data.cpu() / 8

        preds[preds.isnan()] = 0
        preds = preds.clip(1e-4, 100)

        return FlatRewards(preds), is_valid


class SehSynthesisTrainer(SynthesisTrainer):
    def setup_task(self):
        self.task = SehTask(cfg=self.cfg, rng=self.rng, wrap_model=self._wrap_for_mp)


class SehSynthesisSampler(SynthesisGFNSampler):
    def setup_task(self):
        self.task = SehTask(cfg=self.cfg, rng=self.rng, wrap_model=self._wrap_for_mp)
