import argparse
import os

os.environ['MKL_THREADING_LAYER'] = 'GNU'


os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.optim
import torch.nn as nn
import numpy as np
import random
import train_util
#from data_util import load_data
from data_util import load_data
#import eval_loss_util
import models as models

import losses as losses
from score.both import get_inception_score_and_fid

from torchvision.utils import make_grid, save_image


from models.wgan_celeba import Discriminator, Generator


import random
from torch.utils.data import DataLoader, TensorDataset
model_names = sorted(name for name in models.__dict__
        if name.islower() and not name.startswith("__")
        and callable(models.__dict__[name]))




loss_fns = {
    'bce': losses.BCEWithLogits,
    'bcesamy': losses.BCESAMY,
    'hinge': losses.Hinge,
    'was': losses.Wasserstein,
    'softplus': losses.Softplus,
}


net_G_models = {
    'res32': models.ResGenerator32,
    'res48': models.ResGenerator48,
    'res64':models.ResGenerator64,
    'cnn32': models.Generator32,
    'cnn48': models.Generator48,
    'cnn64': models.Generator64,
}

net_D_models = {
    'res32': models.ResDiscriminator32,
    'res48': models.ResDiscriminator48,
    'res64': models.ResDiscriminator64,
    'cnn32': models.Discriminator32,
    'cnn48': models.Discriminator48,
    'cnn64': models.Discriminator64,
}


parser = argparse.ArgumentParser(description='CIFAR-10 GAN')

# setting
parser.add_argument('--epochs', default=100, type=int,
                    help='number of total epochs to run')
parser.add_argument('--dataset', default="cifar10", type=str,
                    help='cifar10 or lsun')
parser.add_argument('--path_data', default='/tigress/sjelassi/complete_cifar/data/', type=str,
                    help='path to store data')
#parser.add_argument('--fid_cache', default='/home/sjelassi/lsun/stats/lsun_stats.npz', type=str,
#                    help='path to fid cache')
parser.add_argument('--classes', default='church_outdoor', type=str,
                    help='choice of class')
#parser.add_argument('--image_size', default=32, type=int,
#                    help='Size of image')
#parser.add_argument('--arch', default="resnet18", type=str,
#                    help='choice of architecture')

## parameters GAN


parser.add_argument('--z_dim',  default=128, type=int,
                    help='latent space dimension')
parser.add_argument('--arch',  default="res32", type=str, #"res32"
                    help='architecture')
parser.add_argument('--loss',  default="was", type=str, #"was"
                    help='loss function')
parser.add_argument('--sample_size',  default=64, type=int,
                    help='sampling size of images')
parser.add_argument('--gp_weight',  default=10, type=float,
                    help='GP penalty')





## parameters optim
parser.add_argument('--optim_choice',   default="sgd", type=str,
                    help='choice of optimizer')
parser.add_argument('--lr_G',  default=0.1, type=float,
                    help='initial learning rate of generator')
parser.add_argument('--lr_D',  default=0.1, type=float,
                    help='initial learning rate of discriminator')
parser.add_argument('--batch_size',  default=1024, type=int,
                    help='mini-batch size (default: 1024)')
parser.add_argument('--momentum',  default=0.0, type=float,
                    help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float,
                    help='weight decay (default: 5e-4)')
parser.add_argument('--time1',  default=50, type=int,
                    help='first annealing time')
parser.add_argument('--time2',  default=75, type=int,
                    help='second annealing time')
parser.add_argument('--beta2',  default=0.999, type=float,
                    help='Beta 2 in Adam')
parser.add_argument('--epsadam',  default=1e-8, type=float,
                    help='epsilon in adam')
parser.add_argument('--gamma',  default=0.95, type=float,
                    help='Gamma scheduling factor')
parser.add_argument('--q',  default=0, type=float,
                    help='Lq normlization')

parser.add_argument('--saveconv', default=False, type=bool, help='Some helpful text that is not bar. Default = True')



parser.add_argument('--save', default="False", type=str,
                    help='Save or not')
parser.add_argument('--name', default='CIFAR-10-GAN', type=str,
                    help='name of experiment')
parser.add_argument('--seed', '-s', default=0, type=int,
                    help='seed (default: 0)')




def main():
    args = parser.parse_args()
    for arg in vars(args):
        print(arg, " : ", getattr(args, arg))


    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    #curr_dir=os.getcwd()
    #output_dir=curr_dir+'/output'
    #if not os.path.exists(output_dir):
    #    os.makedirs(output_dir)


#    image_size=32
#    nz=128


    if "cifar" in args.dataset:
       args.path_data="/tigress/sjelassi/complete_cifar/data/"
       fid_cache='/tigress/sjelassi/new_resnetwgan/stats/cifar10.test.npz'
    elif "lsun" in args.dataset:
       args.path_data="/tigress/sjelassi/new_resnetwgan/data/"#church_outdoor_train_lmdb/"
       fid_cache='/tigress/sjelassi/new_resnetwgan/stats/lsun_church.npz'
    elif "stl10" in args.dataset:
       args.path_data="/tigress/sjelassi/new_resnetwgan/data/"#church_outdoor_train_lmdb/"
       fid_cache='/tigress/sjelassi/new_resnetwgan/stats/stl10_unlabeled.npz'
    else:
       args.path_data="/tigress/sjelassi/new_resnetwgan/data/celeba/"
       #fid_cache="/tigress/sjelassi/new_resnetwgan/stats/celebhq.3k.128.npz"
       fid_cache="/tigress/sjelassi/new_resnetwgan/stats/celeba_stats.npz"
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

        dataloader, dataloader2, num_examples = load_data(args, device)

    if "cifar" not in args.dataset:
       args.dim_z=100
    else:
       args.dim_z=128
    print(device)
    #args.arch = "cnn64"
    #args.loss = "bcesamy"
    netG = net_G_models[args.arch](args.z_dim).to(device)
    netD = net_D_models[args.arch]().to(device)

    loss_fn = loss_fns[args.loss]()




    lower_opt = (args.optim_choice).lower()

    if "graftd" in lower_opt:
        args.graftd = True
    else:
        args.graftd = False



    optim_hparams = {
        'initial_lrG' : args.lr_G,
        'initial_lrD' : args.lr_D,
        'momentum' : args.momentum,
        'weight_decay' : args.weight_decay,
        'graft_d' : args.graftd,
 #       'beta1' : args.beta1,
        'beta2' : args.beta2,
        'q' : args.q,
        'gamma': args.gamma,
    }


    optimizerG, optimizerD = train_util.create_optimizer(netG,netD,args.optim_choice,
        optim_hparams)


    test_tab=[]
    train_tab=[]
    val_acc=1000
     total_steps=int(num_examples/args.batch_size)*args.epochs
    #print(args.epochs)
    #print(num_examples)
    #print(total_steps)

    #if lower_opt !="sgd_recover_graft111111111":


    schedulerG = torch.optim.lr_scheduler.LambdaLR(optimizerG, lambda step: 1 - step /total_steps)
    schedulerD = torch.optim.lr_scheduler.LambdaLR(optimizerD, lambda step: 1 - step /total_steps)

   # else: 
   #    schedulerG = None
   #    schedulerD = None

    sample_z = torch.randn(args.sample_size, args.z_dim).to(device)

    OPT = args.optim_choice.upper()

    normD_tab=[]
    normG_tab=[]
    ratio_D=0
    ratio_G=0

    last_fid=1000
    last_is=(1000,1000)

    if "cifar" in args.dataset:
         eval_freq=30#40
    elif "lsun" in args.dataset or "stl10" in args.dataset:
       eval_freq=15
    elif "celeba" in args.dataset:
       eval_freq=30
    else:
       eval_freq=20

    if args.save=="True":
       eval_freq = 10#10
    eval_freq2=20#20


    #ct_netG=0
    #for parameter in netG.parameters():
    #    ct_netG+=1
    #ct_netD=0
    #for parameter in netD.parameters():
    #    ct_netD+=1

    ct_netD=0
    ct_netG=0
    for group in optimizerD.param_groups:
       for p in group['params']:
          ct_netD+=1

    for group in optimizerG.param_groups:
       for p in group['params']:
          ct_netG+=1


    num_steps_total = args.epochs*int(num_examples/args.batch_size)
    num_steps = int(num_examples/args.batch_size)


    etaM_netG = np.zeros((num_steps_total,ct_netG))
    etaD_netG = np.zeros((num_steps_total,ct_netG))
    fulleta_netG = np.zeros((num_steps_total,ct_netG))

    etaM_netD = np.zeros((num_steps_total,ct_netD))
    etaD_netD = np.zeros((num_steps_total,ct_netD))
    fulleta_netD = np.zeros((num_steps_total,ct_netD))

    ratio_tab_D = []#np.zeros(num_steps_total)
    ratio_tab_G = []#np.zeros(num_steps_total)


    full_eta_D = []
    full_eta_G = []

    fid_tab=[]
    ratioD_tab = []
    ratioG_tab = []
    for epoch in range(args.epochs):

        print("Epoch" + str(epoch))
        for param_group, param_group1 in zip(optimizerG.param_groups,optimizerD.param_groups):
            if lower_opt !="sgd_recover_graft":
               print("LR_G: "+str(param_group['lr']))
               print("LR_D: "+str(param_group1['lr']))
            else:
               print("LR_G: "+str(param_group['lr_full']))
               print("LR_D: "+str(param_group1['lr_full']))
            if OPT=="ADAM" or OPT=="GRAFTL":# ("adam" in lower_opt) or ("graft" in lower_opt) :
               print("M_G: "+str(param_group['betas']))
               print("M_D: "+str(param_group1['betas']))
            elif OPT=="SGD":
               print("M_G: "+str(param_group['momentum']))
               print("M_D: "+str(param_group1['momentum']))



        norm_grad_D, norm_grad_G, d_loss, g_loss, etaM_D_tmp, etaD_D_tmp, fulleta_D_tmp, etaM_G_tmp, etaD_G_tmp, fulleta_G_tmp =train_util.train_loop(dataloader,dataloader2,args.loss, loss_fn, args.batch_size, netG, netD, optimizerG, optimizerD, schedulerG, schedulerD, OPT, ct_netG, ct_netD, epoch, args.z_dim, full_eta_D, full_eta_G, device, args.gp_weight, args, eval_freq2)


        #norm_grad_D, norm_grad_G, d_loss, g_loss, D_x, D_G_z1, D_G_z2=train_util.train_loop_bce(dataloader, dataloader2, args.batch_size, netG, netD, optimizerG, optimizerD,schedulerG, schedulerD,OPT,epoch,args.z_dim, device)

        if lower_opt=="graftl" and args.save=="True":
           step_idx=epoch*int(50000/args.batch_size)
           step_idx_nxt = (epoch+1)*int(50000/args.batch_size)

           etaM_netG[step_idx:step_idx_nxt,:]=etaM_G_tmp
           etaD_netG[step_idx:step_idx_nxt,:]=etaD_G_tmp
           fulleta_netG[step_idx:step_idx_nxt,:]=fulleta_G_tmp

           etaM_netD[step_idx:step_idx_nxt,:]=etaM_D_tmp
           etaD_netD[step_idx:step_idx_nxt,:]=etaD_D_tmp
           fulleta_netD[step_idx:step_idx_nxt,:]=fulleta_D_tmp

                if epoch==0:

           norm_grad_D_init=norm_grad_D
           norm_grad_G_init=norm_grad_G
           print('[%d]  D loss: %.4f' % (epoch, d_loss ), flush=True)
           print('[%d]  G loss: %.4f' % (epoch, g_loss ), flush=True)


           print("Norm init gradient D: {} ".format(norm_grad_D_init), flush=True)
           print("Norm init gradient G: {} ".format(norm_grad_G_init), flush=True)

        else:
           print('[%d]  D loss: %.4f' % (epoch, d_loss ), flush=True)
           print('[%d]  G loss: %.4f' % (epoch, g_loss ), flush=True)


           ratio_D=norm_grad_D/norm_grad_D_init
           ratio_G=norm_grad_G/norm_grad_G_init

           print("Ratio gradient D: {} ".format(ratio_D), flush=True)
           print("Ratio gradient G: {} ".format(ratio_G), flush=True)

        if args.save=="True":
           if epoch==1 or (epoch+1)%eval_freq==0 or (epoch+1)==args.epochs:
              ratioD_tab.append(ratio_D) #to unput 
              ratioG_tab.append(ratio_G) #to unput
          if args.save=="True":
           if (epoch+1)%eval_freq==0 or (epoch+1)==args.epochs:
              num_im=50000
              fid_sample_batch_size=args.batch_size
              if "cifar" in args.dataset:
                 fid_batch_size=50
              else:
                 fid_batch_size=200
              imgs = train_util.generate_imgs(netG, device, args.z_dim, num_im, fid_sample_batch_size)
              IS, FID = get_inception_score_and_fid(imgs, fid_cache, batch_size=fid_batch_size, verbose=False)
              print("FID accuracy {}".format(FID), flush=True )
              if np.isnan(FID) or FID==np.inf:
                 FID=10000000000000
                 break
              fid_tab.append(FID)
              last_fid=FID
              last_is=IS
        else:
           if (epoch+1)==eval_freq or (epoch+1)==args.epochs:

              num_im=50000
              fid_sample_batch_size=args.batch_size
              if "cifar" in args.dataset:
                 fid_batch_size=50
              else:
                 fid_batch_size=200
              imgs = train_util.generate_imgs(netG, device, args.z_dim, num_im, fid_sample_batch_size)
              IS, FID = get_inception_score_and_fid(imgs, fid_cache, batch_size=fid_batch_size, verbose=False)

              #print("IS {}".format(IS))
              print("FID accuracy {}".format(FID), flush=True )
              if np.isnan(FID) or FID==np.inf:
                 FID=10000000000000
                 break

              if "cifar" in  args.dataset:
                 if OPT=="SGD":
                    thresh_ref = 195
                 elif OPT=="NORMALIZED":
                    thresh_ref = 160
                 elif OPT=="ADAM":
                    thresh_ref=170  #60 WGAN-GP
                 else:
                    thresh_ref = 300#55#75#60
              elif "lsun" in args.dataset:
                  thresh_ref=50#27#30#20 
              elif "stl10" in args.dataset:
                  thresh_ref=120#80
              elif "celeba" in args.dataset:
                  thresh_ref = 170#70
              #if args.save!="True":
              if FID >thresh_ref:
                    last_fid=FID
                    last_is=IS
                    break

              last_fid=FID
              last_is=IS

       if norm_grad_D==0 or norm_grad_G==0:
           norm_grad_D=norm_grad_G=10000000000000
           d_loss=g_loss=1000000000000000

           break
        if d_loss==0 or g_loss==0:
           break
        if d_loss==np.inf or np.isnan(d_loss) or g_loss==np.inf or np.isnan(g_loss):
           norm_grad_D=norm_grad_G=10000000000000
           d_loss=g_loss=1000000000000000
           break
        if ratio_D == np.inf or ratio_G == np.inf or np.isnan(ratio_D) or np.isnan(ratio_G):
           norm_grad_D=norm_grad_G=10000000000000
           d_loss=g_loss=1000000000000000
           break
    OPT = args.optim_choice.upper()

    if lower_opt=="graftl" and args.save=="True":
       #print(etaM_netD)
       #print("\n") 
       #print(etaM_netG)
       print(ratioD_tab)
       print(ratioG_tab)
       print(fid_tab)

#    train_util.generate_save_imgs(OPT, args.seed, dataloader, netG, fixed_noise, epoch,device)



    print("\n")
    if OPT=="SGD":
        print("Norm gradient D: {}".format(ratio_D), flush=True)
        print("Norm gradient G: {}".format(ratio_G), flush=True)
        print("d loss: {}".format(d_loss), flush=True)
        print("g loss: {}".format(g_loss), flush=True)


        print("FID: {}".format(last_fid), flush=True)
        print("IS: {}".format(last_is), flush=True)

        print("Seed: {}".format(args.seed), flush=True)
        print("Optimization algorithm: {}".format(OPT), flush=True)
        print("LR_G: {}; LR_D: {};  B: {}; M: {};   WD: {}".format(args.lr_G, args.lr_D,\
              args.batch_size, args.momentum,args.weight_decay), flush=True)
    else:
        print("Norm gradient D: {}".format(ratio_D), flush=True)
        print("Norm gradient G: {}".format(ratio_G), flush=True)
        print("d loss: {}".format(d_loss), flush=True)
        print("g loss: {}".format(g_loss), flush=True)
        print("FID: {}".format(last_fid), flush=True)
        print("IS: {} +- {}".format(last_is[0],last_is[1]), flush=True)

        print("Seed: {}".format(args.seed), flush=True)
        print("Optimization algorithm: {}".format(OPT), flush=True)
        print("Dataset: {}".format(args.dataset), flush=True)
        print("Architecture: {}".format(args.arch), flush=True)
        print("LR_G: {}; LR_D: {};  B: {}; M: {}; beta_2: {}; WD: {}".format(args.lr_G, args.lr_D,\
              args.batch_size, args.momentum, args.beta2, args.weight_decay), flush=True)

