# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
import os, sys

PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

import random
from random import SystemRandom

import matplotlib.pyplot as plt

import lib.utils as utils
from lib.odefunc import ODEfuncRegGC
#from lib.torchdiffeq import odeint as odeint
from torchdiffeq import odeint_adjoint as odeint
#import lib.odeint as odeint
import argparse
parser = argparse.ArgumentParser(description='.')
parser.add_argument('--r', type=int, default=0, help='random_seed')

parser.add_argument('--lr', type=float, default=1e-2, help='learning rate')
parser.add_argument('--nepoch', type=int, default=2000, help='max epochs')
parser.add_argument('--niterbatch', type=int, default=100, help='max epochs')

parser.add_argument('--nlayer', type=int, default=4, help='max epochs')
parser.add_argument('--nunit', type=int, default=100, help='max epochs')

parser.add_argument('--lMB', type=int, default=100, help='length of seq in each MB')
parser.add_argument('--nMB', type=int, default=40, help='length of seq in each MB')

parser.add_argument('--drop_rate', type=float, default=.0, help='drop rate')

parser.add_argument('--penalty_weight', type=float, default=.01, help='penalty_weight')
parser.add_argument('--prune_thresh', type=float, default=.01, help='pruning threshold')

parser.add_argument('--ndelay', type=int, default=0, help='random_seed')
parser.add_argument('--lag', type=int, default=10, help='random_seed')
parser.add_argument('--aug', type=int, default=0, help='random_seed')

parser.add_argument('--odeint', type=str, default='rk4', help='integrator')

args = parser.parse_args()

torch.set_default_dtype(torch.float64)
device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu')

seed = args.r
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

save_path = 'experiments/'
utils.makedirs(save_path)
experimentID = int(SystemRandom().random()*100000)
ckpt_path = os.path.join(save_path, "experiment_" + str(experimentID) + '.ckpt')
fig_save_path = os.path.join(save_path,"experiment_"+str(experimentID))
utils.makedirs(fig_save_path)
print(ckpt_path)

data_all = np.load('../data/_sjdf/sjdf.npy')
path = '../data/_sjdf'
train_index = np.loadtxt(os.path.join(path,'sjdf_train_index.txt')).astype(int)
test_index = np.loadtxt(os.path.join(path,'sjdf_test_index.txt')).astype(int)
data_train = data_all[train_index, : , :]
data_test  = data_all[test_index, : , :]

a_h_ref = args.lag 
n = args.ndelay
hist_idx = a_h_ref * torch.arange(1, n+1)
t = np.expand_dims(np.linspace(0.,1.,256,endpoint=True,dtype=np.float64),axis=-1)[::1] 
t = torch.tensor(t).squeeze()

train_data = torch.tensor(data_train).to(device).transpose(1,2) 
test_data = torch.tensor(data_test).to(device).transpose(1,2) 

aug = args.aug 
odefunc = ODEfuncRegGC(3*(n+1)+aug*(n+1), args.nlayer, args.nunit,drop_rate=args.drop_rate, aug_dim=aug).to(device)

params = odefunc.parameters()
optimizer = optim.Adamax(params, lr=args.lr)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9987)

best_loss = 1e30
frame = 0 

time_offset = 60

causality_mask = []
for k in range(3+aug):
	causality_mask.append(torch.ones_like(odefunc.gradient_nets[k][0].weight,dtype=torch.float64).to(device))

for itr in range(args.nepoch):
	print('=={0:d}=='.format(itr))
	for i in range(args.niterbatch):
		optimizer.zero_grad()
		r = torch.from_numpy(np.random.choice(np.arange(len(train_data),dtype=np.int64),args.nMB, replace=False))
		s = torch.from_numpy(np.random.choice(np.arange(time_offset+1,len(t) - args.lMB, dtype=np.int64), args.nMB, replace=False))
		cur = s.clone()
		batch_t = t[:2]  # (T)
		pred = [] 
		loss = 0
		for j in range(args.lMB):
			batch_y0 = torch.zeros((args.nMB, n+1, 3+aug)).to(device) # batch size, history+1, num of states
			if j == 0:
				batch_y0[:,0,:3] = train_data[r,s,:]
			else:
				batch_y0[:,0,:] = pred[-1]
			hist_indices = cur.unsqueeze(-1) - hist_idx.unsqueeze(0)
	
			for k in range(args.nMB):
				batch_y0[k,1:,:3] = train_data[r[k], hist_indices[k,:],:3]
			batch_y = torch.stack([train_data[r,cur + k,:] for k in range(2)], dim=1)  # (T, M, D)
			pred_y = odeint(odefunc, batch_y0.reshape((args.nMB, -1)), batch_t, method=args.odeint).to(device).transpose(0,1)
			pred_y = pred_y.reshape((args.nMB, 2, n+1, 3+aug))
			pred.append(pred_y[:,-1,0,:]) # single variable case

			loss += torch.mean(torch.abs(pred_y[:,-1,0,:3] - batch_y[:,-1,:]))
		
			cur = cur + 1
		for k in range(3+aug):
			loss += args.penalty_weight * ( torch.sum(torch.norm(odefunc.gradient_nets[k][0].weight, p=2, dim=0) ))
		print(itr,i,loss.item())
		loss.backward()
		optimizer.step()

		with torch.no_grad():
			for k in range(3+aug):
				cw_norm = torch.norm(odefunc.gradient_nets[k][0].weight, p=2, dim=0)
				cw_mask = cw_norm <= args.prune_thresh 
				cw_mask = cw_mask.reshape(n+1,3+aug)
				cw_mask[0,k] = 0 
				cw_mask = cw_mask.reshape((n+1)*(3+aug))
				causality_mask[k][:,cw_mask] = 0.0
				odefunc.gradient_nets[k][0].weight.data = causality_mask[k] * odefunc.gradient_nets[k][0].weight
		
	scheduler.step()

	with torch.no_grad():
		for k in range(3+aug):
			print(k)
			for l in range(int((n+1)*(3+aug))):
				print(torch.norm(odefunc.gradient_nets[k][0].weight[:,l], p=2))
		val_loss = 0
		odefunc.eval()
		
		time_offset_test = 60
		cur = time_offset_test 
		batch_t = t[:2]  # (T)
		pred = torch.zeros((50, len(t)-(time_offset_test+1), 3+aug)) 
		loss = 0
		for j in range(len(t)-(time_offset_test+1)):
			batch_y0 = torch.zeros((50,n+1, 3+aug)).to(device) # history+1, num of states
			if j == 0:
				batch_y0[:,0,:3] = test_data[:50,time_offset_test,:3]
			else:
				batch_y0[:,0,:] = pred[:,j-1,:]
			hist_indices = cur - hist_idx
			batch_y0[:,1:,:3][:,hist_indices<time_offset_test,:] = test_data[:50,hist_indices,:][:,hist_indices<time_offset_test,:]
			batch_y0[:,1:,:3+aug][:,hist_indices>=time_offset_test,:] = pred[:,hist_indices-(time_offset_test+1),:3+aug][:,hist_indices>=time_offset_test,:] #train_data[0, hist_indices,:3]
			batch_y0 = batch_y0.squeeze()
			batch_y = torch.stack([test_data[:50,cur + i,:] for i in range(2)], dim=1)  # (T, M, D)
			pred_y = odeint(odefunc, batch_y0.reshape((50,-1)), batch_t, method=args.odeint).to(device)
			pred_y = pred_y.squeeze().reshape((2, 50, n+1, 3+aug))
			pred[:,j,:] = pred_y[-1,:,0,:]

			cur = cur + 1
		val_loss = torch.mean(torch.abs(test_data[:50,(time_offset_test+1):,:] - pred[:,:,:3]))
		print('val loss', val_loss)
		
		if best_loss > val_loss:
			print('saving...', val_loss)
			torch.save({'state_dict': odefunc.state_dict(),}, ckpt_path)
			best_loss = val_loss 

		pred_to_plot = torch.cat((test_data[:50,time_offset_test:time_offset_test+1,:3], pred[:,:,:3]), dim=1)

		for k in range(10):
			plt.figure()
			plt.tight_layout()
			save_file = os.path.join(fig_save_path,"image_{:03d}_{:02d}.png".format(frame,k))
			fig = plt.figure(figsize=(4,4))
			axes = []
			for i in range(3):
				axes.append(fig.add_subplot(3,1,i+1))
				axes[i].plot(t[:],test_data[k,:,i].cpu().detach().numpy(),lw=2,color='k')
				axes[i].plot(t[time_offset_test:],pred_to_plot.cpu().detach().numpy()[k,:,i],lw=2,color='c',ls='--')
			plt.savefig(save_file)
			plt.close(fig)
			plt.close('all')
			plt.clf()
		frame += 1
		odefunc.train()
