import torch
from torch.nn import functional as F


def adjacency_matrix_loss(adj_pred, adj_gt, edge_mask, n_classes=5, reduction='mean', weight=None, use_focal_loss=False):
    # TODO: only compute loss on actual possible edges (remove dummy stuff)
    # for that, we might need edge_mask
    if not use_focal_loss:
        # The following code ensures that we only compute the loss on the nodes we care about
        # i.e. we ignore the nodes that were padded to all atoms in the batch so that they have the same number of nodes
        adj_indices_non_zero = edge_mask.bool().view(-1)
        adj_pred_without_dummy_nodes = adj_pred.view(-1, n_classes)[adj_indices_non_zero]
        adj_gt_without_dummy_nodes = adj_gt.view(-1)[adj_indices_non_zero]

        return F.cross_entropy(adj_pred_without_dummy_nodes, adj_gt_without_dummy_nodes, reduction=reduction, weight=weight) 
    else:
        # TODO: only compute loss on actual possible edges (remove dummy stuff) for focal_loss still to be done
        ce_loss = F.cross_entropy(adj_pred.view(-1, n_classes), adj_gt.view(-1), reduction='none', weight=weight)
        pt = torch.exp(-ce_loss)
        # hyperaprams
        gamma = 2 # controls how much easy examples are down-weighted
        alpha = 1 # per-class coefficient
        focal_loss = (alpha * (1 - pt) ** gamma * ce_loss).mean()
        return focal_loss

def atom_types_and_formal_charges_loss(h_pred, atom_types_gt, formal_charges_gt, node_mask, reduction='mean', weight_dict=None, use_focal_loss=False):
    n_atom_types = atom_types_gt.shape[-1]
    n_formal_charges = formal_charges_gt.shape[-1] # should be 3 in most cases (at least for ZINC250k)
    atom_types_pred = h_pred[:, :, :n_atom_types]
    formal_charges_pred = h_pred[:, :, n_atom_types:]

    if weight_dict is not None:
        weight_atom_types = weight_dict['atom_types']
        weight_formal_charges = weight_dict['formal_charges']
    else:
        weight_atom_types = None
        weight_formal_charges = None

    # go from one-hot encoding to class indices
    atom_types_gt_class_indices = atom_types_gt.float().argmax(-1)
    # The following code ensures that we only predict the nodes we care about
    # i.e. we ignore the nodes that were padded to all atoms in the batch so that they have the same number of nodes
    atom_indices_non_zero = node_mask.bool().view(-1)

    atom_types_pred_without_dummy_nodes = atom_types_pred.view(-1, n_atom_types)[atom_indices_non_zero]
    atom_types_gt_class_indices_without_dummy_nodes = atom_types_gt_class_indices.view(-1)[atom_indices_non_zero]

    if not use_focal_loss:
        atom_types_loss = F.cross_entropy(atom_types_pred_without_dummy_nodes, atom_types_gt_class_indices_without_dummy_nodes, 
                                          reduction=reduction, weight=weight_atom_types)
    else:
        # TODO: only compute loss on actual possible edges (remove dummy stuff) for focal_loss still to be done
        # DEPRECATED: was using special class index to ignore dummy nodes
        # problem is there is no clear way for formal charges because there is charge = 0 and dummy nodes are also 0
        # set empty atom types to special class index and ignore it
        atom_types_gt_class_indices[torch.all(atom_types_gt==0, -1)] = -1

        ce_loss = F.cross_entropy(atom_types_pred.view(-1, n_atom_types), atom_types_gt_class_indices.view(-1), 
                                        ignore_index=-1, reduction='none', weight=weight_atom_types)
        pt = torch.exp(-ce_loss)
        # hyperaprams
        gamma = 2 # controls how much easy examples are down-weighted
        alpha = 1 # per-class coefficient
        atom_types_loss = (alpha * (1 - pt) ** gamma * ce_loss).mean()

    # go from one-hot encoding of formal charges to class indices 0,1,2
    formal_charges_gt_class_indices = formal_charges_gt.float().argmax(-1)
    formal_charges_pred_without_dummy_nodes = formal_charges_pred.view(-1, n_formal_charges)[atom_indices_non_zero]
    formal_charges_gt_class_indices_without_dummy_nodes = formal_charges_gt_class_indices.view(-1)[atom_indices_non_zero]

    if not use_focal_loss:
        formal_charges_loss = F.cross_entropy(formal_charges_pred_without_dummy_nodes, formal_charges_gt_class_indices_without_dummy_nodes,
                                              reduction=reduction, weight=weight_formal_charges)
    else:
        # TODO: only compute loss on actual possible edges (remove dummy stuff) for focal_loss still to be done
        ce_loss = F.cross_entropy(formal_charges_pred.view(-1, n_formal_charges), formal_charges_gt_class_indices.view(-1), 
                                        reduction='none', weight=weight_formal_charges)
        pt = torch.exp(-ce_loss)
        # hyperaprams
        gamma = 2 # controls how much easy examples are down-weighted
        alpha = 1 # per-class coefficient
        formal_charges_loss = (alpha * (1 - pt) ** gamma * ce_loss).mean()

    return atom_types_loss, formal_charges_loss
