"""
"""

import pickle
import numpy as np
import torch
import json
from torch.utils.data import Dataset, DataLoader
from abc import ABC, abstractmethod
import torch
import torch.distributions.studentT as studentT
import torch.nn.functional as F
import torch.nn as nn
##################################################################################################################
#
# Abstract classes for torch models.
#
##################################################################################################################

class AbstractMoleculeEncoder(ABC, torch.nn.Module):
    """
    Abstract base class of molecule embedding models.
    """

    def forward(self, src):
        emb = None
        return emb


class AbstractProteinEncoder(ABC, torch.nn.Module):
    """
    Abstract base class of protein embedding models.
    """

    def forward(self, src):
        emb = None
        return emb


class AbstractInteractionModel(ABC, torch.nn.Module):
    """
    Abstract base class of drug-target interaction models.
    """
    def forward(self, protein_emb, drug_emb):
        prediction = None
        return prediction


class AbstractDTIModel(ABC, torch.nn.Module):
    def __init__(self):
        super(AbstractDTIModel, self).__init__()
        self.protein_encoder = AbstractMoleculeEncoder()
        self.smiles_encoder = AbstractProteinEncoder()
        self.interaction_predictor = AbstractInteractionModel()

    def forward(self, d, p):
        """
        Args:
            d(Tensor) : Preprocessed drug input batch
            p(Tensor) : Preprocessed protein input batch

            both d and p contains Long elements representing the token,
            such as
            ["C", "C", "O", "H"] -> Tensor([4, 4, 5, 7])
            ["P, K"] -> Tensor([12, 8])

        Return:
            (Tensor) [batch_size, 1]: predicted affinity value
        """
        p_emb = self.protein_encoder(p)
        d_emb = self.smiles_encoder(d)

        return self.interaction_predictor(p_emb, d_emb)

##################################################################################################
##################################################################################################
##################################################################################################

class MLPMixedDTI(torch.nn.Module):
    def __init__(self, token_len = 64+25+3):
        super(MLPMixedDTI, self).__init__()
        self.token_len = token_len
        self.seqlen = 1286
        self.channel_dim = 128
        self.hidden_dim = 512
        self.embedding = torch.nn.Embedding(token_len, self.channel_dim) #* VOCALEN: chem(64), prot(25), unk(1), sep(1), pedding(1)
        self.fc1 = torch.nn.Linear(self.seqlen, self.hidden_dim) #* VOCALEN: chem, prot, unk, sep

    def forward(self, d, p):
        sep_token = ( torch.ones((d.size(0), 1)).to(d.device)*(self.token_len) ).long()
        print(sep_token[0])
        cat = torch.cat((d, sep_token, p), dim=1)
        print(cat[0])
        out = self.embedding(cat)
        print(out.shape)
        out = out.transpose(1,2)
        print(out.shape)
        out = self.fc1(out)
        print(out.shape)


class SMILESEncoder(AbstractMoleculeEncoder):
    def __init__(self, smile_len=64+1, latent_len=128): ## +1 for 0 padding
        super(SMILESEncoder, self).__init__()
        self.encoder = torch.nn.Embedding(smile_len, latent_len)
        self.conv1 = torch.nn.Conv1d(latent_len, 32, 4)
        self.conv2 = torch.nn.Conv1d(32, 64, 6)
        self.conv3 = torch.nn.Conv1d(64, 96, 8)

    def forward(self, src):
        emb = self.encoder(src)
        conv1 = torch.nn.ReLU()(self.conv1(emb.transpose(1,2)))
        conv2 = torch.nn.ReLU()(self.conv2(conv1))
        conv3 = torch.nn.ReLU()(self.conv3(conv2))

        return torch.max(conv3, 2)[0]


class ProteinEncoder(AbstractProteinEncoder):
    def __init__(self, protein_len=25+1, latent_len=128): ## +1 for 0 padding
        super(ProteinEncoder, self).__init__()
        self.encoder = torch.nn.Embedding(protein_len, latent_len)
        self.conv1 = torch.nn.Conv1d(latent_len, 32, 4)
        self.conv2 = torch.nn.Conv1d(32, 64, 8)
        self.conv3 = torch.nn.Conv1d(64, 96, 12)

    def forward(self, src):
        emb = self.encoder(src)
        conv1 = torch.nn.ReLU()(self.conv1(emb.transpose(1,2)))
        conv2 = torch.nn.ReLU()(self.conv2(conv1))
        conv3 = torch.nn.ReLU()(self.conv3(conv2))

        return torch.max(conv3, 2)[0]

class InteractionPredictor(AbstractInteractionModel):
    def __init__(self, input_dim):
        super(InteractionPredictor, self).__init__()
        self.fully1 = torch.nn.Linear(input_dim, 1024)
        self.fully2 = torch.nn.Linear(1024, 1024)
        self.fully3 = torch.nn.Linear(1024, 512)
        self.output = torch.nn.Linear(512, 1)
        self.dropout = torch.nn.Dropout(0.1)


    def forward(self, protein_emb, drug_emb):
        src = torch.cat((protein_emb, drug_emb), 1)
        fully1 = torch.nn.ReLU()(self.fully1(src))
        fully1 = self.dropout(fully1)
        fully2 = torch.nn.ReLU()(self.fully2(fully1))
        fully2 = self.dropout(fully2)
        fully3 = torch.nn.ReLU()(self.fully3(fully2))
        return self.output(fully3)


class DeepDTA(AbstractDTIModel):
    """
    The final DeepDTA model includes the protein encoding model;
    the smiles(drug; chemical) encoding model; the interaction model.
    """
    def __init__(self, concat_dim=96*2):
        super(DeepDTA, self).__init__()
        self.protein_encoder = ProteinEncoder()
        self.smiles_encoder = SMILESEncoder()
        self.interaction_predictor = InteractionPredictor(concat_dim)
    def forward(self, d, p):
        """
        Args:
            d(Tensor) : Preprocessed drug input batch
            p(Tensor) : Preprocessed protein input batch

            both d and p contains Long elements representing the token,
            such as
            ["C", "C", "O", "H"] -> Tensor([4, 4, 5, 7])
            ["P, K"] -> Tensor([12, 8])

        Return:
            (Tensor) [batch_size, 1]: predicted affinity value
        """
        p_emb = self.protein_encoder(p)
        d_emb = self.smiles_encoder(d)

        return self.interaction_predictor(p_emb, d_emb)

    def train_dropout(self):
        def turn_on_dropout(m):
            if type(m) == torch.nn.modules.dropout.Dropout:
                m.train()
        self.apply(turn_on_dropout)
