from typing import List, Optional
import numpy as np
from open_biomed.data import Molecule, Protein
from open_biomed.utils.config import Config
from open_biomed.tasks.aidd_tasks.structure_based_drug_design import calc_vina_molecule_metrics

class SuccessReward:
    def __init__(self, protein: Protein) -> None:
        self.protein = protein

    def __call__(self, molecule: Optional[Molecule]) -> float:
        if molecule is None:
            return -10
        metrics = calc_vina_molecule_metrics(molecule, self.protein)
        return metrics["success"]

class WeightedSuccessReward:
    def __init__(self, protein: Protein, ref_molecule: Molecule, cfg: Config) -> None:
        self.protein = protein
        self.weights = cfg.weights
        self.use_vina_min = cfg.use_vina_min
        self.ref_metrics = calc_vina_molecule_metrics(ref_molecule, protein, calculate_vina_dock=not self.use_vina_min)

    # NOTE: since vina_dock is slow, we refer to vina_min to speed up
    def __call__(self, molecule: Optional[Molecule]) -> float:
        if molecule is None:
            return -10
        metrics = calc_vina_molecule_metrics(molecule, self.protein, calculate_vina_dock=not self.use_vina_min)
        reward = 0
        keys = ["vina_min", "qed", "sa"] if self.use_vina_min else ["vina_dock", "qed", "sa"]
        for i, key in enumerate(keys):
            reward += (metrics[key] - self.ref_metrics[key]) * self.weights[i]
        return reward
    
class WeightedSuccessSmoothedReward:
    def __init__(self, protein: Protein, ref_molecule: Molecule, cfg: Config) -> None:
        self.protein = protein
        self.weights = cfg.weights
        self.offset = cfg.offset
        self.use_vina_min = cfg.use_vina_min
        self.norm = cfg.norm

    def __call__(self, molecule: Optional[Molecule]) -> float:
        if molecule is None:
            return -10
        
        metrics = calc_vina_molecule_metrics(molecule, self.protein, calculate_vina_dock=not self.use_vina_min)
        reward = 0
        keys = ["vina_min", "qed", "sa"] if self.use_vina_min else ["vina_dock", "qed", "sa"]
        for i, key in enumerate(keys):
            reward += (2 / (1 + np.exp(-self.norm[i] * (metrics[key] - self.offset[i]))) - 1) * self.weights[i]
        return reward
    
class WeightedSuccessRewardWithPBValidity:
    def __init__(self, protein: Protein, ref_molecule: Molecule, weights: List[float]) -> None:
        self.weights = weights
        self.ref_metrics = calc_vina_molecule_metrics(ref_molecule, protein)

    def __call__(self, molecule: Optional[Molecule]) -> float:
        # TODO: implement PB check
        pass

class WeightedSuccessRewardWithADMET:
    def __init__(self, protein: Protein, ref_molecule: Molecule, weights: List[float]) -> None:
        self.weights = weights
        self.ref_metrics = calc_vina_molecule_metrics(ref_molecule, protein)

    def __call__(self, molecule: Optional[Molecule]) -> float:
        # TODO: implement ADMET prediction models as proxies
        pass