#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jul  8 10:22:24 2020

@author: zw
"""
import argparse
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transforms
from torch.utils.data.distributed import DistributedSampler

import numpy as np
import glob
import time

#from data_loader import Rescale
from data_loader import Randomflip
from data_loader import RescaleT
from data_loader import RandomCrop
#from data_loader import CenterCrop
#from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import CObjDataset

from model import BaseFG, BaseFGM

import pytorch_ssim
import pytorch_iou
import os
from PIL import Image
from skimage import io, transform
import cv2
import random

import wandb

parser = argparse.ArgumentParser(description='PyTorch')
parser.add_argument('--pretrain', action='store_true', default=False, help='use pretrain encoding')
parser.add_argument('--seed', action='store_true', default=False, help='hold seed')
parser.add_argument('--seedvalue', default=256, type=int) 
parser.add_argument('--gpuname', default='0', type=str)
parser.add_argument('--batch_size_train', type=int, default=12, metavar='N', help='input batch size for training')
parser.add_argument('--accumulation_steps', type=int, default=4)
parser.add_argument('--batch_size_val', type=int, default=1, metavar='N', help='input batch size for training')
parser.add_argument('--epoch_num', type=int, default=1000)
parser.add_argument('--check_ite', type=int, default=4000)
parser.add_argument('--begin_ite', type=int, default=16000)
parser.add_argument('--step_size', type=int, default=30)
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate (default: 0.001)')

parser.add_argument('--noFG', type=float, default=1, metavar='LR')
parser.add_argument('--trainset', default='TrainDataset', type=str)
parser.add_argument('--level', default='E', type=str)

parser.add_argument('--tag', default='0', type=str)

args = parser.parse_args()

torch.backends.cudnn.enabled = True

train_n = args.gpuname
os.environ["CUDA_VISIBLE_DEVICES"] = train_n


wandb.init(project="My ToyFG")
wandb.watch_called = False


if args.seed:
    def setup_seed(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    setup_seed(args.seedvalue)
else:
    torch.backends.cudnn.benchmark = True

# about test
def eval_mae(y_pred, y):
    """
    evaluate MAE (for test or validation phase)
    :param y_pred:
    :param y:
    :return: Mean Absolute Error
    """
    N, _, _, _ = y_pred.size()
    ae = torch.abs(y_pred - y).mean() * N
    
    return ae

def normPRED(d):
	ma = torch.max(d)
	mi = torch.min(d)

	dn = (d - mi) / (ma - mi)

	return dn

def save_out(pred_list, d_dir):
    p = pred_list[0]
    _, w, h = np.shape(p)
    x = np.zeros([1, w, h])
    
    for pp in pred_list:
        x = np.concatenate((x, pp), axis=0)
    x = x[1:, :, :]
    
    leng, _, _ = np.shape(x)
    
    for i in range(leng):
        img = x[i] * 255
        cv2.imwrite(d_dir+str(i)+'.png', img)
        
# ------- 1. define loss function --------

bce_loss = nn.BCELoss(size_average=True)
mse_loss = nn.MSELoss(size_average=True)
ce_loss = nn.CrossEntropyLoss(size_average=True)
ssim_loss = pytorch_ssim.SSIM(window_size=11, size_average=True)
iou_loss = pytorch_iou.IOU(size_average=True)

def loss_list_main(D, labels_v):
    loss = 0.0
    
    if (len(D)>4):
        for i in range(len(D)):
            if (i==0 or i==1 or i==2 or i==3):
                loss = loss + 1.0 * bce_loss_list(D[i], labels_v)
            elif (i==4 or i==5 or i==6 or i==7):
                loss = loss + 1.0 * bce_loss_list(D[i], labels_v)
            else:    
                loss = loss + 1.0 * bce_loss_list(D[i], labels_v)
    else:
        for i in range(len(D)):
            loss = loss + bce_loss_list(D[i], labels_v)

    loss0 = bce_loss_list(D[0], labels_v)

    return loss0, loss

def bce_loss_list(pred, target):
    bce_out = bce_loss(pred, target)
    
    loss = bce_out
	
    return loss

# ------- 2. set the directory of training dataset --------

basedir = os.getcwd()
data_dir = basedir + '/dataset/FGDataset' + args.level +'/' + args.trainset
tra_image_dir = '/Imgs/'
tra_label_dir = '/GT/'
image_ext = '.jpg'
label_ext = '.png'
model_dir = basedir + '/model_save' + args.tag + '/'

train_num = 0
val_num = 0

tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)

tra_lbl_name_list = []
for img_path in tra_img_name_list:
	img_name = img_path.split("/")[-1]

	aaa = img_name.split(".")
	bbb = aaa[0:-1]
	imidx = bbb[0]
	for i in range(1,len(bbb)):
		imidx = imidx + "." + bbb[i]

	tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)

print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")

#train data load
train_num = len(tra_img_name_list)

codobj_dataset = CObjDataset(
    img_name_list=tra_img_name_list,
    lbl_name_list=tra_lbl_name_list,
    edg_name_list=[],
    transform=transforms.Compose([
        RescaleT(224),
        ToTensorLab(flag=0, state='FG')]))
codobj_dataloader = DataLoader(codobj_dataset, batch_size=args.batch_size_train, shuffle=True, num_workers=8, drop_last=True)

# ------- 3. define model --------

# define the net
net = BaseFGM(pretrain=args.pretrain)

if torch.cuda.is_available():
    net.cuda()

# ------- 4. define optimizer --------

print("---define optimizer...")
optimizer = optim.AdamW(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)

# ------- 5. training process --------

print("---start training...")
ite_num = 0
running_loss = 0.0
running_main_loss = 0.0
ite_num4val = 0

time_temp = time.time()
print("[Train Begin: %s]" % time.ctime())
for epoch in range(0, args.epoch_num):
    net.train()

    for i, data in enumerate(codobj_dataloader):
        ite_num = ite_num + 1
        ite_num4val = ite_num4val + 1

        inputs, labels = data['image'], data['label']
        inputs = inputs.type(torch.FloatTensor)
        labels = labels.type(torch.FloatTensor)
        
        # structure
        orders = data['order']
        orders = orders.type(torch.FloatTensor)

        # edge
        conxs = data['conx']
        conxs = conxs.type(torch.FloatTensor)

        if torch.cuda.is_available():
            orders_v = Variable(orders.cuda(), requires_grad=False)
        else:
            orders_v = Variable(orders, requires_grad=False)
            
        if torch.cuda.is_available():
            conxs_v = Variable(conxs.cuda(), requires_grad=False)
        else:
            conxs_v = Variable(conxs, requires_grad=False)
        
        # wrap them in Variable
        if torch.cuda.is_available():
            inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
        else:
            inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
        
        if i == 0:
            optimizer.zero_grad()
        
        D = net(inputs_v)
        lossD, loss_main = loss_list_main(D[0], labels_v)
        
        lossD_order, loss_order = loss_list_main(D[1], orders_v)
        lossD_conx, loss_conx = loss_list_main(D[2], conxs_v)
        
        loss = loss_main + args.noFG * (loss_conx  + loss_order)
        #loss = loss_main
        
        loss.backward()
        #optimizer.step()
        
        if (i % args.accumulation_steps) == 0:
            optimizer.step()
            optimizer.zero_grad()
        
        # # print statistics
        running_loss += loss.item()
        running_main_loss += lossD.item()
        running_order_loss = loss_order.item()
        running_conx_loss = loss_conx.item()

        plot_main_loss = lossD.item()

        # del temporary outputs and loss
        del D, lossD, lossD_order, lossD_conx, loss_main, loss_order, loss_conx, loss
        
        if i % 200 == 0:
            print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train: %.1f, main: %.4f, order: %.3f, conx: %.3f " % (
                epoch + 1, args.epoch_num, (i + 1) * args.batch_size_train, train_num, ite_num, running_loss/ite_num4val, running_main_loss/ite_num4val, running_order_loss, running_conx_loss))
            print("[Time Use: %2f]" % (time.time() - time_temp))
            time_temp = time.time()

        if ite_num % args.check_ite == 0 and ite_num > args.begin_ite:
            
            print("Testing, Saving Model and Saving Image...")
            torch.save(net.state_dict(), model_dir + "Model_%d_loss_%.4f.pth" % (ite_num, running_main_loss / ite_num4val))
            running_loss = 0.0
            running_main_loss = 0.0
            net.train()  # resume train
            ite_num4val = 0

    wandb.log({
                "Train_Loss": plot_main_loss})

print('-------------Congratulations! Training Done!!!-------------')