# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
import os, sys

PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

import random
from random import SystemRandom

import matplotlib.pyplot as plt
import matplotlib.cm as cm

import lib.utils as utils
from lib.odefunc import ODEfuncRegGC
#from lib.torchdiffeq import odeint as odeint
from torchdiffeq import odeint_adjoint as odeint
#import lib.odeint as odeint
import argparse
parser = argparse.ArgumentParser(description='.')
parser.add_argument('--r', type=int, default=0, help='random_seed')

parser.add_argument('--lr', type=float, default=1e-2, help='learning rate')
parser.add_argument('--nepoch', type=int, default=2000, help='max epochs')
parser.add_argument('--niterbatch', type=int, default=100, help='max epochs')

parser.add_argument('--nlayer', type=int, default=4, help='max epochs')
parser.add_argument('--nunit', type=int, default=100, help='max epochs')

parser.add_argument('--lMB', type=int, default=100, help='length of seq in each MB')
parser.add_argument('--nMB', type=int, default=40, help='length of seq in each MB')

parser.add_argument('--drop_rate', type=float, default=.0, help='drop rate')

parser.add_argument('--penalty_weight', type=float, default=.01, help='penalty_weight')
parser.add_argument('--prune_thresh', type=float, default=.01, help='pruning threshold')

parser.add_argument('--ndelay', type=int, default=0, help='random_seed')
parser.add_argument('--lag', type=int, default=10, help='random_seed') 
parser.add_argument('--aug', type=int, default=0, help='random_seed')

parser.add_argument('--odeint', type=str, default='rk4', help='integrator')

args = parser.parse_args()

torch.set_default_dtype(torch.float64)
device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu')

seed = args.r
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

save_path = 'experiments/'
utils.makedirs(save_path)

expIDs = ['98657', '54088'] # NDDEs + 6 delay


data_all = np.load('../data/_sjdf/sjdf.npy')
path = '../data/_sjdf'
train_index = np.loadtxt(os.path.join(path,'sjdf_train_index.txt')).astype(int)
test_index = np.loadtxt(os.path.join(path,'sjdf_test_index.txt')).astype(int)
data_train = data_all[train_index, : , :]
data_test  = data_all[test_index, : , :]
a_h_ref = args.lag 
n = args.ndelay
hist_idx = a_h_ref * torch.arange(1, n+1)
t = np.expand_dims(np.linspace(0.,1.,256,endpoint=True,dtype=np.float64),axis=-1)[::1] 
t = torch.tensor(t).squeeze()

train_data = torch.tensor(data_train).to(device).transpose(1,2)
test_data = torch.tensor(data_test[:,:,:]).to(device).transpose(1,2) 

aug = args.aug 
odefunc = ODEfuncRegGC(3*(n+1)+aug*(n+1), args.nlayer, args.nunit,drop_rate=args.drop_rate, aug_dim=aug).to(device)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(odefunc))

time_offset = 60


time_offset_test = 60
batch_size = 192
pred_col = torch.zeros((len(expIDs), batch_size, len(t)-(time_offset_test+1), 3)) 
test_losses = torch.zeros((len(expIDs)))
cw_norms = np.zeros((len(expIDs),3+aug, int((n+1)*(3+aug))))
for m in range(len(expIDs)):
	ckpt_path = os.path.join(save_path, "experiment_" + str(expIDs[m]) + '.ckpt')
	print(ckpt_path)
	fig_save_path = os.path.join(save_path,"experiment_"+str(expIDs[m]))

	ckpt = torch.load(ckpt_path)
	odefunc.load_state_dict(ckpt['state_dict'])


	with torch.no_grad():
		for k in range(3+aug):
			print(k)
			for l in range(int((n+1)*(3+aug))):
				print(torch.norm(odefunc.gradient_nets[k][0].weight[:,l], p=2))
				cw_norms[m,k,l] = torch.norm(odefunc.gradient_nets[k][0].weight[:,l], p=2).detach().numpy()
		odefunc.eval()
			
		cur = time_offset_test 
		batch_t = t[:2]  # (T)

		pred = torch.zeros((batch_size, len(t)-(time_offset_test+1), 3+aug)) 
		loss = 0
		for j in range(len(t)-(time_offset_test+1)):
			batch_y0 = torch.zeros((batch_size,n+1, 3+aug)).to(device) # history+1, num of states
			if j == 0:
				batch_y0[:,0,:3] = test_data[:batch_size,time_offset_test,:3]
			else:
				batch_y0[:,0,:] = pred[:,j-1,:]
			hist_indices = cur - hist_idx
	
			batch_y0[:,1:,:3][:,hist_indices<time_offset_test,:] = test_data[:batch_size,hist_indices,:][:,hist_indices<time_offset_test,:]
			batch_y0[:,1:,:3+aug][:,hist_indices>=time_offset_test,:] = pred[:,hist_indices-(time_offset_test+1),:3+aug][:,hist_indices>=time_offset_test,:] #train_data[0, hist_indices,:3]
			batch_y0 = batch_y0.squeeze()
			batch_y = torch.stack([test_data[:batch_size,cur + i,:] for i in range(2)], dim=1)  # (T, M, D)
			pred_y = odeint(odefunc, batch_y0.reshape((batch_size,-1)), batch_t, method=args.odeint).to(device)
			pred_y = pred_y.squeeze().reshape((2, batch_size, n+1, 3+aug))
			pred[:,j,:] = pred_y[-1,:,0,:]
	
			cur = cur + 1
		pred_col[m, :, :, :] = pred[:,:,:3]
		test_losses[m] = torch.linalg.norm((test_data[:batch_size,(time_offset_test+1):,:] - pred[:,:,:3]).flatten())/torch.linalg.norm((test_data[:batch_size,(time_offset_test+1):,:]).flatten())
		print('test loss', test_losses[m])

print(torch.mean(test_losses))
print(torch.std(test_losses))		

from matplotlib.ticker import StrMethodFormatter
#plt.gca().yaxis.set_major_formatter(StrMethodFormatter('{x:,.0f}')) # No decimal places

pred_mean = torch.mean(pred_col, dim=0)
pred_std = torch.std(pred_col, dim=0)
titles = ['Gauge 702', 'Gauge 901', 'Gauge 911']
colors = ['r','g','b']
for k in range(10):
	plt.figure()
	plt.tight_layout()
	save_file = os.path.join("experiments/test_node_image_{:02d}.png".format(k))
	mean_to_plot = torch.cat((test_data[:batch_size,time_offset_test:time_offset_test+1,:3], pred_mean[:,:,:3]), dim=1)
	std_to_plot = torch.cat((torch.zeros_like(test_data[:batch_size,time_offset_test:time_offset_test+1,:3]), pred_std[:,:,:3]), dim=1)
	fig = plt.figure(figsize=(4,4))
	axes = []
	for i in range(3):
		axes.append(fig.add_subplot(3,1,i+1))
		axes[i].plot(t[:],test_data[k,:,i].cpu().detach().numpy(),lw=2,color=colors[i])
		axes[i].plot(t[time_offset_test:],mean_to_plot.cpu().detach().numpy()[k,:,i],lw=2,color='c',ls='--')
		axes[i].fill_between(t[time_offset_test:],mean_to_plot.cpu().detach().numpy()[k,:,i]-2.*std_to_plot.cpu().detach().numpy()[k,:,i], mean_to_plot.cpu().detach().numpy()[k,:,i]+2.*std_to_plot.cpu().detach().numpy()[k,:,i],alpha=.1, color='m')
		axes[i].set_title(titles[i])

		axes[i].yaxis.set_major_formatter(StrMethodFormatter('{x:,.2f}'))
		max_val = torch.max(abs(test_data[k,:,i]))
		#axes[i].set_ylim([-2.5, 2.5])
		axes[i].set_ylim([-float(max_val), float(max_val)])
		axes[i].set_xlim([t[0], t[-1]])
		axes[i].set_ylabel('elevation (m)')
	plt.tight_layout()
	plt.savefig(save_file)
	plt.close(fig)
	plt.close('all')
	plt.clf()

