import functools

import torch
from torch import nn, nn as nn
from torch._C import dtype
from torch.nn import functional as F, Linear
from torch.utils import checkpoint
from torch_scatter import scatter_mean, scatter_add, scatter_max
from torch_scatter.composite import scatter_softmax

from torchdrug import data, layers, utils
from torchdrug.layers import functional

from components.ModelUtils import Transition
from data import ComplexGraph


class EGNNBlock(torch.nn.Module):
    def __init__(self, node_hidden_dims, edge_hidden_dims, normalize_coord, unnormalize_coord,
                 egnn_coords_agg, egnn_normalize, intra_egnn,
                 n_egnn_layer=4, fix_clash=True, clash_step=6,
                 geom_reg_steps=1, geometry_reg_step_size=0.01):
        super(EGNNBlock, self).__init__()
        self.geometry_reg_step_size = geometry_reg_step_size
        self.geom_reg_steps = geom_reg_steps
        self.use_intra_egnn = intra_egnn
        self.fix_clash = fix_clash
        self.clash_step = clash_step
        self.n_egnn_layer = n_egnn_layer

        self.normalize_coord = normalize_coord
        self.unnormalize_coord = unnormalize_coord

        self.tranistion = Transition(hidden_dim=edge_hidden_dims, n=4)

        self.inter_egnn_layers = nn.ModuleList(
            [BiEGCL(in_dim=node_hidden_dims, hid_dim=node_hidden_dims, edge_attr_dim=edge_hidden_dims,
                    activation="SiLU", residual=True, gated=True, normalize=egnn_normalize, coords_agg=egnn_coords_agg,
                    coord_change_maximum=self.normalize_coord(10)) for _ in range(self.n_egnn_layer)]
        )
        self.intra_egnn_layers = nn.ModuleList(
            [EGCL(in_dim=node_hidden_dims, hid_dim=node_hidden_dims,
                  activation="SiLU", residual=True, gated=True, normalize=egnn_normalize, coords_agg="sum",
                  coord_change_maximum=self.normalize_coord(10)) for _ in range(self.n_egnn_layer)]
        )

        self.affinity_graph_layers = nn.ModuleList(
            [BiEGCL(in_dim=node_hidden_dims, hid_dim=node_hidden_dims, edge_attr_dim=edge_hidden_dims,
                    activation="SiLU", residual=True, gated=True, normalize=egnn_normalize, coords_agg=egnn_coords_agg,
                    coord_change_maximum=self.normalize_coord(10)) for _ in range(2)]
        )

        self.confidence_layers = nn.ModuleList(
            [Transition(hidden_dim=node_hidden_dims, n=4),
             Transition(hidden_dim=node_hidden_dims, n=4), ]
        )
        self.confidence_output_layer = nn.Linear(node_hidden_dims, 1)

    def forward(self, p_embed, c_embed, p_coord, c_coord,
                pair_embed, inter_edge_mask, true_c_coord,
                complex_graph: ComplexGraph, metric=None,
                print_trajectory=False):
        if print_trajectory:
            trajectory_list = []

        p2c_edge_list = complex_graph.get_protein_compound_edge(use_complex_index=False)
        c2c_edge_list = complex_graph.get_compound_compound_edge()
        LAS_edge_list = complex_graph.LAS_edge_index

        for i_module in range(self.n_egnn_layer):
            p_embed, c_embed, p_coord, c_coord = \
                self.inter_egnn_layers[i_module](
                    src_node_feat=p_embed,
                    tgt_node_feat=c_embed,
                    src_node_coord=p_coord,
                    tgt_node_coord=c_coord,
                    edge_list=p2c_edge_list[:, inter_edge_mask],
                    edge_attr=pair_embed[inter_edge_mask]
                )
            # Intra Message Passing and Coordinate Update
            if self.use_intra_egnn:
                c_embed, c_coord = self.intra_egnn_layers[i_module](
                    node_feat=c_embed,
                    edge_list=c2c_edge_list,
                    coord=c_coord,
                )
            # LAS Geometry Constraint (Adapted From EquiBind)
            for step in range(self.geom_reg_steps):
                LAS_cur_squared = torch.sum(
                    (c_coord[LAS_edge_list[0]] - c_coord[LAS_edge_list[1]]) ** 2, dim=1)
                LAS_true_squared = torch.sum(
                    (true_c_coord[LAS_edge_list[0]] - true_c_coord[LAS_edge_list[1]]) ** 2,
                    dim=1)
                grad_squared = 2 * (c_coord[LAS_edge_list[0]] - c_coord[LAS_edge_list[1]])
                LAS_force = 2 * (LAS_cur_squared - LAS_true_squared)[:, None] * grad_squared
                LAS_delta_coord = scatter_add(src=LAS_force, index=LAS_edge_list[1], dim=0,
                                              dim_size=complex_graph.compound_node_nums.sum())
                c_coord = c_coord + (LAS_delta_coord * self.geometry_reg_step_size) \
                    .clamp(min=self.normalize_coord(-15), max=self.normalize_coord(15))

            # >>>>>>>Fix Clash
        if self.fix_clash:
            for _ in range(self.clash_step):
                c_coord_ij = c_coord[c2c_edge_list[0]] - c_coord[c2c_edge_list[1]]
                c_distance = c_coord_ij.norm(dim=-1)
                c_f_ij = F.relu(1.22 - self.unnormalize_coord(c_distance))
                c_delta_ij = c_coord_ij * c_f_ij[:, None]
                c_delta_coord = scatter_add(src=c_delta_ij, index=c2c_edge_list[0], dim=0,
                                            dim_size=complex_graph.compound_node_nums.sum())  # Repulsive Force
                c_coord = c_coord + c_delta_coord.clamp(min=self.normalize_coord(-15),
                                                        max=self.normalize_coord(15))
        outputs = {"compound_node_coord": c_coord}
        if metric is not None and "Confidence Loss" in metric.keys():
            for layer in self.affinity_graph_layers:
                p_embed, c_embed = \
                    layer(
                        src_node_feat=p_embed,
                        tgt_node_feat=c_embed,
                        src_node_coord=p_coord,
                        tgt_node_coord=c_coord,
                        edge_list=p2c_edge_list[:, inter_edge_mask],
                        edge_attr=pair_embed[inter_edge_mask],
                        output_coord=False
                    )
            graph_embed = scatter_add(c_embed, index=complex_graph.compound_batch, dim=0,
                                      dim_size=complex_graph.B)
            for i_confidence_layer in range(len(self.confidence_layers)):
                graph_embed = self.confidence_layers[i_confidence_layer](graph_embed)
            confidence = self.confidence_output_layer(graph_embed).squeeze(-1)
            confidence = confidence.sigmoid()
            outputs.update({"confidence": confidence})
        if print_trajectory:
            outputs.update({"trajectory_list": trajectory_list})

        return outputs


class EGCL(nn.Module):
    """
    E(n) Equivariant Convolutional Layer
    """

    def __init__(self, in_dim, hid_dim, node_attr_dim=0, edge_attr_dim=0, activation="ReLU",
                 residual=True, gated=False, normalize=False, coords_agg='mean', tanh=False,
                 coord_change_maximum=10):
        super(EGCL, self).__init__()
        self.coord_change_maximum = coord_change_maximum
        self.residual = residual
        self.gated = gated
        self.normalize = normalize
        self.coords_agg = coords_agg
        self.tanh = tanh
        in_edge_dim = in_dim * 2
        edge_coor_dim = 1

        if isinstance(activation, str):
            self.activation = getattr(nn, activation)()
        else:
            self.activation = activation

        self.edge_mlp = nn.Sequential(
            nn.Linear(in_edge_dim + edge_coor_dim + edge_attr_dim, hid_dim),
            self.activation,
            nn.Linear(hid_dim, hid_dim),
            self.activation
        )

        self.node_mlp = nn.Sequential(
            nn.Linear(hid_dim + in_dim + node_attr_dim, hid_dim),
            self.activation,
            nn.Linear(hid_dim, hid_dim)
        )

        weight_layer = nn.Linear(hid_dim, 1, bias=False)
        torch.nn.init.xavier_uniform_(weight_layer.weight, gain=0.001)
        if self.tanh:
            self.coord_mlp = nn.Sequential(
                nn.Linear(hid_dim, hid_dim),
                self.activation,
                weight_layer,
                nn.Tanh()
            )
        else:
            self.coord_mlp = nn.Sequential(
                nn.Linear(hid_dim, hid_dim),
                self.activation,
                weight_layer
            )

        if self.gated:
            self.gated = nn.Sequential(
                nn.Linear(hid_dim, 1),
                nn.Sigmoid()
            )

    def unsorted_segment_sum(self, data, segment_ids, num_segments):
        result_shape = (num_segments, data.size(1))
        result = data.new_full(result_shape, 0)
        segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
        result.scatter_add_(0, segment_ids, data)

        return result

    def unsorted_segment_mean(self, data, segment_ids, num_segments):
        result_shape = (num_segments, data.size(1))
        result = data.new_full(result_shape, 0)
        count = data.new_full(result_shape, 0)
        segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
        result.scatter_add_(0, segment_ids, data)
        count.scatter_add_(0, segment_ids, torch.ones_like(data))

        return result / count.clamp(min=1)

    def edge_function(self, src_node_feat, tgt_node_feat, radial, edge_attr):
        if edge_attr is None:
            edge_feat_in = torch.cat([src_node_feat, tgt_node_feat, radial], dim=1)
        else:
            edge_feat_in = torch.cat([src_node_feat, tgt_node_feat, radial, edge_attr], dim=1)

        edge_feat_out = self.edge_mlp(edge_feat_in)
        if self.gated:
            att_weight = self.gated(edge_feat_out)
            edge_feat_out = edge_feat_out * att_weight

        return edge_feat_out

    def coord_function(self, coord, edge_list, coord_diff, edge_feat, coord_update_mask):
        # Action on Node Out
        node_in = edge_list[0]
        node_out = edge_list[1]
        weighted_trans = coord_diff * self.coord_mlp(edge_feat)

        if self.coords_agg == 'sum':
            agg_trans = self.unsorted_segment_sum(weighted_trans, node_out, num_segments=coord.size(0))
        elif self.coords_agg == 'mean':
            agg_trans = self.unsorted_segment_mean(weighted_trans, node_out, num_segments=coord.size(0))
        else:
            raise NotImplementedError('Aggregation method {} is not implemented'.format(self.coords_agg))
        if coord_update_mask is not None:
            coord[coord_update_mask] += agg_trans[coord_update_mask].clamp(-self.coord_change_maximum,
                                                                           self.coord_change_maximum)
        else:
            coord += agg_trans.clamp(-self.coord_change_maximum, self.coord_change_maximum)

        return coord

    def node_function(self, node_feat, edge_list, edge_feat, node_attr):
        node_in = edge_list[0]
        node_out = edge_list[1]
        agg_edge_feat = self.unsorted_segment_sum(edge_feat, node_out, num_segments=node_feat.size(0))

        if node_attr is not None:
            node_feat_in = torch.cat([node_feat, agg_edge_feat, node_attr], dim=1)
        else:
            node_feat_in = torch.cat([node_feat, agg_edge_feat], dim=1)

        node_feat_out = self.node_mlp(node_feat_in)
        if self.residual:
            if node_feat.size(1) == node_feat_out.size(1):
                node_feat_out = node_feat + node_feat_out

        return node_feat_out

    def coord2radial(self, edge_list, coord, epsilon=1e-6):
        node_in = edge_list[0]
        node_out = edge_list[1]
        coord_diff = coord[node_out] - coord[node_in]
        radial = torch.sum(coord_diff ** 2, 1).unsqueeze(1)

        if self.normalize:
            norm = torch.sqrt(radial).detach() + epsilon
            coord_diff = coord_diff / norm

        return radial, coord_diff

    def forward(self, edge_list, node_feat, coord, node_attr=None, edge_attr=None,
                output_coord=True, coord_update_mask=None):
        node_in = edge_list[0]
        node_out = edge_list[1]
        radial, coord_diff = self.coord2radial(edge_list, coord)

        edge_feat = self.edge_function(node_feat[node_in], node_feat[node_out], radial, edge_attr)
        coord = self.coord_function(coord, edge_list, coord_diff, edge_feat, coord_update_mask)
        node_feat = self.node_function(node_feat, edge_list, edge_feat, node_attr)
        if output_coord:
            return node_feat, coord
        else:
            return node_feat


class BiEGCL(nn.Module):
    """
    E(n) Equivariant Convolutional Layer
    """

    def __init__(self, in_dim, hid_dim, node_attr_dim=0, edge_attr_dim=0, activation="ReLU",
                 residual=True, gated=False, normalize=False, coords_agg='mean', tanh=False,
                 coord_change_maximum=10):
        super(BiEGCL, self).__init__()
        self.residual = residual
        self.gated = gated
        self.normalize = normalize
        self.coords_agg = coords_agg
        self.tanh = tanh
        self.coord_change_maximum = coord_change_maximum
        in_edge_dim = in_dim * 2
        edge_coor_dim = 1

        if isinstance(activation, str):
            self.activation = getattr(nn, activation)()
        else:
            self.activation = activation

        self.edge_mlp_s2t = nn.Sequential(
            nn.Linear(in_edge_dim + edge_coor_dim + edge_attr_dim, hid_dim),
            self.activation,
            nn.Linear(hid_dim, hid_dim),
            self.activation
        )

        self.edge_mlp_t2s = nn.Sequential(
            nn.Linear(in_edge_dim + edge_coor_dim + edge_attr_dim, hid_dim),
            self.activation,
            nn.Linear(hid_dim, hid_dim),
            self.activation
        )
        if self.gated:
            self.gate_mlp_s2t = nn.Sequential(
                nn.Linear(hid_dim, 1),
                nn.Sigmoid()
            )
            self.gate_mlp_t2s = nn.Sequential(
                nn.Linear(hid_dim, 1),
                nn.Sigmoid()
            )

        self.node_mlp_s = nn.Sequential(
            nn.Linear(hid_dim + in_dim + node_attr_dim, hid_dim),
            self.activation,
            nn.Linear(hid_dim, hid_dim)
        )

        self.node_mlp_t = nn.Sequential(
            nn.Linear(hid_dim + in_dim + node_attr_dim, hid_dim),
            self.activation,
            nn.Linear(hid_dim, hid_dim)
        )

        weight_layer = nn.Linear(hid_dim, 1, bias=False)
        torch.nn.init.xavier_uniform_(weight_layer.weight, gain=0.001)
        if self.tanh:
            self.coord_mlp = nn.Sequential(
                nn.Linear(hid_dim, hid_dim),
                self.activation,
                weight_layer,
                nn.Tanh()
            )
        else:
            self.coord_mlp = nn.Sequential(
                nn.Linear(hid_dim, hid_dim),
                self.activation,
                weight_layer
            )

        if self.coords_agg == "attention":
            # self.attent_mlp = nn.Linear(hid_dim, 1) # TODO: refine
            self.attent_mlp_s2t = nn.Linear(hid_dim, 1)
            self.attent_mlp_t2s = nn.Linear(hid_dim, 1)

    def unsorted_segment_sum(self, data, segment_ids, num_segments):
        result_shape = (num_segments, data.size(1))
        result = data.new_full(result_shape, 0)
        segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
        result.scatter_add_(0, segment_ids, data)

        return result

    def unsorted_segment_mean(self, data, segment_ids, num_segments):
        result_shape = (num_segments, data.size(1))
        result = data.new_full(result_shape, 0)
        count = data.new_full(result_shape, 0)
        segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
        result.scatter_add_(0, segment_ids, data)
        count.scatter_add_(0, segment_ids, torch.ones_like(data))

        return result / count.clamp(min=1)

    def edge_function(self, edge_src_feat, edge_tgt_feat, radial, edge_attr):
        if edge_attr is None:
            edge_feat_in = torch.cat([edge_src_feat, edge_tgt_feat, radial], dim=1)
        else:
            edge_feat_in = torch.cat([edge_src_feat, edge_tgt_feat, radial, edge_attr], dim=1)
        edge_feat_out_s2t = self.edge_mlp_s2t(edge_feat_in)
        edge_feat_out_t2s = self.edge_mlp_t2s(edge_feat_in)
        if self.gated:
            edge_feat_out_s2t = edge_feat_out_s2t * self.gate_mlp_s2t(edge_feat_out_s2t)
            edge_feat_out_t2s = edge_feat_out_t2s * self.gate_mlp_t2s(edge_feat_out_t2s)

        return edge_feat_out_s2t, edge_feat_out_t2s

    def attention_function(self, edge_feat_s2t, edge_feat_t2s, edge_list):
        edge_src = edge_list[0]
        edge_tgt = edge_list[1]
        attent_score_s2t = self.attent_mlp_s2t(edge_feat_s2t)
        attent_weight_s2t = scatter_softmax(attent_score_s2t.squeeze(), index=edge_tgt).unsqueeze(-1)

        attent_score_t2s = self.attent_mlp_t2s(edge_feat_t2s)
        attent_weight_t2s = scatter_softmax(attent_score_t2s.squeeze(), index=edge_src).unsqueeze(-1)

        return attent_weight_s2t, attent_weight_t2s

    def tgt_coord_function(self, tgt_node_coord, edge_list, coord_diff, edge_feat_s2t,
                           attent_weight_s2t=None):
        # Action on Node Out
        edge_src = edge_list[0]
        edge_tgt = edge_list[1]
        weighted_trans = coord_diff * self.coord_mlp(edge_feat_s2t)

        if self.coords_agg == 'sum':
            agg_trans = self.unsorted_segment_sum(weighted_trans, edge_tgt, num_segments=tgt_node_coord.size(0))
        elif self.coords_agg == 'mean':
            agg_trans = self.unsorted_segment_mean(weighted_trans, edge_tgt, num_segments=tgt_node_coord.size(0))
        elif self.coords_agg == "attention":
            weighted_trans = weighted_trans * attent_weight_s2t
            agg_trans = self.unsorted_segment_sum(weighted_trans, edge_tgt, num_segments=tgt_node_coord.size(0))
        else:
            raise NotImplementedError('Aggregation method {} is not implemented'.format(self.coords_agg))
        tgt_node_coord += agg_trans.clamp(-self.coord_change_maximum, self.coord_change_maximum)

        return tgt_node_coord

    def node_function(self, edge_list, src_node_feat, tgt_node_feat, edge_feat_s2t, edge_feat_t2s,
                      attent_weight_s2t=None, attent_weight_t2s=None):
        edge_src = edge_list[0]
        edge_tgt = edge_list[1]
        # TODO: Attention
        if self.coords_agg == "attention":
            edge_feat_s2t = edge_feat_s2t * attent_weight_s2t
            edge_feat_t2s = edge_feat_t2s * attent_weight_t2s
            agg_edge_feat_s2t = self.unsorted_segment_sum(edge_feat_s2t, edge_tgt, num_segments=tgt_node_feat.size(0))
            agg_edge_feat_t2s = self.unsorted_segment_sum(edge_feat_t2s, edge_src, num_segments=src_node_feat.size(0))
        else:
            agg_edge_feat_s2t = self.unsorted_segment_sum(edge_feat_s2t, edge_tgt, num_segments=tgt_node_feat.size(0))
            agg_edge_feat_t2s = self.unsorted_segment_sum(edge_feat_t2s, edge_src, num_segments=src_node_feat.size(0))

        tgt_node_feat_in = torch.cat([tgt_node_feat, agg_edge_feat_s2t], dim=1)
        src_node_feat_in = torch.cat([src_node_feat, agg_edge_feat_t2s], dim=1)

        tgt_node_feat_out = self.node_mlp_t(tgt_node_feat_in)
        src_node_feat_out = self.node_mlp_s(src_node_feat_in)
        if self.residual:
            if tgt_node_feat.size(1) == tgt_node_feat_out.size(1):
                tgt_node_feat_out = tgt_node_feat + tgt_node_feat_out
            if src_node_feat.size(1) == src_node_feat_out.size(1):
                src_node_feat_out = src_node_feat + src_node_feat_out
        return tgt_node_feat_out, src_node_feat_out

    def coord2radial(self, edge_list, src_node_coord, tgt_node_coord, epsilon=1e-6):
        edge_src = edge_list[0]
        edge_tgt = edge_list[1]
        coord_diff = tgt_node_coord[edge_tgt] - src_node_coord[edge_src]
        radial = torch.sum(coord_diff ** 2, 1).unsqueeze(1)

        if self.normalize:
            norm = torch.sqrt(radial).detach() + epsilon
            coord_diff = coord_diff / norm

        return radial, coord_diff

    def forward(self, src_node_feat, tgt_node_feat, src_node_coord, tgt_node_coord,
                edge_list, edge_attr=None, output_coord=True):
        radial, coord_diff = self.coord2radial(edge_list, src_node_coord, tgt_node_coord)

        edge_feat_s2t, edge_feat_t2s = self.edge_function(src_node_feat[edge_list[0]], tgt_node_feat[edge_list[1]],
                                                          radial, edge_attr)
        if self.coords_agg == "attention":
            attent_weight_s2t, attent_weight_t2s = self.attention_function(edge_feat_s2t, edge_feat_t2s, edge_list)
            if output_coord:
                tgt_node_coord = self.tgt_coord_function(tgt_node_coord, edge_list, coord_diff, edge_feat_s2t,
                                                         attent_weight_s2t=attent_weight_s2t)
            tgt_node_feat_out, src_node_feat_out = self.node_function(edge_list, src_node_feat, tgt_node_feat,
                                                                      edge_feat_s2t, edge_feat_t2s,
                                                                      attent_weight_s2t=attent_weight_s2t,
                                                                      attent_weight_t2s=attent_weight_t2s)
        else:
            if output_coord:
                tgt_node_coord = self.tgt_coord_function(tgt_node_coord, edge_list, coord_diff, edge_feat_s2t)
            tgt_node_feat_out, src_node_feat_out = self.node_function(edge_list, src_node_feat, tgt_node_feat,
                                                                      edge_feat_s2t, edge_feat_t2s)
        if output_coord:
            return src_node_feat_out, tgt_node_feat_out, src_node_coord, tgt_node_coord
        else:
            return src_node_feat_out, tgt_node_feat_out
