import torch
from torch import nn
import numpy as np

from egnn.egnn_new import EGNN
from train_test import check_mask_correct
from bond_type_prediction.losses import adjacency_matrix_loss, atom_types_and_formal_charges_loss
from bond_type_prediction.utils import compute_class_weight


class EGNNEdgeModel(nn.Module):
    def __init__(self, in_node_nf,
                 in_edge_nf=1, hidden_nf=64, device='cpu',
                 act_fn=torch.nn.SiLU(), n_layers=4, attention=False,
                 norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, norm_constant=1,
                 inv_sublayers=2, sin_embedding=False, normalization_factor=1, aggregation_method='sum',
                 include_charges=True, encoder='egnn', edge_head='mlp', edge_head_hidden_dim=64, included_species=None,
                 n_classes=5, modify_h=False, joint_training=False, condition_time=False):
        super().__init__()
        self.encoder = encoder
        if encoder == 'egnn':
            # in_edge_nf is not used
            self.egnn = EGNN(
                in_node_nf=in_node_nf, in_edge_nf=in_edge_nf,
                hidden_nf=hidden_nf, device=device, act_fn=act_fn,
                n_layers=n_layers, attention=attention, norm_diff=norm_diff, 
                out_node_nf=out_node_nf, tanh=tanh, coords_range=coords_range,
                norm_constant=norm_constant,
                inv_sublayers=inv_sublayers, sin_embedding=sin_embedding,
                normalization_factor=normalization_factor,
                aggregation_method=aggregation_method)
        
        self.in_node_nf = in_node_nf
        if out_node_nf is None:
            out_node_nf = in_node_nf
        self.out_node_nf = out_node_nf
        self.device = device
        self.dtype = torch.float32
        self.include_charges = include_charges
        self._edges_dict = {}
        self.included_species = included_species
        self.modify_h = modify_h
        self.joint_training = joint_training
        self.condition_time = condition_time

        if edge_head == 'linear':
            self.edge_head = nn.Linear(3+2*out_node_nf, n_classes).to(self.device)
        elif edge_head == 'mlp':
            self.edge_head = nn.Sequential(
                nn.Linear(3+2*out_node_nf, edge_head_hidden_dim),
                act_fn,
                nn.Linear(edge_head_hidden_dim, n_classes)).to(self.device)
        else:
            raise Exception(f"The requested edge head architecture {edge_head} is not supported. Should be linear or mlp")

        if modify_h:
            # if we're also changing the features, we add a linear layer on top.
            self.h_head = nn.Linear(out_node_nf, out_node_nf).to(self.device)

    def prepare_class_weights(self, dataloader, dataset_info, device, recompute_class_weight=False):
        print('Computing class weights...')
        self.class_weight_dict = compute_class_weight(dataloader, dataset_info, recompute_class_weight=recompute_class_weight)
        self.class_weight_dict = {key: value.to(device) for key, value in self.class_weight_dict.items()}
        print(f'Computed class_weight: {self.class_weight_dict}')

    def forward(self, batch):
        x = batch['positions'].to(self.device, self.dtype)
        node_mask = batch['atom_mask'].to(self.device, self.dtype).unsqueeze(2)
        edge_mask = batch['edge_mask'].to(self.device, self.dtype)
        one_hot = batch['one_hot'].to(self.device, self.dtype) # categorical
        charges = (batch['charges'] if self.include_charges else torch.zeros(0)).to(self.device, self.dtype) # integer
        t = batch['t'].to(self.device, self.dtype) if self.condition_time else torch.zeros(0).to(self.device, self.dtype)
        if not self.joint_training:
            # make the charges one-hot encoded
            possible_formal_charges = torch.Tensor([-1, 0, 1]).to(self.device)
            charges = charges == possible_formal_charges.unsqueeze(0).unsqueeze(0)
            # make sure we're covering all charges values by checking if the one-hot codes have at least one 1
            #assert torch.all(torch.any(charges, -1))
            charges = charges * node_mask

        check_mask_correct([x, one_hot, charges], node_mask)
        h_in = torch.cat([one_hot, charges], dim=2)

        bs, n_nodes, _ = x.shape

        one_hot = one_hot * node_mask
        charges = charges * node_mask
        h = torch.cat([one_hot, charges], dim=2)

        if self.condition_time:
            if np.prod(t.size()) == 1:
                # t is the same for all elements in batch.
                h_time = torch.empty_like(h[:, :, 0:1]).fill_(t.item())
                h_time = h_time * node_mask
            else:
                # t is different over the batch dimension.
                h_time = t.view(bs, 1).repeat(1, n_nodes).unsqueeze(-1)
                h_time = h_time * node_mask
            h = torch.cat([h, h_time], dim=2)

        edges = self.get_adj_matrix(n_nodes, bs, self.device)
        edges = [x.to(self.device) for x in edges]

        node_mask = node_mask.view(bs*n_nodes, 1)
        edge_mask = edge_mask.view(bs*n_nodes*n_nodes, 1)
        x = x.view(bs*n_nodes, -1) * node_mask
        h = h.view(bs*n_nodes, -1) * node_mask
                
        if self.encoder == 'egnn':
            h_final, x_final = self.egnn(h, x, edges, node_mask=node_mask, edge_mask=edge_mask)
        else:
            h_final, x_final = h, x

        x_final = x_final.view(bs, n_nodes, -1)
        h_final = h_final.view(bs, n_nodes, -1)

        if torch.any(torch.isnan(x_final)) or torch.any(torch.isnan(h_final)):
            print('Warning: detected nan, resetting EGNN output to zero.')
            x_final = torch.zeros_like(x_final)
            h_final = torch.zeros_like(h_final)

        check_mask_correct([x_final, h_final], node_mask.view(bs, n_nodes, -1))

        adj_pred = self.decode_from_xh(x_final, h_final, edge_mask, edge_head=self.edge_head)

        if self.modify_h:
            # The following code ensures that we only predict the nodes we care about
            # i.e. we ignore the nodes that were padded to all atoms in the batch so that they have the same number of nodes
            h_final = h_final.view(bs*n_nodes, -1)
            h_final_indices_non_zero = node_mask.bool().squeeze()
            h_final_processed = self.h_head(h_final[h_final_indices_non_zero])
            _, output_dim = h_final_processed.shape
            h_final_output = torch.zeros((bs*n_nodes, output_dim)).to(self.device)
            h_final_output[h_final_indices_non_zero] = h_final_processed
            h_final = h_final_output
            h_final = h_final.view(bs, n_nodes, -1)
            return adj_pred, h_final
        else:
            return adj_pred, h_in

    def decode_from_xh(self, x, h, edge_mask, edge_head=None, C=10, b=-1, remove_diagonal=True, enforce_symmetry=True):
        bs, n_nodes, _ = x.shape

        x_a = x.unsqueeze(1)
        x_b = torch.transpose(x_a, 1, 2)
        X = (x_a - x_b) ** 2 # (bs, n_nodes, n_nodes, 3)

        h_repeated_i = h.unsqueeze(2).repeat(1, 1, n_nodes, 1)
        h_repeated_j = h.unsqueeze(1).repeat(1, n_nodes, 1, 1)
        H = torch.cat([h_repeated_i, h_repeated_j], dim=-1) # (bs, n_nodes, n_nodes, 2*nf)

        XH = torch.cat([X, H], dim=-1) # (bs, n_nodes, n_nodes, 3+2*nf)
        XH = XH.view(bs * n_nodes * n_nodes, -1) # (total_n_edges, 3+2*nf)

        if edge_head is not None:
            # The following code ensures that we only predict the nodes we care about
            # i.e. we ignore the nodes that were padded to all atoms in the batch so that they have the same number of nodes
            XH_indices_non_zero = edge_mask.bool().squeeze()
            # here, linear_layer should map from 3+2*nf to 5
            XH_processed = edge_head(XH[XH_indices_non_zero]) # (total_n_edges, n_edge_types=5) # logits
            _, n_edge_types = XH_processed.shape
            XH_output = torch.zeros((bs * n_nodes * n_nodes, n_edge_types)).to(self.device)
            XH_output[XH_indices_non_zero] = XH_processed
            XH = XH_output
            # XH = torch.softmax(XH, dim=1) # remove and later compute logsoftmax for stability
        else:
            # TODO : not sure if this is reasonable at all.
            XH = torch.sigmoid(C*torch.sum(XH, dim=1) + b)

        adj_pred = XH.view(bs, n_nodes, n_nodes, -1) # (bs, n_nodes, n_nodes, n_edge_types), where n_edge_types=5
        if remove_diagonal:
            adj_pred = adj_pred * (1 - torch.eye(n_nodes).unsqueeze(-1).unsqueeze(0).to(self.device))
        if enforce_symmetry:
            # when using h features, we concatenate the nodes features, which makes the feature of edge (i,j) different from (j,i)
            adj_pred = (adj_pred + torch.transpose(adj_pred, 1, 2)) / 2
        return adj_pred

    def decode_from_x(self, x, edge_head=None, C=10, b=-1, remove_diagonal=True):
        bs, n_nodes, _ = x.shape

        x_a = x.unsqueeze(1)
        x_b = torch.transpose(x_a, 1, 2)
        X = (x_a - x_b) ** 2 # (bs, n_nodes, n_nodes, 3)

        X = X.view(bs * n_nodes * n_nodes, -1) # (total_n_edges, 3)

        if edge_head is not None:
            # here, linear_layer should map from 3 to 5
            X = edge_head(X) # (total_n_edges, n_edge_types=5) # logits
            # X = torch.softmax(X, dim=1) # remove and later compute logsoftmax for stability
        else:
            # TODO : not sure if this is reasonable at all.
            X = torch.sigmoid(C*torch.sum(X, dim=1) + b)

        adj_pred = X.view(bs, n_nodes, n_nodes, -1) # (bs, n_nodes, n_nodes, n_edge_types), where n_edge_types=5
        if remove_diagonal:
            adj_pred = adj_pred * (1 - torch.eye(n_nodes).unsqueeze(-1).unsqueeze(0).to(self.device))
        return adj_pred

    def get_adj_matrix(self, n_nodes, batch_size, device):
        if n_nodes in self._edges_dict:
            edges_dic_b = self._edges_dict[n_nodes]
            if batch_size in edges_dic_b:
                return edges_dic_b[batch_size]
            else:
                # get edges for a single sample
                rows, cols = [], []
                for batch_idx in range(batch_size):
                    for i in range(n_nodes):
                        for j in range(n_nodes):
                            rows.append(i + batch_idx * n_nodes)
                            cols.append(j + batch_idx * n_nodes)
                edges = [torch.LongTensor(rows).to(device),
                         torch.LongTensor(cols).to(device)]
                edges_dic_b[batch_size] = edges
                return edges
        else:
            self._edges_dict[n_nodes] = {}
            return self.get_adj_matrix(n_nodes, batch_size, device)

    def predict_adj_matrix_from_single_molecule(self, molecule): # TODO: implement
        """
        Preprocesses and performs a prediction on a single raw molecule which would be directly read from an xyz file
        IMPORTANT: expects a single molecule and not a batch of molecules.
        Args:
            molecule (dict): contains keys: num_atoms, atomic_numbers, positions, formal_charges
        """
        # 'one_hot' will contain the one-hot representation of atomic numbers
        molecule['one_hot'] = molecule['atomic_numbers'].unsqueeze(-1) == self.included_species.unsqueeze(0).unsqueeze(0)
        molecule['one_hot'] = molecule['one_hot'].squeeze()

        atom_mask = molecule['atomic_numbers'] > 0
        molecule['atom_mask'] = atom_mask

        #Obtain edges
        n_nodes = len(atom_mask)
        edge_mask = atom_mask.unsqueeze(0) * atom_mask.unsqueeze(1)

        #mask diagonal
        diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool)
        edge_mask *= diag_mask

        molecule['edge_mask'] = edge_mask.view(n_nodes * n_nodes, 1)

        if self.include_charges:
            molecule['charges'] = molecule['formal_charges'].unsqueeze(1)

        # include singleton batch dimension
        for key in molecule:
            molecule[key] = molecule[key].unsqueeze(0)

        adj_pred = self.forward(molecule).squeeze(0)
        adj_pred = torch.argmax(adj_pred, -1)
        return adj_pred

    def compute_loss_joint_training(self, xh_pred, node_mask, edge_mask, h_gt, adj_gt, t):
        x = xh_pred[:, :, :3]
        h_categorical = xh_pred[:, :, 3:-1]
        h_integer = xh_pred[:, :, -1].unsqueeze(-1) # TODO: make shape consistent with model's expectations
        node_mask = node_mask.squeeze()
        batch = {'positions': x, 'atom_mask': node_mask, 'edge_mask': edge_mask, 
                 'one_hot': h_categorical, 'charges': h_integer, 't': t}
        adj_pred, h_pred = self.forward(batch)

        h_categorical_gt = h_gt['categorical'].long()
        h_integer_gt = h_gt['integer'].long()
        if self.modify_h:
            atom_types_loss, formal_charges_loss = atom_types_and_formal_charges_loss(h_pred, h_categorical_gt, h_integer_gt, weight_dict=self.class_weight_dict)
        else:
            atom_types_loss, formal_charges_loss = 0, 0
        adj_loss = adjacency_matrix_loss(adj_pred, adj_gt, weight=self.class_weight_dict['edges'])
        loss = atom_types_loss + formal_charges_loss + adj_loss
        return loss

    def map_to_2d(self, h_categorical, h_integer, x, node_mask, edge_mask):
        """
        This method is called after sampling from the 3D diffusion model to get the 2D graph
        """
        node_mask = node_mask.squeeze()
        # t = 0
        t = torch.zeros((1,)).to(self.device)
        batch = {'positions': x, 'atom_mask': node_mask, 'edge_mask': edge_mask, 
                 'one_hot': h_categorical, 'charges': h_integer, 't': t}
        adj_pred, h_pred = self.forward(batch)

        adj_pred = torch.argmax(adj_pred, -1)

        atom_types_pred = h_pred[:, :, :-3]
        atom_types_pred = torch.argmax(atom_types_pred, -1)

        formal_charges_pred = h_pred[:, :, -3:]
        formal_charges_pred = torch.argmax(formal_charges_pred, -1) - 1

        return adj_pred, atom_types_pred, formal_charges_pred
