import os
import math
import argparse

import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as T
import random
import numpy as np
from utils import *
from model.model import *
from PIL import Image

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
p = argparse.ArgumentParser()
p.add_argument('--dataset', '-dataset', default='mnist')
p.add_argument('--log_dir_root', '-o', default='/data1/mclgan_exp/test')
p.add_argument('--manual_seed', '-seed', type=int)
p.add_argument('--n_disc', '-ndisc', type=int, default=10) # number of discriminators
p.add_argument('--n_expert', '-nexp', type=int, default=1) # should be <= n_gens
p.add_argument('--gan_type', type=str, default='dcgan') # dcgan / lsgan / hinge
p.add_argument('--d_batch_size', '-d_batch', type=int, default=64) 
p.add_argument('--g_batch_size', '-g_batch', type=int, default=128) 
p.add_argument('--fixed_z_batch_size', '-fixed_z_batch', type=int, default=100) 
p.add_argument('--z_prior', '-zp', type=str, default='n') # u :uniform, n :normal
p.add_argument('--n_epoch', '-nepoch', type=int, default=40) 
p.add_argument('--img_size', '-img_size', type=int, default=32) # side length of square image
p.add_argument('--num_channel', '-num_channel', type=int, default=3)  
p.add_argument('--d_learning_rate', '-d_lr', type=float, default=0.0002)
p.add_argument('--g_learning_rate', '-g_lr', type=float, default=0.0002)  
p.add_argument('--d_weight_decay', '-d_wd', type=float, default=0.0)
p.add_argument('--g_weight_decay', '-g_wd', type=float, default=0.0)
p.add_argument('--beta1', '-beta1', type=float, default=0.5) # for adam optimizer
p.add_argument('--beta2', '-beta2', type=float, default=0.999) # for adam optimizer
p.add_argument('--lr_update_freq', '-lrf', type=int, default=1) # learning rate update frequency
p.add_argument('--lr_gamma', '-lrg', type=float, default=0.0) 
p.add_argument('--kld_decay', '-kd', type=float, default=0.9) # kld loss weight decay
p.add_argument('--nz', '-nz', type=int, default=100)
p.add_argument('--ndf', '-ndf', type=int, default=128)
p.add_argument('--ngf', '-ngf', type=int, default=128)
p.add_argument('--d_lambda_kld', '-dkld', type=float, default=0.5) # balance loss weight for discriminator (kld loss)
p.add_argument('--g_lambda_kld', '-gkld', type=float, default=0.0) # balance loss weight for generator (kld loss)
p.add_argument('--lambda_ne', '-lambda_ne', type=float, default=1.0) # nonexpert loss weight
p.add_argument('--temperature', '-t', type=float, default=1) # softmax temperature
p.add_argument('--lambda_l1', '-lambda_l1', type=float, default=0.0) # L1 loss weight
p.add_argument('--nonexpert_label', '-ne_label', type=float, default=0.5) # soft label for real data for nonexpert loss
p.add_argument('--kld_update_freq', '-kf', type=int, default=10) # kld loss weight update frequency



args = p.parse_args()

if args.manual_seed is None:
    args.manual_seed = 4201
random.seed(args.manual_seed)
torch.manual_seed(args.manual_seed)

log_dir_root = args.log_dir_root
os.makedirs(log_dir_root, exist_ok = True)
img_dir = os.path.join(log_dir_root, 'result')
os.makedirs(img_dir, exist_ok = True)
model_dir = os.path.join(log_dir_root, 'trained_model')
os.makedirs(model_dir, exist_ok = True)


def main():

    if args.dataset == 'mnist':
        train_loader = torch.utils.data.DataLoader(
                datasets.MNIST('/data1/dataset/mnist', True,
                    T.Compose([T.Resize(args.img_size),
                    T.ToTensor(),
                    T.Normalize(mean=[0.5], std=[0.5]),
                ]), download=True),
                batch_size=args.d_batch_size, shuffle=True, drop_last=True)
        args.num_channel = 1
    
    elif args.dataset == 'f-mnist':
        train_loader = torch.utils.data.DataLoader(
                datasets.FashionMNIST('/data1/dataset/fashion-mnist', True,
                    T.Compose([T.Resize(args.img_size),
                    T.ToTensor(),
                    T.Normalize(mean=[0.5], std=[0.5]),
                ]), download=True),
                batch_size=args.d_batch_size, shuffle=True, drop_last=True)
        args.num_channel = 1

    elif args.dataset == 'cifar10':
        train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('/data1/dataset/cifar10', True,
                    T.Compose([T.Resize(args.img_size),
                    T.ToTensor(),
                    T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                ]), download=True),
                batch_size=args.d_batch_size, shuffle=True, drop_last=True)
    
    elif args.dataset == 'celeba':
        train_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder('/data1/dataset/celebA', 
                    T.Compose([T.Resize((args.img_size,args.img_size)),
                    T.ToTensor(),
                    T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                ]), ),
                batch_size=args.d_batch_size, shuffle=True, drop_last=True)

    model = mclgan(device, args)
    model.train()

    current_step = 0

    for epoch in range(args.n_epoch):
        lossD_sum, lossD_real, lossD_xfake, lossD_gfake, lossD_kld, lossG_sum, lossG_exp, lossG_nonexp, lossG_kld = np.zeros([9])
        
        for i, (x, y) in enumerate(train_loader):
            current_step += 1

            model.feed_data(x, y)
            model.optimize_parameters()
            
            # accum loss
            tmp_loss = model.get_loss()
            lossD_sum += tmp_loss[0][0]; lossD_real += tmp_loss[0][1]; lossD_xfake += tmp_loss[0][2]; lossD_gfake += tmp_loss[0][3]; lossD_kld += tmp_loss[0][4] 
            lossG_sum += tmp_loss[1][0]; lossG_exp += tmp_loss[1][1]; lossG_nonexp += tmp_loss[1][2]; lossG_kld += tmp_loss[1][3]

            if (i == len(train_loader) - 1):    
                # visualize images
                fixed_visuals = model.get_visuals()
                image_grid = tensor2np_img(fixed_visuals)
                save_path_name = os.path.join(img_dir, '{}_samples.jpg'.format(epoch+1))
                save_np_img(image_grid, save_path_name)
            
            
        avg_loss_D = np.array([lossD_sum, lossD_real, lossD_xfake, lossD_gfake, lossD_kld]) / len(train_loader) 
        avg_loss_G = np.array([lossG_sum, lossG_exp, lossG_nonexp, lossG_kld]) / len(train_loader)  
        print ('Epoch[{}] Loss D sum: {:.3f}, D real: {:.3f}, D xfake: {:.3f}, D gfake: {:.3f}, D kld: {:.3f} / Loss G sum: {:.3f}, G exp: {:.3f}, G nonexpcls: {:.3f}, G kld: {:.3f}'.format(epoch+1, 
                        avg_loss_D[0], avg_loss_D[1], avg_loss_D[2], avg_loss_D[3], avg_loss_D[4], avg_loss_G[0], avg_loss_G[1], avg_loss_G[2], avg_loss_G[3]))
        
        print ('===End of Epoch {}==='.format(epoch+1))
        model.save_model(epoch+1, model_dir)
        print ('Saved the models!')
        model.update_epoch()
        if ((epoch+1) % args.lr_update_freq == 0):
            model.update_learning_rate()
        if ((epoch+1) % args.kld_update_freq == 0):
            model.update_kld()
        
        # draw histogram
        plt.rcParams["figure.figsize"] = (12,4)
        plt.bar(np.arange(1, model.n_disc + 1), model.hist_fake)
        plt.xlabel("Discriminator ID")
        plt.ylabel("Counts")
        plt.xticks(np.arange(1, model.n_disc + 1))
        plt.savefig('{}/fake_stat'.format(log_dir_root))
        plt.close()
        
        plt.bar(np.arange(1, model.n_disc + 1), model.hist_real)
        plt.xlabel("Discriminator ID")
        plt.ylabel("Counts")
        plt.xticks(np.arange(1, model.n_disc + 1))
        plt.savefig('{}/real_stat'.format(log_dir_root))
        plt.close() 
    
if __name__ == '__main__':
    main()
