# TODO: Avoid to_dense

import torch
import torch.nn as nn
from torch.nn import Linear
from torch_geometric.utils import to_dense_batch

from components.MoleculeEncoder import GIN
from components.ProteinEncoder import GVP_embedding
from components.ModelUtils import RBFDistanceModule, InteractionModule, Transition
from components.EgnnModule import EGNNBlock

import sys

from components.TrioformerModule import TrioformerBlock
from data import ComplexGraph




class Decoder(torch.nn.Module):
    def __init__(self, hidden_dim, normalize_coord, unnormalize_coord, edge_mask_threshold, egnn_block_config):
        super(Decoder, self).__init__()
        self.edge_mask_threshold = edge_mask_threshold
        self.normalize_coord = normalize_coord
        self.unnormalize_coord = unnormalize_coord
        self.egnn_block = EGNNBlock(node_hidden_dims=hidden_dim,
                                    edge_hidden_dims=hidden_dim,  # TODO
                                    normalize_coord=self.normalize_coord,
                                    unnormalize_coord=self.unnormalize_coord,
                                    **egnn_block_config)

    def forward(self, iter_i, p_embed, c_embed,
                p_coords, c_coords, true_c_coords,
                pair_embed, complex_graph,
                metric, print_trajectory=False, ):
        if print_trajectory:
            global_trajectory_list = []
        for i in range(iter_i):
            with torch.no_grad():
                inter_edge_list = complex_graph.get_protein_compound_edge(use_complex_index=False)
                inter_dist = (p_coords[inter_edge_list[0]] - c_coords[inter_edge_list[1]]).norm(
                    dim=-1).detach()
                inter_edge_mask = inter_dist < self.normalize_coord(self.edge_mask_threshold)
                if inter_edge_mask.sum() == 0:
                    random_indices = torch.multinomial(torch.ones(len(inter_edge_mask)), len(inter_edge_mask) // 5)
                    inter_edge_mask[random_indices] = True
            if i < iter_i - 1:
                with torch.no_grad():
                    outputs = self.egnn_block(p_embed, c_embed,
                                              p_coords, c_coords,
                                              pair_embed, inter_edge_mask,
                                              true_c_coord=true_c_coords,
                                              complex_graph=complex_graph,
                                              metric=None,
                                              print_trajectory=print_trajectory)
                    c_coords = outputs["compound_node_coord"]
                    if print_trajectory:
                        global_trajectory_list.extend(outputs["trajectory_list"])
            else:
                outputs = self.egnn_block(p_embed,
                                          c_embed,
                                          p_coords,
                                          c_coords,
                                          pair_embed,
                                          inter_edge_mask,
                                          true_c_coord=true_c_coords,
                                          complex_graph=complex_graph,
                                          metric=metric,
                                          print_trajectory=print_trajectory)

                if "compound_node_coord" in outputs.keys():
                    outputs["compound_node_coord"] = self.unnormalize_coord(outputs["compound_node_coord"])
                if print_trajectory:
                    global_trajectory_list.extend(outputs["trajectory_list"])
                    # global_trajectory_list = global_trajectory_list[:100]
                    for i_trajectory in range(len(global_trajectory_list)):
                        global_trajectory_list[i_trajectory] = self.unnormalize_coord(
                            global_trajectory_list[i_trajectory])
                    outputs["trajectory_list"] = global_trajectory_list
        return outputs


class AffinityDecoder(torch.nn.Module):
    def __init__(self, hidden_dim):
        super(AffinityDecoder, self).__init__()
        self.c_transition = Transition(hidden_dim, 4)
        self.linear = Linear(hidden_dim, 1)

    def forward(self, c_embed_batched, c_mask):
        c_embed_batched = self.c_transition(c_embed_batched) * c_mask.unsqueeze(-1)
        c_embed_batched = c_embed_batched.sum(-2)
        affinity = self.linear(c_embed_batched).squeeze(-1)
        return affinity


class Encoder(torch.nn.Module):
    def __init__(self, hidden_dim, normalize_coord, unnormalize_coord, n_stack=5, add_noise=True):
        super(Encoder, self).__init__()
        self.n_stack = n_stack
        self.unnormalize_coord = unnormalize_coord
        self.normalize_coord = normalize_coord
        self.hidden_dim = hidden_dim
        self.add_noise = add_noise

        f = self.normalize_coord
        self.p_p_dist_layer = RBFDistanceModule(rbf_stop=f(32), distance_hidden_dim=hidden_dim, num_gaussian=32)
        self.c_c_dist_layer = RBFDistanceModule(rbf_stop=f(16), distance_hidden_dim=hidden_dim, num_gaussian=32)
        self.inter_layer = InteractionModule(hidden_dim, hidden_dim, 32)
        self.trioformer_blocks = nn.ModuleList(
            [TrioformerBlock(hidden_dim, hidden_dim, hidden_dim) for _ in range(n_stack)]
        )

    def forward(self, p_embed, c_embed, p_coord, c_coord,
                true_c_coord, complex_graph: ComplexGraph):
        p_embed_batched, p_mask = to_dense_batch(p_embed, complex_graph.protein_batch)  # (B, Np_max, E)
        c_embed_batched, c_mask = to_dense_batch(c_embed, complex_graph.compound_batch)  # (B, Nc_max, E)
        pair_embed_batched, pair_mask = self.inter_layer(p_embed_batched, c_embed_batched, p_mask, c_mask)

        p_coord_batched, p_coord_mask = to_dense_batch(p_coord, complex_graph.protein_batch)  # (B, Np_max, 3)
        c_coord_batched, c_coord_mask = to_dense_batch(c_coord, complex_graph.compound_batch)  # (B, Np_max, 3)
        p_p_dist = torch.cdist(p_coord_batched, p_coord_batched, compute_mode='donot_use_mm_for_euclid_dist')
        c_c_dist = torch.cdist(c_coord_batched, c_coord_batched, compute_mode='donot_use_mm_for_euclid_dist')
        if self.add_noise:
            p_p_dist += self.normalize_coord(0.2) * torch.randn(p_p_dist.shape, device=p_p_dist.device)
            c_c_dist += self.normalize_coord(0.1) * torch.randn(c_c_dist.shape, device=c_c_dist.device)
        p_p_dist_mask = torch.einsum("...i, ...j->...ij", p_coord_mask, p_coord_mask)
        c_c_diag_mask = torch.diag_embed(c_coord_mask)  # (B, Nc, Nc)
        c_c_dist_mask = torch.logical_or(complex_graph.LAS_mask, c_c_diag_mask)
        p_p_dist[~p_p_dist_mask] = 1e6
        c_c_dist[~c_c_dist_mask] = 1e6
        p_p_dist_embed = self.p_p_dist_layer(p_p_dist)
        c_c_dist_embed = self.c_c_dist_layer(c_c_dist)

        for i_block in self.trioformer_blocks:
            p_embed_batched, c_embed_batched, pair_embed_batched = i_block(
                p_embed_batched, p_mask,
                c_embed_batched, c_mask,
                pair_embed_batched, pair_mask,
                p_p_dist_embed, c_c_dist_embed
            )

        p_embed = p_embed_batched[p_mask]
        c_embed = c_embed_batched[c_mask]
        pair_embed = pair_embed_batched[pair_mask]

        return p_embed, p_embed_batched, p_mask, \
               c_embed, c_embed_batched, c_mask, \
               pair_embed, pair_embed_batched, pair_mask


# How to do Edge Embed -> Node Embed


class IterativeRefinement(torch.nn.Module):
    def __init__(self, hidden_dim=128, coordinate_scale=5, edge_mask_threshold=10,
                 egnn_block_config=None, trioformer_config=None, affinity_only=False):
        super().__init__()
        self.affinity_only = affinity_only
        self.coordinate_scale = coordinate_scale
        self.normalize_coord = lambda x: x / self.coordinate_scale
        self.unnormalize_coord = lambda x: x * self.coordinate_scale

        self.p_encoder = GVP_embedding((6, 3), (hidden_dim, 16),
                                       (32, 1), (32, 1), seq_in=True)

        self.c_encoder = GIN(input_dim=56, hidden_dims=[128, 56, hidden_dim], edge_input_dim=19,
                             concat_hidden=False)

        self.trioformer_encoder = Encoder(hidden_dim,
                                          normalize_coord=self.normalize_coord,
                                          unnormalize_coord=self.unnormalize_coord,
                                          **trioformer_config)

        self.egnn_decoder = Decoder(hidden_dim,
                                    normalize_coord=self.normalize_coord,
                                    unnormalize_coord=self.unnormalize_coord,
                                    edge_mask_threshold=edge_mask_threshold,
                                    egnn_block_config=egnn_block_config)


    def _prep_input(self, data):
        # Encode Protein
        p_coords = self.normalize_coord(data['protein'].coords)
        p_node_feature = (data['protein']['node_s'], data['protein']['node_v'])
        p_edge_index = data[("protein", "p2p", "protein")]["edge_index"]
        p_edge_feature = (data[("protein", "p2p", "protein")]["edge_s"], data[("protein", "p2p", "protein")]["edge_v"])
        p_batch = data['protein'].batch
        p_node_nums = data["protein"].ptr[1:] - data["protein"].ptr[:-1]
        p_embed = self.p_encoder(p_node_feature, p_edge_index, p_edge_feature, data.seq)

        # Encode Compound
        # c_coords = data['compound'].true_coords
        c_coords = self.normalize_coord(data['compound'].init_coords.float())  # TODO
        true_c_coords = self.normalize_coord(data['compound'].true_coords.float())
        rdkit_c_coords = self.normalize_coord(data['compound'].rdkit_coords.float())
        c_feature = data['compound'].x.float()
        c_edge_index = data[("compound", "c2c", "compound")].edge_index
        c_edge_feature = data[("compound", "c2c", "compound")].edge_attr
        c_edge_weight = data[("compound", "c2c", "compound")].edge_weight
        c_batch = data['compound'].batch
        c_node_nums = data["compound"].ptr[1:] - data["compound"].ptr[:-1]
        c_embed = self.c_encoder(c_edge_index.T, c_edge_weight, c_edge_feature,
                                 c_feature.shape[0], c_feature)['node_feature']

        # LAS
        _LAS_edge_index = data[("compound", "LAS", "compound")].edge_index
        # Complex
        complex_graph = ComplexGraph(p_node_nums, c_node_nums, _LAS_edge_index)
        complex_graph.set_compound_edge(c_edge_index, c_edge_feature)

        return p_embed, c_embed, p_coords, c_coords, true_c_coords, rdkit_c_coords, complex_graph

    def forward(self, data, iter_i, metric=None, print_trajectory=False):
        p_embed, c_embed, p_coords, c_coords, true_c_coords, rdkit_c_coords, complex_graph = self._prep_input(data)

        p_embed, p_embed_batched, p_mask, \
        c_embed, c_embed_batched, c_mask, \
        pair_embed, pair_embed_batched, pair_mask = self.trioformer_encoder(p_embed, c_embed,
                                                                            p_coords, c_coords,
                                                                            true_c_coord=true_c_coords,
                                                                            complex_graph=complex_graph
                                                                            )
        outputs = {}
        coord_outputs = self.egnn_decoder(iter_i, p_embed, c_embed,
                                          p_coords, c_coords, true_c_coords,
                                          pair_embed, complex_graph,
                                          metric, print_trajectory
                                          )
        outputs.update(coord_outputs)

        return outputs
