"""added 23-04-27
"""

import torch
import torch.nn as nn

from myopenfold.model.primitives import Linear, LayerNorm

from myopenfold.utils.precision_utils import is_fp16_enabled

class EvoformerEmbeddingHead(nn.Module):
    def __init__(self, config, depth=0, log=False):
        super(EvoformerEmbeddingHead, self).__init__()

        self.depth=depth
        self.log=log

        self.config = config

        self.c_z = config.c_z
        self.no_bins = config.no_bins
        self.n_heads = config.no_heads

        self.projection_heads = nn.ModuleList()

        for i in range(self.n_heads):
            linear = Linear(self.c_z, self.no_bins, init="final")
            self.projection_heads.append(linear)
        
    def forward(self, emb_list):
        pred = {}

        for i in range(self.n_heads):
            pred[i] = self.projection_heads[i](emb_list[i][1])

        return pred
    

class EvoformerEmbeddingHead2(nn.Module):
    def __init__(self, config, depth=0, log=False):
        super(EvoformerEmbeddingHead2, self).__init__()

        self.depth=depth
        self.log=log

        self.config = config

        self.c_z = config.c_z
        self.no_bins = config.no_bins
        self.n_heads = config.no_heads

        self.projection_heads = nn.ModuleList()

        for i in range(self.n_heads):
            linear = Linear(self.c_z, self.no_bins, init="final")
            self.projection_heads.append(linear)
        
    def forward(self, emb_list):
        pred = {}

        for i in range(self.n_heads):
            logits = self.projection_heads[i](emb_list[i][1])
            pred[i] = logits + logits.transpose(-2, -3)

        return pred
    

class EvoformerEmbeddingHead3(nn.Module):
    def __init__(self, config, depth=0, log=False):
        super(EvoformerEmbeddingHead3, self).__init__()

        self.depth=depth
        self.log=log

        self.config = config

        self.c_z = config.c_z
        self.no_bins = config.no_bins
        self.c_hidden = config.c_hidden

        self.n_heads = config.no_heads

        self.projection_heads = nn.ModuleList()

        for i in range(self.n_heads):
            head = nn.Sequential(
                Linear(self.c_z, self.c_hidden, init="relu"),
                nn.ReLU(),
                Linear(self.c_hidden, self.c_hidden, init="relu"),
                nn.ReLU(),
                Linear(self.c_hidden, self.no_bins, init="final"),
            )
            self.projection_heads.append(head)
        
    def forward(self, emb_list):
        pred = {}

        for i in range(self.n_heads):
            logits = self.projection_heads[i](emb_list[i][1])
            pred[i] = logits + logits.transpose(-2, -3)

        return pred