import torch.nn as nn
import torch
import lightgbm as lgb
from lightgbm import log_evaluation
import numpy as np


class DrugPropertyModel(nn.Module):

    def __init__(self, drug_size, model, hidden_size, num_hidden_layers, task_type, num_class, tokenizer=None):
        super(DrugPropertyModel, self).__init__()
        self.model = model
        self.tokenizer = tokenizer
        if num_hidden_layers < 0:
            self.drug = [nn.Linear(drug_size, num_class)]
        else:
            self.drug = [nn.Linear(drug_size, hidden_size), torch.nn.ReLU()]
            for i in range(num_hidden_layers):
                self.drug.append(nn.Linear(hidden_size, hidden_size))
                self.drug.append(torch.nn.ReLU())
            if task_type == 'regression':
                self.drug.append(nn.Linear(hidden_size, 1))
            else:
                # classification
                self.drug.append(nn.Linear(hidden_size, num_class))
        self.drug = torch.nn.ModuleList(self.drug)

    def forward(self,  input_ids, attention_mask):
        outputs = self.model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_states = outputs.hidden_states[-1]
        drug = last_hidden_states[:, 0]
        c = drug
        for m in self.drug:
            c = m(c)
        return c


class DrugPropertyTree(nn.Module):

    def __init__(self, drug_size, model, hidden_size, num_hidden_layers, task_type, num_class, tree_depth):
        super(DrugPropertyTree, self).__init__()
        self.model = model
        if num_hidden_layers < 0:
            drug = [nn.Linear(drug_size, num_class)]
        else:
            drug = [nn.Linear(drug_size, hidden_size), torch.nn.ReLU()]
            for i in range(num_hidden_layers):
                drug.append(nn.Linear(hidden_size, hidden_size))
                drug.append(torch.nn.ReLU())
            if task_type == 'regression':
                drug.append(nn.Linear(hidden_size, 1))
            else:
                # classification
                drug.append(nn.Linear(hidden_size, num_class))
        drug = torch.nn.ModuleList(drug)
        self.tree_depth = tree_depth
        self.drug = torch.nn.ModuleList(drug)

    def forward(self,  input_ids, attention_mask):
        outputs = self.model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_states = outputs.hidden_states[self.tree_depth]
        drug = last_hidden_states[:, 0]
        c = drug
        for m in self.drug:
            c = m(c)
        return c


class DrugPropertyModelPooling(nn.Module):

    def __init__(self, drug_size, model, hidden_size, num_hidden_layers, task_type, num_class):
        super(DrugPropertyModelPooling, self).__init__()
        self.model = model
        if num_hidden_layers < 0:
            self.drug = [nn.Linear(drug_size, num_class)]
        else:
            self.drug = [nn.Linear(drug_size, hidden_size), torch.nn.ReLU()]
            for i in range(num_hidden_layers):
                self.drug.append(nn.Linear(hidden_size, hidden_size))
                self.drug.append(torch.nn.ReLU())
            if task_type == 'regression':
                self.drug.append(nn.Linear(hidden_size, 1))
            else:
                # classification
                self.drug.append(nn.Linear(hidden_size, num_class))
        self.drug = torch.nn.ModuleList(self.drug)

    def forward(self,  input_ids, attention_mask):
        outputs = self.model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_states = outputs.hidden_states[-1]
        attention_mask = attention_mask.unsqueeze(-1)
        drug = (last_hidden_states*attention_mask).sum(1)/attention_mask.sum(1)
        c = drug
        for m in self.drug:
            c = m(c)
        return c


class DrugPropertyModelPoolingDescriptors(nn.Module):

    def __init__(self, drug_size, model, hidden_size, num_hidden_layers, task_type, num_class, descriptor_size):
        super(DrugPropertyModelPoolingDescriptors, self).__init__()
        # drug_size = drug_size + descriptor_size
        drug_size = descriptor_size
        self.model = model
        if num_hidden_layers < 0:
            self.drug = [nn.Linear(drug_size, num_class)]
        else:
            self.drug = [nn.Linear(drug_size, hidden_size), torch.nn.ReLU()]
            for i in range(num_hidden_layers):
                self.drug.append(nn.Linear(hidden_size, hidden_size))
                self.drug.append(torch.nn.ReLU())
            if task_type == 'regression':
                self.drug.append(nn.Linear(hidden_size, 1))
            else:
                # classification
                self.drug.append(nn.Linear(hidden_size, num_class))
        self.drug = torch.nn.ModuleList(self.drug)

    def forward(self,  input_ids, attention_mask, descriptors):
        # outputs = self.model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        # last_hidden_states = outputs.hidden_states[-1]
        # attention_mask = attention_mask.unsqueeze(-1)
        # drug = (last_hidden_states*attention_mask).sum(1)/attention_mask.sum(1)
        # c = torch.cat([drug, descriptors], dim=-1)
        c = descriptors
        for m in self.drug:
            c = m(c)
        return c


class DrugPropertyModelPoolingBDB(nn.Module):

    def __init__(self, drug_size, model, hidden_size, num_hidden_layers, task_type, num_class):
        super(DrugPropertyModelPoolingBDB, self).__init__()
        self.model = model
        if num_hidden_layers < 0:
            self.drug = [nn.Linear(drug_size*4, num_class)]
        else:
            self.drug = [nn.Linear(drug_size*4, hidden_size), torch.nn.ReLU()]
            for i in range(num_hidden_layers):
                self.drug.append(nn.Linear(hidden_size, hidden_size))
                self.drug.append(torch.nn.ReLU())
            if task_type == 'regression':
                self.drug.append(nn.Linear(hidden_size, 1))
            else:
                # classification
                self.drug.append(nn.Linear(hidden_size, num_class))
        self.drug = torch.nn.ModuleList(self.drug)

    def forward(self,  input_ids, attention_mask,
                input_ids_b1, attention_mask_b1,
                input_ids_b2, attention_mask_b2,
                input_ids_b3, attention_mask_b3):
        outputs = self.model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        outputs_b1 = self.model.model(input_ids=input_ids_b1, attention_mask=attention_mask_b1, output_hidden_states=True)
        outputs_b2 = self.model.model(input_ids=input_ids_b2, attention_mask=attention_mask_b2, output_hidden_states=True)
        outputs_b3 = self.model.model(input_ids=input_ids_b3, attention_mask=attention_mask_b3, output_hidden_states=True)
        last_hidden_states = outputs.hidden_states[-1]
        last_hidden_states_b1 = outputs_b1.hidden_states[-1]
        last_hidden_states_b2 = outputs_b2.hidden_states[-1]
        last_hidden_states_b3 = outputs_b3.hidden_states[-1]
        attention_mask = attention_mask.unsqueeze(-1)
        drug = (last_hidden_states*attention_mask).sum(1)/attention_mask.sum(1)
        attention_mask_b1 = attention_mask_b1.unsqueeze(-1)
        drug_b1 = (last_hidden_states_b1 * attention_mask_b1).sum(1) / attention_mask_b1.sum(1)
        attention_mask_b2 = attention_mask_b2.unsqueeze(-1)
        drug_b2 = (last_hidden_states_b2 * attention_mask_b2).sum(1) / attention_mask_b2.sum(1)
        attention_mask_b3 = attention_mask_b3.unsqueeze(-1)
        drug_b3 = (last_hidden_states_b3 * attention_mask_b3).sum(1) / attention_mask_b3.sum(1)
        c = torch.cat([drug, drug_b1,  drug_b2,  drug_b3], dim=-1)
        for m in self.drug:
            c = m(c)
        return c


class DrugPropertyModelPoolingGapMind(nn.Module):

    def __init__(self, drug_size, model, hidden_size, num_hidden_layers, task_type, num_class, num_tokens):
        super(DrugPropertyModelPoolingGapMind, self).__init__()
        self.model = model
        self.lm_head = nn.Linear(drug_size, num_tokens)
        if num_hidden_layers < 0:
            self.drug = [nn.Linear(drug_size, num_class)]
        else:
            self.drug = [nn.Linear(drug_size, hidden_size), torch.nn.ReLU()]
            for i in range(num_hidden_layers):
                self.drug.append(nn.Linear(hidden_size, hidden_size))
                self.drug.append(torch.nn.ReLU())
            if task_type == 'regression':
                self.drug.append(nn.Linear(hidden_size, 1))
            else:
                # classification
                self.drug.append(nn.Linear(hidden_size, num_class))
        self.drug = torch.nn.ModuleList(self.drug)

    def forward(self,  input_ids, attention_mask, masked_ids):
        outputs = self.model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_states = outputs.hidden_states[-1]
        attention_mask = attention_mask.unsqueeze(-1)
        drug = (last_hidden_states*attention_mask).sum(1)/attention_mask.sum(1)
        masked_ids = masked_ids.unsqueeze(-1).expand(-1, -1, last_hidden_states.shape[-1])
        token_predictions = self.lm_head(torch.gather(last_hidden_states, 1, masked_ids))
        c = drug
        for m in self.drug:
            c = m(c)
        return c, token_predictions


class DrugPropertyModelMF(nn.Module):

    def __init__(self, drug_size, hidden_size, num_hidden_layers, task_type, num_class):
        super(DrugPropertyModelMF, self).__init__()
        self.drug = [nn.Linear(drug_size, hidden_size), torch.nn.ReLU()]
        for i in range(num_hidden_layers):
            self.drug.append(nn.Linear(hidden_size, hidden_size))
            self.drug.append(torch.nn.ReLU())
        if task_type == 'regression':
            self.drug.append(nn.Linear(hidden_size, 1))
        else:
            # classification
            self.drug.append(nn.Linear(hidden_size, num_class))
        self.drug = torch.nn.ModuleList(self.drug)

    def forward(self, x):
        c = x
        for m in self.drug:
            c = m(c)
        return c


class DrugPropertyModelMorganGen(nn.Module):

    def __init__(self, drug_size, model, hidden_size, num_hidden_layers, task_type, num_class):
        super(DrugPropertyModelMorganGen, self).__init__()
        self.model = model
        self.drug = [nn.Linear(drug_size, hidden_size), torch.nn.ReLU()]
        for i in range(num_hidden_layers):
            self.drug.append(nn.Linear(hidden_size, hidden_size))
            self.drug.append(torch.nn.ReLU())
        if task_type == 'regression':
            self.drug.append(nn.Linear(hidden_size, 1))
        else:
            # classification
            self.drug.append(nn.Linear(hidden_size, num_class))
        self.drug = torch.nn.ModuleList(self.drug)

    def forward(self,  input_ids, attention_mask):
        outputs = self.model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_states = outputs.last_hidden_state
        drug = last_hidden_states[:, 0]
        c = drug
        for m in self.drug:
            c = m(c)
        return c


class DrugPropertyModelMorganCompact(nn.Module):

    def __init__(self, drug_size, model, hidden_size, num_hidden_layers, task_type, num_class):
        super(DrugPropertyModelMorganCompact, self).__init__()
        self.model = model
        self.drug = [nn.Linear(drug_size, hidden_size), torch.nn.ReLU()]
        for i in range(num_hidden_layers):
            self.drug.append(nn.Linear(hidden_size, hidden_size))
            self.drug.append(torch.nn.ReLU())
        if task_type == 'regression':
            self.drug.append(nn.Linear(hidden_size, 1))
        else:
            # classification
            self.drug.append(nn.Linear(hidden_size, num_class))
        self.drug = torch.nn.ModuleList(self.drug)

    def forward(self,  input_ids, attention_mask):
        outputs = self.model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_states = outputs.hidden_states[-1]
        drug = last_hidden_states[:, 0]
        c = drug
        for m in self.drug:
            c = m(c)
        return c


class DrugPropertyModelMorganGenPooling(nn.Module):

    def __init__(self, drug_size, model, hidden_size, num_hidden_layers, task_type, num_class):
        super(DrugPropertyModelMorganGenPooling, self).__init__()
        self.model = model
        self.drug = [nn.Linear(drug_size, hidden_size), torch.nn.ReLU()]
        for i in range(num_hidden_layers):
            self.drug.append(nn.Linear(hidden_size, hidden_size))
            self.drug.append(torch.nn.ReLU())
        if task_type == 'regression':
            self.drug.append(nn.Linear(hidden_size, 1))
        else:
            # classification
            self.drug.append(nn.Linear(hidden_size, num_class))
        self.drug = torch.nn.ModuleList(self.drug)

    def forward(self,  input_ids, attention_mask):
        outputs = self.model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_states = outputs.last_hidden_state
        drug = torch.mean(last_hidden_states, dim=1)
        c = drug
        for m in self.drug:
            c = m(c)
        return c


class DrugPropertyModelLightGBM:

    def __init__(self, num_leaves, learning_rate, feature_fraction, bagging_fraction, task_type, bagging_freq,
                 shuffle=True):
        self.task_type = task_type
        self.shuffle = shuffle
        if task_type == 'regression':
            metric = 'l1'
        else:
            metric = 'auc'
            task_type = 'binary'

        self.params = {
            'boosting_type': 'gbdt',
            'objective': task_type,
            'metric': metric,
            'num_leaves': num_leaves,
            'learning_rate': learning_rate,
            'feature_fraction': feature_fraction,
            'bagging_fraction': bagging_fraction,
            'bagging_freq': bagging_freq,
            'verbose': -1,
            'early_stopping_rounds': 3,
        }

    def train(self, lgb_train):
        print('Starting training...')
        gbm = lgb.cv(self.params, lgb_train, nfold=5, num_boost_round=2000,
                     shuffle=self.shuffle, callbacks=[log_evaluation(period=5)],
                     stratified=False,)

        if self.task_type == 'regression':
            optimal_num_round = np.argmin(gbm['valid l1-mean'])
            best_val_score = gbm['valid l1-mean'][optimal_num_round]
        else:
            optimal_num_round = np.argmax(gbm['valid auc-mean'])
            best_val_score = -gbm['valid auc-mean'][optimal_num_round]
        if optimal_num_round == 0:
            return None, best_val_score
        # train
        gbm = lgb.train(self.params,
                        lgb_train,
                        num_boost_round=optimal_num_round,
                        valid_sets=[lgb_train],
                        callbacks=[lgb.early_stopping(stopping_rounds=5),
                                   log_evaluation(period=10)])

        return gbm, best_val_score

    def test(self, X_test, gbm):
        y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
        y_pred = torch.tensor(y_pred)
        return y_pred

