"""Soft iterative refinement metrics.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F

def layer_dropout(stage, inp, final_fun):
    """Layer dropout: compute metric required for Recurrence Index.

    Args:
        block: Set of residual blocks to which layer dropout should be applied.
        inp: Input tensor
        final_fun: What function should be applied after the blocks?

    See also:
        apply_to_mnist
    """
    final_outcome = []
    block_outcome = []
    for layer in range(stage.depth):
        _inp = inp.clone()
        inp = stage.blocks[layer](inp)
        for l in range(layer+1, stage.depth):
            _inp = stage.blocks[l](_inp)
        block_outcome.append(_inp.clone().cpu())
        out = final_fun(_inp)
        final_outcome.append(out.clone().cpu())
    block_deviation = (torch.stack(block_outcome) - inp.unsqueeze(0).cpu())**2
    final_outcome = torch.stack(final_outcome)
    return final_outcome, block_deviation.mean(dim=(1,2,3,4))

def early_late_readout(stage, inp, final_fun, depth=None):
    """Early/Late Readout: compute metrics required for Convergence and
    Divergence Index.

    Args:
        block: Set of residual blocks to which layer dropout should be applied.
        inp: Input tensor
        depth: What should be the maximum depth for the late readout?
        final_fun: What function should be applied after the blocks?

    See also:
        apply
    """
    if depth is None:
        depth = stage.depth*2
    block_outcome = [inp.clone().cpu()]
    out = final_fun(inp)
    final_outcome = [out.clone().cpu()]
    for d in range(depth):
        inp = stage.blocks[min(d, stage.depth-1)](inp)
        if d == stage.depth-1:
            true_outcome = inp.clone().cpu()
        block_outcome.append(inp.clone().cpu())
        out = final_fun(inp)
        final_outcome.append(out.clone().cpu())
    block_deviation = (torch.stack(block_outcome) - true_outcome.unsqueeze(0))**2
    final_outcome = torch.stack(final_outcome)
    return final_outcome, block_deviation.mean(dim=(1,2,3,4))

def apply(model, data, metric, device='cuda', **kwargs):
    """Apply a metric to an MNIST-ResNet.

    Args:
        model: The trained model.
        data: The batch from the dataloader.
        metric: The metric to be computed.
        device: Should this be computed on GPU (default) or CPU
        kwargs: Possible arguments to the metric.
    """
    images, labels = data
    images = images.to(device)
    labels = labels
    inp = F.relu(model.bn1(model.conv1(images)))
    lst_final_outcome = []
    lst_blockdev = []
    for i in range(int((len(model.stages)+1)/2)):
        if i != 0:
            inp = model.stages[2*i-1](inp)
        def final_fun(inp):
            for stage in model.stages[(2*i+1):]:
                inp = stage(inp)
            if model.unit_type == 'preactivation':
                inp = F.relu(model.bn2(inp))
            inp = F.avg_pool2d(inp, inp.size()[3])
            inp = inp.view(inp.size(0), -1)
            inp = model.linear(inp)
            return inp
        final_outcome, blockdev = metric(
            model.stages[2*i], inp, final_fun=final_fun, **kwargs
        )
        lst_final_outcome.append(final_outcome)
        lst_blockdev.append(blockdev)
        inp = model.stages[2*i](inp)
    final_outcome = torch.cat(lst_final_outcome, dim=0)
    blockdev = torch.cat(lst_blockdev, dim=0)
    correct, cel = eval_criteria(final_outcome, labels)
    return {'acc': correct, 'crossentropy': cel, 'blockdev': blockdev.to('cpu'), 'size': labels.size(0)}

def eval_criteria(final_outcome, labels):
    val_criterion = nn.CrossEntropyLoss(reduction='sum')
    _, predicted = torch.max(final_outcome, 2)
    correct = (predicted == labels.unsqueeze(0)).sum(dim=1).to('cpu')
    cel = torch.tensor([val_criterion(_traj, labels) for _traj in final_outcome]).to('cpu')
    return correct, cel
