import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as Data

import matplotlib.pyplot as plt
import matplotlib.animation as animation

import numpy as np
import imageio
from tqdm import tqdm

from neural_nets import *
from utils import *
import argparse

plt.rcParams.update({
"animation.writer": "ffmpeg",
"font.family": "serif",  # use serif/main font for text elements
"font.size": 12,
"text.usetex": True,     # use inline math for ticks
"pgf.rcfonts": False,    # don't setup fonts from rc parameters
"hist.bins": 20, # default number of bins in histograms
"pgf.preamble": [
     "\\usepackage{units}",          # load additional packages
     "\\usepackage{metalogo}",
     "\\usepackage{unicode-math}",   # unicode math setup
     r"\setmathfont{xits-math.otf}",
     r"\setmainfont{DejaVu Serif}",  # serif font via preamble
     r'\usepackage{color}',
]
})



if __name__ == "__main__":
	argParser = argparse.ArgumentParser()
	argParser.add_argument("-niter", "--niterations", nargs='?', type=int, default=800000)
	argParser.add_argument("-skip", "--skip_connection", nargs='?', const=True, type=bool, default=False)
	argParser.add_argument("-count_bias", "--count_bias", nargs='?', const=True, type=bool, default=False)
	argParser.add_argument("-n_hidden", "--n_hidden", nargs='?', type=int, default=200)

	args = argParser.parse_args()

	x = torch.Tensor(np.array([-1.5,-1.,0,0.25,2]).reshape(-1,1))
	y = torch.Tensor(np.array([0.7,-0.3,-0.3,0.4,0.2]).reshape(-1,1))

	# torch can only train on Variable, so convert them to Variable
	x, y = Variable(x), Variable(y)

	torch.manual_seed(0)

	# init network
	net = Net(n_feature=1, n_hidden=args.n_hidden, n_output=1, init_scale=1/np.sqrt(args.n_hidden), skip_connection=args.skip_connection, balanced=False, zero_output=False)     # define the network
	 
	optimizer = torch.optim.SGD(net.parameters(), lr=5e-2) #Gradient descent
	loss_func = torch.nn.MSELoss(reduction='mean')  # mean squared error
	reg_lambda = 1e-3 # scale of regularization by 2 norm of parameters

	n_samples = x.shape[0]
	n_iterations = args.niterations # number of descent steps

	loss = torch.Tensor(np.array([0]))
	previous_loss = torch.Tensor(np.array([np.infty]))

	# plot parameters
	iter_geom = 2 #saved frames correspond to step t=\lceil k^{iter_geom} \rceil for all integers k 
	last_iter = 0
	frame = 0
	ims = []
	fig = plt.figure("Training dynamics")
	plt.ioff()

	# Cosmetics
	c1 = 'tab:green' # color of left axis
	c2 = 'tab:blue' # color of right axis

	plt.subplots_adjust(left=0.15, right=0.85)

	ax1 = fig.add_subplot(111)
	ax1.set_xlim(x.min()-1,x.max()+1)
	ax1.set_ylim(y.min()-1.6,y.max()+0.8)
	ax2 = ax1.twinx()
	ax2.axhline(0, linestyle='--', alpha=0.5)
	ax1.set_ylabel(r'$f_{\theta}(x)$', fontsize=20)
	ax2.set_ylabel(r'$\mathsf{s}_j\|w_j\|$', fontsize=20)

	ax1.yaxis.label.set_color(c1)
	ax2.yaxis.label.set_color(c2)

	ax2.spines["left"].set_edgecolor(c1)
	ax2.spines["right"].set_edgecolor(c2)

	ax1.tick_params(axis='y', colors=c1)
	ax2.tick_params(axis='y', colors=c2)
	multicolor_label(ax1,(r'$x$',r'$-w_{j,2}/w_{j,1}$'),(c1,c2),axis='x', fontsize=20)
	#######

	iters = []
	losses = []
	z = torch.Tensor(np.linspace(x.min()-1,x.max()+1,100).reshape(-1,1))

	# train the network
	for it in tqdm(range(n_iterations)):
	    previous_loss = loss
	    prediction = net(x)
	    reg_loss = None
	    for name, param in net.named_parameters():
	    	if args.count_bias:
	    		count = "skip" not in name
	    	else:
	    		count = ("weight" in name and "skip" not in name)
	    	if count:
		        if reg_loss is None:
		            reg_loss = 0.5 * torch.sum(param**2)
		        else:
		            reg_loss += 0.5 * param.norm(2)**2
	    loss = loss_func(prediction, y) + reg_lambda*reg_loss    # must be (1. nn output, 2. target)

	    if (it<2 or it==int(last_iter*iter_geom)+1): # save frame in animation
	        im1, = ax1.plot(z.data.numpy(), net(z).data.numpy(), '-', c=c1, lw=2, animated=True)
	        im2 = ax2.scatter(-(net.hidden.bias.data.reshape(-1)/net.hidden.weight.data.reshape(-1)).numpy(), net.predict.weight.data.reshape(-1).numpy(), animated=True, c=c2, marker='*')
	        t = ax1.annotate("iteration: "+str(it),(0.4,0.95),xycoords='figure fraction',annotation_clip=False) # add text
	        if it == 0:
	            ax1.scatter(x.data.numpy(), y.data.numpy(), color=c1)
	        ims.append([im1,im2,t])
	        last_iter = it
	        iters.append(last_iter)
	        frame += 1

	    losses.append(loss.data.numpy())
	    optimizer.zero_grad()   # clear gradients for next train
	    loss.backward()         # backpropagation, compute gradients
	    optimizer.step()        # descent step
	    
	# plot last iterate
	im1, = ax1.plot(z.data.numpy(), net(z).data.numpy(), '-', c=c1, lw=2, animated=True)
	im2 = ax2.scatter(-(net.hidden.bias.data.reshape(-1)/net.hidden.weight.data.reshape(-1)).numpy(), net.predict.weight.data.reshape(-1).numpy(), animated=True, c=c2, marker='*')
	t = ax1.annotate("iteration: "+str(it),(0.4,0.95),xycoords='figure fraction',annotation_clip=False) # add text
	ims.append([im1,im2,t])
	iters.append(it)

	#############
	    
	ani = animation.ArtistAnimation(fig, ims, interval=100, repeat=False)
	plt.close()


	if args.count_bias:
		filename = 'norm_induced/regularisation_full'
	else:
		filename = 'norm_induced/regularisation_nobias'
	if args.skip_connection:
		filename += "_skip"


	ani.save(filename+".mp4", fps=10, dpi=120) # save animation as video

	#plot last iteration
	del ani
	z = torch.Tensor(np.linspace(x.min()-1,x.max()+1,10000).reshape(-1,1))

	fig = plt.figure("Training dynamics")
	plt.ioff()

	# Cosmetics
	c1 = 'tab:green' # color of left axis
	c2 = 'tab:blue' # color of right axis

	plt.subplots_adjust(left=0.15, right=0.85)

	ax1 = fig.add_subplot(111)
	ax1.set_xlim(x.min()-1,x.max()+1)
	ax1.set_ylim(y.min()-1.6,y.max()+0.8)
	ax2 = ax1.twinx()
	ax2.axhline(0, linestyle='--', alpha=0.5)
	ax1.set_ylabel(r'$f_{\theta}(x)$', fontsize=20)
	ax2.set_ylabel(r'$\mathsf{s}_j\|w_j\|$', fontsize=20)

	ax1.yaxis.label.set_color(c1)
	ax2.yaxis.label.set_color(c2)

	ax2.spines["left"].set_edgecolor(c1)
	ax2.spines["right"].set_edgecolor(c2)

	ax1.tick_params(axis='y', colors=c1)
	ax2.tick_params(axis='y', colors=c2)
	multicolor_label(ax1,(r'$x$',r'$-w_{j,2}/w_{j,1}$'),(c1,c2),axis='x', fontsize=20)
	#######

	ax1.scatter(x.data.numpy(), y.data.numpy(), color=c1)
	# plot last iterate
	im1, = ax1.plot(z.data.numpy(), net(z).data.numpy(), '-', c=c1, lw=2, animated=True)
	im2 = ax2.scatter(-(net.hidden.bias.data.reshape(-1)/net.hidden.weight.data.reshape(-1)).numpy(), net.predict.weight.data.reshape(-1).numpy(), animated=True, c=c2, marker='*')

	ax1.grid(alpha=0.5)

	plt.savefig(filename+'_lastiter.pdf')
	plt.close()


