import os
import math
import cfg

import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as T
import random
import numpy as np
from utils import *
from model.model import *
from PIL import Image

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def main():

    args = cfg.parse_args()
    if args.manual_seed is None:
        args.manual_seed = 4201
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)

    log_dir_root = args.log_dir_root
    os.makedirs(log_dir_root, exist_ok = True)
    img_dir = os.path.join(log_dir_root, 'result')
    os.makedirs(img_dir, exist_ok = True)
    model_dir = os.path.join(log_dir_root, 'trained_model')
    os.makedirs(model_dir, exist_ok = True)

    if args.dataset == 'mnist':
        train_loader = torch.utils.data.DataLoader(
                datasets.MNIST('/data1/dataset/mnist', True,
                    T.Compose([T.Resize(args.img_size),
                    T.ToTensor(),
                    T.Normalize(mean=[0.5], std=[0.5]),
                ]), download=True),
                batch_size=args.d_batch_size, shuffle=True, drop_last=True)
        args.num_channel = 1
    
    elif args.dataset == 'f-mnist':
        train_loader = torch.utils.data.DataLoader(
                datasets.FashionMNIST('/data1/dataset/fashion-mnist', True,
                    T.Compose([T.Resize(args.img_size),
                    T.ToTensor(),
                    T.Normalize(mean=[0.5], std=[0.5]),
                ]), download=True),
                batch_size=args.d_batch_size, shuffle=True, drop_last=True)
        args.num_channel = 1

    elif args.dataset == 'cifar10':
        train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('/data1/dataset/cifar10', True,
                    T.Compose([T.Resize(args.img_size),
                    T.ToTensor(),
                    T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                ]), download=True),
                batch_size=args.d_batch_size, shuffle=True, drop_last=True)
    
    elif args.dataset == 'celeba':
        train_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder('/data1/dataset/celebA', 
                    T.Compose([T.Resize((args.img_size,args.img_size)),
                    T.ToTensor(),
                    T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                ]), ),
                batch_size=args.d_batch_size, shuffle=True, drop_last=True)

    model = mclgan(device, args)
    model.train()
    current_step = 0
    print (model.netD)
    for epoch in range(args.n_epoch):
        lossD_sum, lossD_real, lossD_xfake, lossD_gfake, lossD_kld, lossG_sum, lossG_exp, lossG_nonexp, lossG_kld = np.zeros([9])
        
        for i, (x, y) in enumerate(train_loader):
            current_step += 1

            model.feed_data(x, y)
            model.optimize_parameters()
            
            # accum loss
            tmp_loss = model.get_loss()
            lossD_sum += tmp_loss[0][0]; lossD_real += tmp_loss[0][1]; lossD_xfake += tmp_loss[0][2]; lossD_gfake += tmp_loss[0][3]; lossD_kld += tmp_loss[0][4] 
            lossG_sum += tmp_loss[1][0]; lossG_exp += tmp_loss[1][1]; lossG_nonexp += tmp_loss[1][2]; lossG_kld += tmp_loss[1][3]

            if (i == len(train_loader) - 1):    
                # visualize images
                fixed_visuals = model.get_visuals()
                image_grid = tensor2np_img(fixed_visuals)
                save_path_name = os.path.join(img_dir, '{}_samples.jpg'.format(epoch+1))
                save_np_img(image_grid, save_path_name)
            
            
        avg_loss_D = np.array([lossD_sum, lossD_real, lossD_xfake, lossD_gfake, lossD_kld]) / len(train_loader) 
        avg_loss_G = np.array([lossG_sum, lossG_exp, lossG_nonexp, lossG_kld]) / len(train_loader)  
        print ('Epoch[{}] Loss D sum: {:.3f}, D real: {:.3f}, D xfake: {:.3f}, D gfake: {:.3f}, D kld: {:.3f} / Loss G sum: {:.3f}, G exp: {:.3f}, G nonexpcls: {:.3f}, G kld: {:.3f}'.format(epoch+1, 
                        avg_loss_D[0], avg_loss_D[1], avg_loss_D[2], avg_loss_D[3], avg_loss_D[4], avg_loss_G[0], avg_loss_G[1], avg_loss_G[2], avg_loss_G[3]))
        
        print ('===End of Epoch {}==='.format(epoch+1))
        model.save_model(epoch+1, model_dir)
        print ('Saved the models!')
        model.update_epoch()
        if ((epoch+1) % args.lr_update_freq == 0):
            model.update_learning_rate()
        if ((epoch+1) % args.kld_update_freq == 0):
            model.update_kld()
        
        # draw histogram
        plt.rcParams["figure.figsize"] = (12,4)
        plt.bar(np.arange(1, model.n_disc + 1), model.hist_fake)
        plt.xlabel("Discriminator ID")
        plt.ylabel("Counts")
        plt.xticks(np.arange(1, model.n_disc + 1))
        plt.savefig('{}/fake_stat'.format(log_dir_root))
        plt.close()
        
        plt.bar(np.arange(1, model.n_disc + 1), model.hist_real)
        plt.xlabel("Discriminator ID")
        plt.ylabel("Counts")
        plt.xticks(np.arange(1, model.n_disc + 1))
        plt.savefig('{}/real_stat'.format(log_dir_root))
        plt.close() 
    
if __name__ == '__main__':
    main()
