import torch
import torch.nn.functional as F

from utils import logcosh, EPSILON

# TODO: add to tracking
margin = 1.0 # As in GMN implementaion


def inner_training_loop(
        model, optimizer, dataloader, phase, loss_type, device,
        graphs1, graphs2, labels, applygrad_iter, apply_grad):

    # send stuff to device
    graphs1 = {k: v.to(device) for k, v in graphs1.items()}
    graphs2 = {k: v.to(device) for k, v in graphs2.items()}

    if dataloader.collate_fn.triplets:
        graphs3 = {k: v.to(device) for k, v in labels.items()}
    else:
        labels = labels.to(device)

    # forward
    # track history if only in train
    with torch.set_grad_enabled(phase == 'train'):
        outputs_for_gradient, outputs = model(graphs1, graphs2)
        batch_size = outputs.size(0)

        if loss_type == 'mse':
            loss = F.mse_loss(outputs, labels)
            if outputs_for_gradient is not None:
                loss += F.mse_loss(outputs_for_gradient, labels)
        elif loss_type == 'logcosh':
            loss = logcosh(outputs - labels).mean()
            if outputs_for_gradient is not None:
                loss += logcosh(outputs_for_gradient - labels).mean()
        elif loss_type == 'margin_pair':
            target = 2 * (labels - 0.5)
            loss = torch.mean(torch.nn.ReLU()(margin - target * (1. - outputs)))
        elif loss_type == 'margin_triplet':
            outputs2 = model(graphs1, graphs3)
            loss = torch.mean(torch.nn.ReLU()(outputs - outputs2 + 10.))
            labels = outputs2
        elif loss_type == 'matrix_nll':
            max_nodes = labels.size(1)

            # This mask gives everything that is a real calculated distance
            mask_n1 = (torch.arange(max_nodes, dtype=torch.int64,
                                device=labels.device)[:, None].expand_as(labels)
                       >= graphs1['num_nodes'][:, None, None])
            mask_n2 = (torch.arange(max_nodes, dtype=torch.int64,
                                    device=labels.device).expand_as(labels)
                       >= graphs2['num_nodes'][:, None, None])
            mask_nodist = mask_n1 | mask_n2

            # cut out matrix, and NLL
            tensor = (labels * torch.log(outputs[:, :max_nodes, :max_nodes].clamp(min=EPSILON)) +
                      (1. - labels) * torch.log((1. - outputs[:, :max_nodes, :max_nodes]).clamp(min=EPSILON)))
            tensor = tensor.masked_fill_(mask_nodist, 0)
            loss = -tensor.sum(dim=[1, 2]).mean()

            if outputs_for_gradient is not None:
                tensor = (labels * torch.log(outputs_for_gradient[:, :max_nodes, :max_nodes].clamp(min=EPSILON)) +
                          (1. - labels) * torch.log((1. - outputs_for_gradient[:, :max_nodes, :max_nodes]).clamp(min=EPSILON)))
                tensor = tensor.masked_fill_(mask_nodist, 0)
                loss += -tensor.sum(dim=[1, 2]).mean()

            # compute hits@1 right here so don't have to keep the matricies around
            # could also only compute hits for validation data
            with torch.set_grad_enabled(False):
                temp = outputs[:, :max_nodes, :max_nodes].masked_fill(mask_nodist, 0)
                temp = temp.argmax(dim=2)

                # # TODO: implement batch wise version
                outputs = torch.zeros([batch_size], dtype=torch.int32)
                for j in range(batch_size):
                    t2 = labels[j].nonzero()
                    outputs[j] = (temp[j, t2[:, 0]] == t2[:, 1]).sum()
                labels = labels.sum(dim=[1, 2])

                # # This workaround saves around 3s per 20k samples
                # t2 = labels.nonzero()
                # outputs = (temp[t2[:, 0], t2[:, 1]] == t2[:, 2]).sum(dim=-1, keepdim=True)
                # labels = labels.sum(dim=[1, 2])
                # labels = labels.sum(dim=-1, keepdim=True)
        else:
            raise NotImplementedError()

        # backward + optimize only if in training phase
        if phase == 'train':
            # Divide by update stepsize because loss is a mean value
            loss_grad = loss / applygrad_iter

            loss_grad.backward()
            if apply_grad:
                optimizer.step()
                optimizer.zero_grad()

    return outputs, labels, loss, batch_size
