from sklearn.metrics import roc_curve, auc
import argparse
import os
import shutil
import time
import random
from sklearn.metrics import roc_auc_score
import torch
import shutil
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
import torchvision.utils as vutils
import math
from PIL import Image
import matplotlib.pyplot as plt
from torch.autograd import Variable
from model import *
from torchvision.utils import save_image
from data import *
import torch.nn.functional as F
from hard_aug import *
import sys
from torch.utils.data import DataLoader

ki = cls+'_1'

try:
    if not os.path.exists('output_test_{}'.format( ki)):
        os.makedirs('output_test_{}'.format(ki))
except OSError:
    print('Error: Creating directory. ' + 'output_test_{}'.format(ki))

Tensor = torch.cuda.FloatTensor
lr = 0.0001
state ={'lr':lr}

def adjust_learning_rate(optimizer, epoch):
    global state
    if epoch in [150, 180]:
        state['lr'] *= 0.1
        for param_group in optimizer.param_groups:
            param_group['lr'] = state['lr']

def segmentation_auroc(mask, anomaly_maps):

    gt = mask.astype(np.int64)
    auroc = roc_auc_score(gt.reshape(-1), anomaly_maps.reshape(-1))

    return auroc

gan = Generator().cuda()
dis = get_disca(4).cuda()

gps = [0,1]
gan = torch.nn.DataParallel(gan,  device_ids=gps)
dis = torch.nn.DataParallel(dis, device_ids=gps)

criterion_c = torch.nn.BCELoss(reduction='none').cuda()
criterion_ce = torch.nn.BCEWithLogitsLoss(size_average=True).cuda()
criterion_bce = torch.nn.MSELoss().cuda()
criterion_ae = nn.BCELoss().cuda()
criterio = torch.nn.MSELoss(reduction='none').cuda()

optimizer_a = optim.Adam(gan.parameters(), lr=lr, betas=(0.9, 0.999))
optimizer_d = optim.Adam(dis.parameters(), lr=lr/10, betas=(0.9, 0.999))

batch_size = 8

use_cuda = True
train_dataset = MVTecDataset(sss = 256, resize = 256)
val_dataset = MVTecDataset_test(sss = 256, resize = 256,is_train=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=32)
valid_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=16)
best_score = 0
#gan.load_state_dict(torch.load('cls_init_{}2.pth'.format(clss)), strict=False)
bbbb = 0

for i in range(300):

    adjust_learning_rate(optimizer_d, i)
    adjust_learning_rate(optimizer_a, i)

    train_anomaly_segloss = 0
    train_masked_reloss = 0

    for batch_idx, (inputs, targets, _) in enumerate(train_loader):

        inputs = inputs.cuda()
        inputs = torch.autograd.Variable(inputs)
        optimizer_d.zero_grad()
        x_Ano1, x_Ano, A_Ano = Anomaly_generation1(inputs)

        ### x_seg^nor adversarial learning
        l1 = inputs
        mmm, out = gan(l1)

        l1 = torch.cat([l1, torch.ones([l1.shape[0], 1, l1.shape[2], l1.shape[3]]).cuda()], dim=1)
        out = torch.cat([out, mmm], dim=1)

        innp = torch.cat([l1, out], dim=0)
        inxs = torch.randperm(innp.size()[0])
        innp = innp[inxs]

        real_dis, out_dis = dis(innp)
        loss_mse12 = criterion_bce(out_dis[:l1.shape[0]], out_dis[l1.shape[0]:])

        real_la = Variable(Tensor(real_dis.shape[0] // 2, 1).fill_(1.0), requires_grad=False)
        fake_la = Variable(Tensor(real_dis.shape[0] // 2, 1).fill_(0.0), requires_grad=False)

        ttar = torch.cat([real_la, fake_la], dim=0)[inxs]

        loss_dv = criterion_ce(real_dis, ttar)
        loss_dv.backward(retain_graph=True)

        ### x_seg^ano adversarial learning
        l1 = x_Ano
        mmm, out_ano = gan(l1)
        ano_dec_l1 = l1

        l1 = torch.cat([l1, A_Ano], dim=1)
        out_ano = torch.cat([out_ano, mmm], dim=1)

        inp1 = torch.cat([l1, out_ano], dim=0)

        inxs = torch.randperm(inp1.size()[0])
        inp1 = inp1[inxs]

        ano1_fake_dis, _ = dis(inp1)

        real_la = Variable(Tensor(ano1_fake_dis.shape[0] // 2, 1).fill_(1.0), requires_grad=False)
        fake_la = Variable(Tensor(ano1_fake_dis.shape[0] // 2, 1).fill_(0.0), requires_grad=False)

        ttar = torch.cat([real_la, fake_la], dim=0)[inxs]

        adver1_loss_f = criterion_ce(ano1_fake_dis, ttar)  # * 0.1

        adver1_loss_f.backward()
        optimizer_d.step()

        ### genertor adversarial learning
        inp = torch.cat([inputs, x_Ano], dim=0)
        inxs = torch.randperm(inp.size()[0])
        inp = inp[inxs]

        real_la = Variable(Tensor(inputs.shape[0]).fill_(1.0), requires_grad=False)
        fake_la = Variable(Tensor(x_Ano.shape[0]).fill_(0.0), requires_grad=False)

        ttar = torch.cat([real_la, fake_la], dim=0)[inxs]
        l1 = inp
        mmm, out = gan(l1)
        aamm = torch.cat([torch.ones([inputs.shape[0], 1, inputs.shape[2], inputs.shape[3]]).cuda(), A_Ano], dim=0)[inxs]


        ### masked reconstruction loss
        loss_mse = (criterion_c(out[ttar == 0], l1[ttar == 0]) * aamm[ttar == 0]).mean(1).sum() / aamm[ttar == 0].sum() \
                   + criterion_ae(out[ttar == 1], l1[ttar == 1])


        l1 = torch.cat([l1, aamm], dim=1)
        out = torch.cat([out, mmm], dim=1)
        out = torch.cat([l1, out], dim=0)
        fake_dis_gen, out_dis = dis(out)

        tmmm = torch.cat([torch.ones([inputs.shape[0], 1, inputs.shape[2], inputs.shape[3]]).cuda(), A_Ano], dim=0)[ inxs]

        ### anomaly segmentation loss
        loss_a1 = criterion_ae(mmm[ttar == 1], tmmm[ttar == 1]) * 0.1
        loss_a2 = criterion_ae(mmm[ttar == 0], tmmm[ttar == 0]) * 0.1

        real_la = Variable(Tensor(l1.shape[0], 1).fill_(1.0), requires_grad=False)
        loss_ae_v = criterion_ce(fake_dis_gen[l1.shape[0]:], real_la)

        loss_ae_all = (loss_mse) * 0.1 + (loss_ae_v) * 0.001 + loss_a1 + loss_a2

        optimizer_a.zero_grad()
        loss_ae_all.backward()
        optimizer_a.step()


        train_masked_reloss += loss_mse.item()
        train_anomaly_segloss += (loss_a1+loss_a2).item()


        if batch_idx % 2 == 0:
            print(batch_idx, train_masked_reloss / (batch_idx + 1), train_anomaly_segloss / (batch_idx + 1)
                  )
    gan.eval()
    dis.eval()

    m1 = []
    i1 = []

    for batch_idx, (inputs, targets, mas) in enumerate(valid_loader):

        inputs, mas = inputs.cuda(), mas.cuda()
        X = inputs
        with torch.no_grad():
            # _,a_X = auto(X)
            a_X = X
            mmm, g_X = gan(a_X)

        m1.append(mas.cpu())
        i1.append(1 - mmm.mean(1).unsqueeze(1).cpu().detach())

    m1 = torch.cat(m1, dim=0)
    i1 = torch.cat(i1, dim=0)

    pred_mask = i1.reshape(-1).numpy()
    GT_mask = m1.reshape(-1).numpy()

    score = segmentation_auroc(GT_mask, pred_mask)

    if best_score < score:
        best_score = score
        bep = i

    print(bep, '\x1b[1;31mpixel_best :\x1b[1;m', best_score, score)

    gan.train()
    dis.train()

