import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from collections import namedtuple


class ConsistencyLoss(nn.Module):
    """
    Loss function for landmark prediction
    Input:
        loss_type: string, 'perceptual' or 'l2'
    """
    def __init__(self, loss_type='perceptual'):
        super(ConsistencyLoss, self).__init__()
        self.loss_type = loss_type
        self.vggnet = Vgg16() if 'perceptual' in loss_type else None

    def forward(self, input_1, input_2):
        if self.loss_type == 'perceptual':
            loss = self.perceptual_loss(input_1, input_2)
        elif self.loss_type == 'mse':
            loss = F.mse_loss(input_1, input_2)
        elif self.loss_type == 'bce':
            loss = F.binary_cross_entropy_with_logits(input_1, input_2)
        elif self.loss_type == 'perceptual_color_hm':
            #Colorize heatmaps and apply perceptual loss
            loss = self.perceptual_loss(self.hm2color(input_1), self.hm2color(input_2))
        elif self.loss_type == 'perceptual_gray_hm':
            loss = self.perceptual_loss(self.hm2gray(input_1), self.hm2gray(input_2))
        else:
            raise ValueError('Incorrect loss_type for consistency loss', self.loss_type)

        return loss

    def perceptual_loss(self, gt_image, pred_image,
                        ws=[50., 40., 6., 3., 3., 1.],
                        names=['input', 'conv1_2', 'conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']):

        #get features map from vgg
        feats_gt = self.vggnet(gt_image)
        feats_pred = self.vggnet(pred_image)

        feat_gt, feat_pred = [gt_image], [pred_image]
        for k in names[1:]: #no need input
            feat_gt.append(getattr(feats_gt, k))
            feat_pred.append(getattr(feats_pred, k))

        losses = []
        for k, v in enumerate(names):
            loss = F.mse_loss(feat_pred[k], feat_gt[k], reduction='mean')
            #print('loss at layer {} is {}'.format(v, l))
            loss /= ws[k]
            losses.append(loss)
        loss = torch.stack(losses).sum()
        return loss

    def hm2gray(self, hm):
        #Convert heatmap to grayscale. Then stack 3 dimensions for input to VGG
        gray_hm = torch.sum(hm, dim=1).unsqueeze(1)
        gray_hm = torch.cat([gray_hm, gray_hm, gray_hm], dim=1)
        return gray_hm


class Vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False, \
            names=['conv1_2', 'conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']):
        super(Vgg16, self).__init__()
        self.names = names
        vgg_pretrained_features = models.vgg16(pretrained=True).features
        self.slice1 = vgg_pretrained_features[:3] #conv1_2
        self.slice2 = vgg_pretrained_features[3:8] #conv2_2
        self.slice3 = vgg_pretrained_features[8:13] #conv3_2
        self.slice4 = vgg_pretrained_features[13:20] #conv4_2
        self.slice5 = vgg_pretrained_features[20:27] #conv5_2
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_conv1_2 = h
        h = self.slice2(h)
        h_conv2_2 = h
        h = self.slice3(h)
        h_relu3_2 = h
        h = self.slice4(h)
        h_relu4_2 = h
        h = self.slice5(h)
        h_relu5_2 = h
        vgg_outputs = namedtuple("VggOutputs", self.names)
        out = vgg_outputs(h_conv1_2, h_conv2_2, h_relu3_2, h_relu4_2, h_relu5_2)
        return out
