import torch
import torch.nn as nn
from torchvision import models
from torch.nn import functional as F
from collections import OrderedDict

class Vgg16(nn.Module):
    def __init__(self):
        super(Vgg16, self).__init__()
        features = models.vgg16(pretrained=True).features
        self.to_relu_1_2 = nn.Sequential() 
        self.to_relu_2_2 = nn.Sequential() 
        self.to_relu_3_3 = nn.Sequential()
        self.to_relu_4_3 = nn.Sequential()

        for x in range(4):
            self.to_relu_1_2.add_module(str(x), features[x])
        for x in range(4, 9):
            self.to_relu_2_2.add_module(str(x), features[x])
        for x in range(9, 16):
            self.to_relu_3_3.add_module(str(x), features[x])
        for x in range(16, 23):
            self.to_relu_4_3.add_module(str(x), features[x])
        
        # don't need the gradients, just want the features
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        h = self.to_relu_1_2(x)
        h_relu_1_2 = h
        h = self.to_relu_2_2(h)
        h_relu_2_2 = h
        h = self.to_relu_3_3(h)
        h_relu_3_3 = h
        h = self.to_relu_4_3(h)
        h_relu_4_3 = h
        out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3)
        return out

class VggFace(nn.Module):
    def __init__(self, feature_idx):
        super(VggFace, self).__init__()
        model = models.vgg.VGG(models.vgg.make_layers(models.vgg.cfgs['D'], batch_norm=False), num_classes=2622)
        model.load_state_dict(vgg_face_state_dict())
        features = list(model.features)
        self.features = nn.ModuleList(features).eval()
        self.idx_list = feature_idx

    def forward(self, x):
        results = []
        for ii, model in enumerate(self.features):
            x = model(x)
            if ii in self.idx_list:
                results.append(x)

        return results


class PerceptualLoss(nn.Module):
    def __init__(self, weights, type='vgg'):
        super(PerceptualLoss, self).__init__()
        if type == 'vgg':
            self.vgg = Vgg16()
            self.weights = weights
        else:
            self.vgg = VggFace([1, 6, 11, 18, 25])
            self.weights = [1, 1, 1, 1, 1]

    def forward(self, x, target): # video
        batch_sz, time, channel, sz1, sz2 = x.shape
        assert x.shape == target.shape

        x = x.view(-1, channel, sz1, sz2)
        target = target.view(-1, channel, sz1, sz2)

        IMG_NET_MEAN = torch.Tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1]).to(x.device)
        IMG_NET_STD = torch.Tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1]).to(x.device)

        x = (x - IMG_NET_MEAN) / IMG_NET_STD
        target = (target - IMG_NET_MEAN) / IMG_NET_STD

        x_feature = self.vgg(x)
        target_feature = self.vgg(target)

        vgg_loss = 0
        for i in range(0, len(x_feature)):
            vgg_loss += self.weights[i] * F.l1_loss(target_feature[i], x_feature[i])

        return vgg_loss

def vgg_face_state_dict():
    default = torch.load("./vgg_face_dag.pth")
    state_dict = OrderedDict({
        'features.0.weight': default['conv1_1.weight'],
        'features.0.bias': default['conv1_1.bias'],
        'features.2.weight': default['conv1_2.weight'],
        'features.2.bias': default['conv1_2.bias'],
        'features.5.weight': default['conv2_1.weight'],
        'features.5.bias': default['conv2_1.bias'],
        'features.7.weight': default['conv2_2.weight'],
        'features.7.bias': default['conv2_2.bias'],
        'features.10.weight': default['conv3_1.weight'],
        'features.10.bias': default['conv3_1.bias'],
        'features.12.weight': default['conv3_2.weight'],
        'features.12.bias': default['conv3_2.bias'],
        'features.14.weight': default['conv3_3.weight'],
        'features.14.bias': default['conv3_3.bias'],
        'features.17.weight': default['conv4_1.weight'],
        'features.17.bias': default['conv4_1.bias'],
        'features.19.weight': default['conv4_2.weight'],
        'features.19.bias': default['conv4_2.bias'],
        'features.21.weight': default['conv4_3.weight'],
        'features.21.bias': default['conv4_3.bias'],
        'features.24.weight': default['conv5_1.weight'],
        'features.24.bias': default['conv5_1.bias'],
        'features.26.weight': default['conv5_2.weight'],
        'features.26.bias': default['conv5_2.bias'],
        'features.28.weight': default['conv5_3.weight'],
        'features.28.bias': default['conv5_3.bias'],
        'classifier.0.weight': default['fc6.weight'],
        'classifier.0.bias': default['fc6.bias'],
        'classifier.3.weight': default['fc7.weight'],
        'classifier.3.bias': default['fc7.bias'],
        'classifier.6.weight': default['fc8.weight'],
        'classifier.6.bias': default['fc8.bias']
    })
    return state_dict



class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()        
        self.vgg = Vgg19().cuda()
        self.criterion = nn.L1Loss()
        self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]        

    def forward(self, x, y): # video  
        batch_sz, time, channel, sz1, sz2 = x.shape
        assert x.shape == y.shape

        x = x.view(-1, channel, sz1, sz2)
        y = y.view(-1, channel, sz1, sz2)

        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())        
        return loss / (batch_sz * time)


class Vgg19(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg19, self).__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features

        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)        
        h_relu3 = self.slice3(h_relu2)        
        h_relu4 = self.slice4(h_relu3)        
        h_relu5 = self.slice5(h_relu4)                
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out