import os
import time
from tqdm import tqdm
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
import torch.utils.data
import torchvision.utils as vutils
import torch.nn.functional as F
import torch
import math
import resnet_base as models

class AnomalyResnet181(nn.Module):

    def __init__(self, n=-2, in_c=3):
        super(AnomalyResnet181, self).__init__()
        self.n = n
        self.in_c = in_c

        self.resnet18 = models.resnet18(pretrained=False, in_c=self.in_c)

    def forward(self, x, l4=False):

        x, x1 = self.resnet18(x)

        if l4:
            return x, x1
        else:
            return x

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Input size: [batch, 3, 32, 32]
        # Output size: [batch, 3, 32, 32]
        self.encoder = AnomalyResnet181(in_c=3, n=-2)

        ndf = 128

        self.decoder = nn.Sequential(
            nn.Conv2d(ndf * 4, ndf * 4, 3, stride=1, padding=1),

            nn.ConvTranspose2d(ndf * 4, ndf * 4, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 4),
            nn.ReLU(),

            nn.Conv2d(ndf * 4, ndf * 4, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf * 4),
            nn.ReLU(),

        )
        self.decoder1 = nn.Sequential(

            nn.ConvTranspose2d(ndf * 6, ndf * 2, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 2),
            nn.ReLU(),

            nn.Conv2d(ndf * 2, ndf * 2, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf * 2),
            nn.ReLU(),

            nn.ConvTranspose2d(ndf * 2, ndf, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),

            nn.Conv2d(ndf, ndf, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),

            nn.ConvTranspose2d(ndf, ndf, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),
            nn.Conv2d(ndf, ndf, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),

            nn.ConvTranspose2d(ndf, ndf, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),
            nn.Conv2d(ndf, ndf, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),

            nn.Conv2d(ndf, 3, 3, stride=1, padding=1),
            nn.Sigmoid()
        )

        self.decodermm2 = nn.Sequential(
            nn.Conv2d(ndf * 4, ndf * 4, 3, stride=1, padding=1),

            nn.ConvTranspose2d(ndf * 4, ndf * 4, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 4),
            nn.ReLU(),

            nn.Conv2d(ndf * 4, ndf * 4, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf * 4),
            nn.ReLU(),

        )

        self.decodermm123 = nn.Sequential(
            nn.Conv2d(ndf * 2, ndf, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),
            nn.Conv2d(ndf, ndf * 2, 3, stride=1, padding=1),

        )
        self.decodermm1 = nn.Sequential(

            nn.ConvTranspose2d(ndf * 6, ndf * 2, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 2),
            nn.ReLU(),

            nn.Conv2d(ndf * 2, ndf * 2, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf * 2),
            nn.ReLU(),

            nn.ConvTranspose2d(ndf * 2, ndf, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),

            nn.Conv2d(ndf, ndf, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),

            nn.ConvTranspose2d(ndf, ndf, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),
            nn.Conv2d(ndf, ndf, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),

            nn.ConvTranspose2d(ndf, ndf, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),
            nn.Conv2d(ndf, ndf, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),

            nn.Conv2d(ndf, 1, 3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded, xx = self.encoder(x, l4=True)

        decoded = self.decoder(encoded)
        decoded = torch.cat([decoded, xx], dim=1)
        decoded = self.decoder1(decoded)

        decoded1 = self.decodermm2(encoded)
        decoded1 = torch.cat([decoded1, self.decodermm123(xx)], dim=1)
        decoded1 = self.decodermm1(decoded1)

        return decoded1, decoded


class Generator_test(nn.Module):
    def __init__(self):
        super(Generator_test, self).__init__()
        # Input size: [batch, 3, 32, 32]
        # Output size: [batch, 3, 32, 32]
        self.encoder = AnomalyResnet181(in_c=3, n=-2)

        ndf = 128

        self.decodermm2 = nn.Sequential(
            nn.Conv2d(ndf * 4, ndf * 4, 3, stride=1, padding=1),

            nn.ConvTranspose2d(ndf * 4, ndf * 4, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 4),
            nn.ReLU(),

            nn.Conv2d(ndf * 4, ndf * 4, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf * 4),
            nn.ReLU(),

        )

        self.decodermm123 = nn.Sequential(
            nn.Conv2d(ndf * 2, ndf, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),
            nn.Conv2d(ndf, ndf * 2, 3, stride=1, padding=1),

        )
        self.decodermm1 = nn.Sequential(

            nn.ConvTranspose2d(ndf * 6, ndf * 2, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 2),
            nn.ReLU(),

            nn.Conv2d(ndf * 2, ndf * 2, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf * 2),
            nn.ReLU(),

            nn.ConvTranspose2d(ndf * 2, ndf, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),

            nn.Conv2d(ndf, ndf, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),

            nn.ConvTranspose2d(ndf, ndf, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),
            nn.Conv2d(ndf, ndf, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),

            nn.ConvTranspose2d(ndf, ndf, 4, stride=2, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),
            nn.Conv2d(ndf, ndf, 3, stride=1, padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),

            nn.Conv2d(ndf, 1, 3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded, xx = self.encoder(x, l4=True)

        decoded1 = self.decodermm2(encoded)
        decoded1 = torch.cat([decoded1, self.decodermm123(xx)], dim=1)
        decoded1 = self.decodermm1(decoded1)

        return decoded1

class get_disca(nn.Module):
    def __init__(self, inp):
        super(get_disca, self).__init__()
        ndf = 64

        self.main2 = nn.Sequential(
            # (3) x 32 x 32)
            nn.Conv2d(inp, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # ndf x 16 x 16
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # (ndf * 2) x 8 x 8
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            # (ndf * 4) x 4 x 4
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

        )

        self.main11 = nn.Sequential(

            nn.Conv2d(ndf * 8, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # (ndf * 8) x 2 x 2
            nn.Conv2d(ndf * 2, 1, 2, 1, 0, bias=True),
            nn.AdaptiveAvgPool2d(1),
        )

    def forward(self, input):

        output1 = self.main2(input)
        output = self.main11(output1)

        return output.view(-1, 1), output1




def weights_init(mod):
    classname = mod.__class__.__name__
    if classname.find('Conv') != -1:
        mod.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        mod.weight.data.normal_(1.0, 0.02)
        mod.bias.data.fill_(0)
