import argparse
import math
from typing import Dict, List, Optional

import torch
import torch.nn as nn
from torch import Tensor

from transformer_src.transformer_layer import SinusoidalPositionalEmbedding
from gvp_src.features import GVPInputFeaturizer, DihedralFeatures
from gvp_src.gvp_encoder import GVPEncoder
from transformer_src.transformer_layer import TransformerEncoderLayer
from util import nan_to_num, get_rotation_frames, rotate, rbf

class GVPTransformerEncoder(nn.Module):

    def __init__(self, args, dictionary, embed_tokens):
        super().__init__()
        self.args = args
        self.dictionary = dictionary

        self.dropout_module = nn.Dropout(args.dropout)

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = SinusoidalPositionalEmbedding(
            embed_dim,
            self.padding_idx,
        )
        self.embed_gvp_input_features = nn.Linear(15, embed_dim)
        self.embed_confidence = nn.Linear(16, embed_dim)
        self.embed_dihedrals = DihedralFeatures(embed_dim)

        gvp_args = argparse.Namespace()
        for k, v in vars(args).items():
            if k.startswith("gvp_"):
                setattr(gvp_args, k[4:], v)
        self.gvp_encoder = GVPEncoder(gvp_args)
        gvp_out_dim = gvp_args.node_hidden_dim_scalar + (3 *
                gvp_args.node_hidden_dim_vector)  # 100 for the ss_ct_map features
        self.embed_gvp_output = nn.Linear(gvp_out_dim, embed_dim)

        self.layers = nn.ModuleList([])
        self.layers.extend(
            [self.build_encoder_layer(args) for i in range(args.encoder_layers)]
        )
        self.num_layers = len(self.layers)
        self.layer_norm = nn.LayerNorm(embed_dim)

    def build_encoder_layer(self, args):
        return TransformerEncoderLayer(args)

    def forward_embedding(self, coords, adjunct_coords, padding_mask, confidence):
        
        components = dict()
        coord_mask = torch.all(torch.all(torch.isfinite(coords), dim=-1), dim=-1)
        coords = nan_to_num(coords)
        adjunct_coords = nan_to_num(adjunct_coords)
        

        mask_tokens = (
            padding_mask * self.dictionary.padding_idx + 
            ~padding_mask * self.dictionary.get_idx("<mask>")
        )
        
        components["tokens"] = self.embed_tokens(mask_tokens) * self.embed_scale
        components["diherals"] = self.embed_dihedrals(adjunct_coords)

        # GVP encoder
        gvp_out_scalars, gvp_out_vectors = self.gvp_encoder(coords,
                coord_mask, padding_mask, confidence)
        R = get_rotation_frames(coords)
        gvp_out_features = torch.cat([
            gvp_out_scalars,
            rotate(gvp_out_vectors, R.transpose(-2, -1)).flatten(-2, -1),
        ], dim=-1)

        components["gvp_out"] = self.embed_gvp_output(gvp_out_features) #concat gvp_output + ss_features

        components["confidence"] = self.embed_confidence(
             rbf(confidence, 0., 1.))

        scalar_features, vector_features = GVPInputFeaturizer.get_node_features(
            coords, coord_mask, with_coord_mask=False)
        features = torch.cat([
            scalar_features,
            rotate(vector_features, R.transpose(-2, -1)).flatten(-2, -1),
        ], dim=-1)
        components["gvp_input_features"] = self.embed_gvp_input_features(features)

        embed = sum(components.values())
        # for k, v in components.items():
        #     print(k, torch.mean(v, dim=(0,1)), torch.std(v, dim=(0,1)))

        x = embed
        x = x + self.embed_positions(mask_tokens)
        x = self.dropout_module(x)
        return x, components 

    def forward(
        self,
        coords,
        adjunct_coords,
        encoder_padding_mask,
        confidence,
        return_all_hiddens: bool = False,
    ):
        x, encoder_embedding = self.forward_embedding(coords,adjunct_coords, \
                encoder_padding_mask, confidence)
        
        x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))

        x = x.transpose(0, 1)


        encoder_states = []

        if return_all_hiddens:
            encoder_states.append(x)

        for layer in self.layers:
            x = layer(
                x, encoder_padding_mask=encoder_padding_mask
            )
            if return_all_hiddens:
                assert encoder_states is not None
                encoder_states.append(x)
        
        if self.layer_norm is not None:
            x = self.layer_norm(x)
        return {
            "encoder_out": [x],  # T x B x C
            "encoder_padding_mask": [encoder_padding_mask],  # B x T
            "encoder_embedding": [encoder_embedding],  # dictionary
            "encoder_states": encoder_states,  # List[T x B x C]
        }
