import numpy as np
from rdkit.Chem import MolFromSmiles, MolToSmiles
from SmilesPE.pretokenizer import atomwise_tokenizer
import torch
from models import *
from data import *


class OSLSModel:
    def __init__(self, model_file, data_file='data.pickle'):
        self.token_to_idx, self.mean, self.std, max_idx, self.max_len = get_for_eval(data_file)
        self.context_embedder = Embedder(max_idx, self.max_len).cuda()
        self.context_attn = ContextAttn().cuda()
        self.query_embedder = Embedder(max_idx, self.max_len).cuda()
        self.predictor = Predictor().cuda()
        context_embedder_dict, query_embedder_dict, predictor_dict, context_attn_dict = torch.load(model_file)
        self.context_embedder.load_state_dict(context_embedder_dict)
        self.query_embedder.load_state_dict(query_embedder_dict)
        self.context_attn.load_state_dict(context_attn_dict)
        self.predictor.load_state_dict(predictor_dict)
        self.context_embedder.eval()
        self.query_embedder.eval()
        self.context_attn.eval()
        self.predictor.eval()

    def get_context(self, contexts):
        context = torch.zeros((len(contexts), 1, config['d_model']), device='cuda')
        for j in range(len(contexts)):
            context_x, context_y = self.featurize_mol(contexts[j][0]), self.canonicalize_val(contexts[j][1])
            context_x = torch.tensor(context_x).unsqueeze(0).cuda()
            context_y = torch.tensor(context_y).float().view((1, 1)).cuda()
            context[j] = self.context_embedder(context_x, context_y)
        return self.context_attn(context)

    def predict(self, contexts, query):
        query_x = self.featurize_mol(query)
        query_x = torch.tensor(query_x).unsqueeze(0).cuda()
        context = torch.zeros((len(contexts), 1, config['d_model']), device='cuda')
        for j in range(len(contexts)):
            context_x, context_y = self.featurize_mol(contexts[j][0]), self.canonicalize_val(contexts[j][1])
            context_x = torch.tensor(context_x).unsqueeze(0).cuda()
            context_y = torch.tensor(context_y).float().view((1, 1)).cuda()
            context[j] = self.context_embedder(context_x, context_y)
        context = self.context_attn(context)
        query = self.query_embedder(query_x)
        x = torch.concat((context, query), dim=1)
        return self.predictor(x).item()

    def get_query_embedding(self, query):
        x = self.featurize_mol(query)
        x = torch.tensor(x).unsqueeze(0).cuda()
        query = self.query_embedder(x)
        return query.detach().cpu().numpy().flatten()

    def featurize_mol(self, smiles):
        if not ((10 < len([char for char in smiles if char not in '()=@[]123456789']) < 70) and MolFromSmiles(smiles)):
            raise ValueError('smiles invalid or incorrect length')
        fs = [1]
        for token in atomwise_tokenizer(MolToSmiles(MolFromSmiles(smiles))):
            if token in self.token_to_idx:
                fs.append(self.token_to_idx[token])
            else:
                fs.append(2)
        while len(fs) < self.max_len:
            fs.append(0)
        return np.array(fs)

    def canonicalize_val(self, val):
        if val < -2.5:
            val = -2.5
        if val > 6.5:
            val = 6.5
        return val
