import torch
import torch.nn as nn
import torchvision
from torch.nn.parallel import DataParallel
import numpy as np
from torch.autograd import Variable as V
import torchvision.models as models
from torchvision import transforms as trn
from torch.nn import functional as F
import os
from PIL import Image

import resnet_wider as resnetw

class VGGFeatureExtractor(nn.Module):
    def __init__(self, layer_ids=(4,11,18,31), use_bn=False, use_input_norm=True, device=torch.device('cpu'),
                 vgg_v=19, pretrained=True):
        super(VGGFeatureExtractor, self).__init__()
        self.use_input_norm = use_input_norm
        if vgg_v == 19 and use_bn:
            model = torchvision.models.vgg19_bn(pretrained=pretrained)
        elif vgg_v == 19:
            model = torchvision.models.vgg19(pretrained=pretrained)
        elif vgg_v == 16 and use_bn:
            model = torchvision.models.vgg16_bn(pretrained=pretrained)
        elif vgg_v == 16:
            model = torchvision.models.vgg16(pretrained=pretrained)
        else:
            raise ValueError

        if self.use_input_norm:
            mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
            std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
            self.register_buffer('mean', mean)
            self.register_buffer('std', std)

        self.feature_layers_num = len(layer_ids)
        prev_feature_layer = 0
        for i, feature_layer in enumerate(layer_ids):
            m = nn.Sequential(*list(model.features.children())[prev_feature_layer:feature_layer + 1])
            for k, v in m.named_parameters():
                v.requires_grad = False
            setattr(self, 'fe{}'.format(i), m)
            prev_feature_layer = feature_layer + 1

    def forward(self, x):
        # Input range must be in [0, 1] or [-1, 1]
        x = 0.5*x + 0.5
        if self.use_input_norm:
            x = (x - self.mean) / self.std

        output = []
        for i in range(self.feature_layers_num):
            fe = getattr(self, 'fe{}'.format(i))
            if i == 0:
                output.append(fe(x))
            else:
                output.append(fe(output[i - 1]))
        return output


    
class VggPrior(nn.Module):
    def __init__(self, img_size=128, in_ch=3, nz=128, hdim=512, fc_arch='shallow'):
        super().__init__()

        self.feature  = VGGFeatureExtractor(layer_ids=(18,31), use_bn = False, vgg_v=16)
        self.pool  = nn.AdaptiveAvgPool2d((1,1))
        for param in self.feature.parameters():
            param.requires_grad = False
    def forward(self, x , style=False):
        out = self.feature(x)[-1]
        out = self.pool(out)
        if style:
            lam = np.random.beta(1,1)
            idx = torch.randperm(out.shape[0])
            out = lam*out + (1-lam)*out[idx]
            out = out.view(out.shape[0],-1)
            return out
        else:
            out = out.view(out.shape[0],-1)
            return out

    

class Resnet50x1Prior(nn.Module):
    def __init__(self, img_size=128, in_ch=3, nz=128, hdim=512, fc_arch='shallow'):
        super().__init__()
        #https://github.com/tonylins/simclr-converter
        self.feature  = resnetw.resnet50x1()
        self.feature.load_state_dict(torch.load('./pretrained_models/resnet50-1x.pth')['state_dict'])
        self.feature.eval()
        self.pool  = nn.AvgPool2d(4)
        for param in self.feature.parameters():
            param.requires_grad = False
            
    def forward(self, x, style=False):
        x = 0.5*x + 0.5
        out = self.feature(x)
        if style:
            lam = np.random.beta(1,1)
            idx = torch.randperm(out.shape[0])
            out = lam*out + (1-lam)*out[idx]
            out = out.view(out.shape[0],-1)
        else:
            return out


    
class VggStylePrior(nn.Module):
    def __init__(self, img_size=128, in_ch=3, nz=128, hdim=512, fc_arch='shallow'):
        super().__init__()

        self.feature  = VGGFeatureExtractor(layer_ids=(18,31), use_bn = False, vgg_v=16)
        self.pool  = nn.AdaptiveAvgPool2d((1,1)) #nn.AvgPool2d(4)
        for param in self.feature.parameters():
            param.requires_grad = False
            
    def forward(self, x1):
        lam = np.random.beta(1,1)
        out1 = self.pool(self.feature(x1)[-1])
        idx = torch.randperm(out1.shape[0])
        out = lam*out1 + (1-lam)*out1[idx]
        out = out.view(out.shape[0],-1)
        return out

class Resnet50x1StylePrior(nn.Module):
    def __init__(self, img_size=128, in_ch=3, nz=128, hdim=512, fc_arch='shallow'):
        super().__init__()

        self.feature  = resnetw.resnet50x1()
        self.feature.load_state_dict(torch.load('./pretrained_models/resnet50-1x.pth')['state_dict'])
        self.feature.eval()
        self.pool  = nn.AvgPool2d(4)
        for param in self.feature.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        x = 0.5*x + 0.5
        out = self.feature(x)
    
        lam = np.random.beta(1,1)
        idx = torch.randperm(out.shape[0])
        out = lam*out + (1-lam)*out[idx]
        out = out.view(out.shape[0],-1)
        return out

