import torch
import numpy as np
from torch import nn
import h5py


def get_problem(name, *args, **kwargs):
    name = name.lower()
    
    PROBLEM = {
        'molecular_discovery': MolecularDiscovery
 }

    if name not in PROBLEM:
        raise Exception("Problem not found.")
    
    return PROBLEM[name](*args, **kwargs)

class MolecularDiscovery():

    class SurrogateModel(nn.Module):
        def __init__(self, a_dim, c_dim):
            super(MolecularDiscovery.SurrogateModel, self).__init__()
            self.a_dim = a_dim
            self.c_dim = c_dim
            self.fc = nn.Sequential(
                nn.Linear(a_dim + c_dim, 256),
                nn.ReLU(),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Linear(256, 10),
                nn.Sigmoid()
            )

        def forward(self, x):
            return self.fc(x)
        
    def __init__(self, device):
        self.device = device
        self.a_dim = 32
        self.c_dim = 64
        self.n_dim = 32 + 64
        self.lbound = torch.tensor([-4] * 32).float()
        self.ubound = torch.tensor([4] * 32).float()
        self.model = self.SurrogateModel(32, 64).to(device)
        self.model.load_state_dict(torch.load('design_baselines/diff/molecular_discovery/surrogate_model.pt'))
        self.model.eval()
        self.targets = torch.tensor(np.load('design_baselines/diff/molecular_discovery/target_embeddings.npy')).float().to(device)
        self.objectives = ['logp', 'qed', 'sa', 'total_energy', 'torsional_energy', 'hydrogen_bonds', 'pi_pi_stacking_interactions', 'salt_bridges', 't_stacking_interactions', 'ligand_efficiency']
        self.x = np.load('design_baselines/diff/molecular_discovery/md_design.npy')
        self.y = np.load('design_baselines/diff/molecular_discovery/md_obj.npy')
        self.is_discrete = False

    def predict(self, x):
        x = torch.from_numpy(x).to(self.device).float()
        target_embedding = self.targets[:1].repeat(x.shape[0], 1)
        x = torch.cat([x,target_embedding],-1)
        results = self.model(x)
        results = results.detach().cpu().numpy()
        # return results[:,1:2] - results[:,2:3] - results[:,3:4] #return total_energy
        return 1 - results[:,3:4] #return total_energy