import torch
from torch import autograd
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import sys
import os
import os
import datetime
from scipy import stats
from MoG import kde, real_builder_circle, real_builder_diamond
from GameLosses import gan_model_mog
import time

def calc_gradient_norm(params):
    norm = lambda p: p.grad.data.norm(2).item()
    return sum(norm(p)**2 for p in params) ** (1. / 2)

def log(**datum):
    data = []
    sys.stdout.write('\r')
    sys.stdout.write(str(datum))
    data.append(datum)
    
    return list(data[0].values())


np.set_printoptions(precision=2)

lr_g = 5e-3
batch_size = 512
tau = int(sys.argv[1])
reg_param = float(sys.argv[2])
seed = int(sys.argv[3])
num_iter = 60000
view_size = 4096
freq_view = 4000
show = False
alternating = False

activation_function = nn.ReLU()
n_latent = 16
n_out = 2
n_hidden = 32

np.random.seed(0)
torch.manual_seed(0)    
MoG_Type = 'circle'
real_builder = real_builder_circle
bbox=[-2, 2, -2, 2]  
xx, yy = np.mgrid[bbox[0]:bbox[1]:300j, bbox[2]:bbox[3]:300j]
positions = np.vstack([xx.ravel(), yy.ravel()])
real_data = real_builder(view_size)
real_kernel = stats.gaussian_kde(real_data.T)
device = "cuda"
fixed_noise = torch.randn(view_size, n_latent, device=device)

def compute_grad2(d_out, x_in):
    batch_size = x_in.size(0)
    grad_dout = autograd.grad(
        outputs=d_out.sum(), inputs=x_in,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    grad_dout2 = grad_dout.pow(2)
    assert(grad_dout2.size() == x_in.size())
    reg = grad_dout2.view(batch_size, -1).sum(1)
    return reg

start = time.time()

            
kl_data = []

torch.manual_seed(seed)    
np.random.seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

now = datetime.datetime.now()
now = now.strftime("%Y-%m-%d_%H-%M")

sim_name = 'mog_test_' + str(batch_size) + '_' + str(tau) + '_' + str(seed)  
num_gen_layers = 2
num_disc_layers = 1

results_dir = os.path.join(os.getcwd(), 'MoGResults_reg_'+str(reg_param).replace('.', ''))
save_dir = os.path.join(results_dir, '_'.join([sim_name, now]))
fig_dir = os.path.join(save_dir, 'Figs')
kl_dir = os.path.join(save_dir, 'KL')
if not os.path.exists(fig_dir): os.makedirs(fig_dir)
if not os.path.exists(kl_dir): os.makedirs(kl_dir)


G,D = gan_model_mog(n_latent=n_latent, n_out=n_out, n_hidden=n_hidden, \
                 num_gen_layers=num_gen_layers,
                 num_disc_layers=num_disc_layers,
                activation_function=activation_function)


G.to(device)
D.to(device)


n_g = sum(p.numel() for p in G.parameters())
n_d = sum(p.numel() for p in D.parameters())


# set learning rate info 
lr_d = tau*lr_g

opt_G = torch.optim.SGD(G.parameters(), lr=lr_g)
opt_D = torch.optim.SGD(D.parameters(), lr=lr_d)

criterion = nn.BCELoss()
real_label = 1.
fake_label = 0.

for step in range(1, num_iter+1):

    # generator
    G.zero_grad()
    noise = torch.randn(batch_size, n_latent, device=device)
    fake = G(noise)
    label = torch.full((batch_size,), fill_value=real_label, device=device)
    output = D(fake).view(-1)
    errG = criterion(output, label)
    errG.backward()
    D_G_z2 = output.mean().item()
    if alternating:
        opt_G.step()

    # discriminator
    D.zero_grad()
    real_cpu = real_builder(batch_size)
    label = torch.full((batch_size,), fill_value=real_label, device=device)

    real_cpu = torch.Tensor(real_cpu).to(device)
    real_cpu.requires_grad_()

    output = D(real_cpu).view(-1)
    errD_real = criterion(output, label)
    errD_real.backward(retain_graph=True)
    reg = reg_param * compute_grad2(output, real_cpu).mean()
    reg.backward()
    D_x = output.mean().item()

    # train with fake
    label.fill_(fake_label)
    output = D(fake.detach()).view(-1)
    errD_fake = criterion(output, label)

    errD_fake.backward()
    D_G_z1 = output.mean().item()
    errD = errD_real + errD_fake
    
    if not alternating:
        opt_G.step()
        
    opt_D.step()


    if step % freq_view ==0:
        print('[%d, Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (step,
                 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        print(time.time()-start)
        start = time.time()


    if step % freq_view == 0:  
        G_generated = G(fixed_noise).cpu().detach().numpy()
        generator_save_name = os.path.join(fig_dir, 'generator_step_' + str(step))
        discriminator_save_name = os.path.join(fig_dir, 'discriminator_step_' + str(step))
        kde(G_generated.T, show=show, save=generator_save_name, bbox=bbox)

        fake_kernel = stats.gaussian_kde(G_generated.T)
        kl = stats.entropy(pk=fake_kernel(positions), qk=real_kernel(positions))

        kl_data.append(kl)
        np.savetxt(os.path.join(kl_dir, 'kl_div.csv'), np.array(kl_data).reshape(-1, 1), delimiter=',')