import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
import os, sys

PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

import random
from random import SystemRandom

import matplotlib.pyplot as plt
import matplotlib.cm as cm

import lib.utils as utils
from lib.odefunc import DDEfuncGC
from torchdiffeq import odeint_adjoint as odeint
import torch.nn.utils.prune as prune
from lib.prune import ThresholdPruning, GroupThresholdPruning, group_unstructured
import argparse
parser = argparse.ArgumentParser(description='.')
parser.add_argument('--r', type=int, default=0, help='random_seed')

parser.add_argument('--lr', type=float, default=1e-2, help='learning rate')
parser.add_argument('--nepoch', type=int, default=2000, help='max epochs')
parser.add_argument('--niterbatch', type=int, default=100, help='max epochs')

parser.add_argument('--nlayer', type=int, default=4, help='max epochs')
parser.add_argument('--nunit', type=int, default=25, help='max epochs')

parser.add_argument('--lMB', type=int, default=100, help='length of seq in each MB')
parser.add_argument('--nMB', type=int, default=40, help='length of seq in each MB')

parser.add_argument('--odeint', type=str, default='rk4', help='integrator')
parser.add_argument('--id', type=int, default=0, help='exp id')

args = parser.parse_args()

torch.set_default_dtype(torch.float64)
device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu')

seed = args.r
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

save_path = 'experiments/'
utils.makedirs(save_path)


data = np.load("../data/mackey_tau_5_.npz")
h_ref = 1e-2 
Time = 30.72*5
tau = 5 # ground truth tau 
a = 1 # adjustable parameter
a_h_ref = 100
n = 10 # 10 candidates of delayed variables x(t), x(t-1), x(t-2), ... x(t-5), ..., x(t-9), x(t-10)
hist_idx = a_h_ref * torch.arange(1, n+1)
N_steps = int(np.floor(Time/h_ref)) + 1
t = np.expand_dims(np.linspace(0,Time,N_steps,endpoint=True,dtype=np.float64),axis=-1)[::1] 
t = torch.tensor(t).squeeze()

train_data = torch.tensor(data['train_data']).to(device).unsqueeze(0).unsqueeze(-1) # 1 traj, 3073 steps, 1 variable
odefunc = DDEfuncGC(n+1, args.nlayer, args.nunit).to(device)

IDs = np.asarray([58396, 75355])

cw_norms = np.zeros((len(IDs),10))
preds = torch.zeros((len(t)-1, len(IDs))) 
for i in range(len(IDs)):
	experimentID = IDs[i] 
	ckpt_path = os.path.join(save_path, "experiment_" + str(experimentID) + '.ckpt')
	fig_save_path = os.path.join(save_path,"experiment_"+str(experimentID))
	#utils.makedirs(fig_save_path)
	print(ckpt_path)
	ckpt = torch.load(ckpt_path)
	odefunc.load_state_dict(ckpt['state_dict'])

	for k in range(10):
		cw_norms[i,k] = torch.norm(odefunc.gradient_net[0].weight[:,k+1], p=2).cpu().detach().numpy() 
	val_loss = 0
			
	cur = 0 
	batch_t = t[:2]  # (T)
	#print(s, cur, batch_t)
	pred = torch.zeros((len(t)-1, 1)) 
	loss = 0
	for j in range(len(t)-1):
		batch_y0 = torch.zeros((1,n+1, 1)) # history+1, num of states
		if j == 0:
			batch_y0[0,0,:] = train_data[0,0,:]
		else:
			batch_y0[0,0,:] = pred[j-1,:]
		hist_indices = cur - hist_idx
		batch_y0[0,1:,:][hist_indices < 0] = .5
		batch_y0[0,1:,:][hist_indices >=0] = pred[hist_indices-1,:][hist_indices >= 0]
		batch_y0 = batch_y0.squeeze()
		batch_y = torch.stack([train_data[0,cur + i,:] for i in range(2)], dim=1)  # (T, M, D)
		pred_y = odeint(odefunc, batch_y0, batch_t, method=args.odeint).to(device)
		pred[j,:] = pred_y[-1,:1]
	
		val_loss += torch.mean(torch.abs(pred_y[-1,:1] - batch_y[0,:1]))
		cur = cur + 1
	preds[:, i] = pred[:,0]
	print('val loss', val_loss)

cw_norm = np.sum(cw_norms, axis=0)
cw_norm = cw_norm / np.max(cw_norm)
cw_norm = np.expand_dims(cw_norm,0)
print(cw_norm)
x_ticks = [r'$\tau_{:d}$'.format(i) for i in range(1,10)]+[r'$\tau_{10}$']
y_ticks = [r'$\dot x_{:d}$'.format(i) for i in range(1,2)]

print(x_ticks)
plt.figure()
save_file = os.path.join(save_path,"heat_map.png")
fig = plt.figure(figsize=(2,.75))
axes = []
axes.append(fig.add_subplot(1,1,1))
im = axes[0].imshow(1-cw_norm, cmap=cm.gray)
axes[0].spines[:].set_visible(False)
axes[0].grid(which="minor", color="grey", linestyle='-', linewidth=3)
axes[0].set_xticks(np.arange(10))
axes[0].set_yticks(np.arange(1))
axes[0].set_xticklabels(x_ticks)
axes[0].set_yticklabels(y_ticks)
axes[0].set_xticks(np.arange(10+1)-.5, minor=True)
axes[0].set_yticks(np.arange(1+1)-.5, minor=True)
axes[0].tick_params(which="minor", bottom=False, left=False)
plt.setp(axes[0].get_xticklabels(), rotation=60, ha="right",rotation_mode="anchor")
plt.tight_layout()
plt.savefig(save_file)
plt.close(fig)
plt.close('all')
plt.clf()
preds = preds.cpu().detach().numpy()
mean_preds = np.mean(preds, axis=1)
std_preds = np.std(preds, axis=1)
print(mean_preds.shape)

plt.figure()
save_file = os.path.join(save_path,"pred_gt.png")
fig = plt.figure(figsize=(4,2))
axes = []
axes.append(fig.add_subplot(1,1,1))
axes[0].plot(t[:-1],train_data[0,1:,0].cpu().detach().numpy(),lw=3,color='deepskyblue')
axes[0].plot(t[:-1],mean_preds,lw=2,color='r',ls='--')
axes[0].fill_between(t[:-1], mean_preds - 2.*std_preds, mean_preds + 2.*std_preds, alpha=1., color='r')
axes[0].set_xlabel('$t$ (sec)')
plt.tight_layout()
plt.savefig(save_file)
plt.close(fig)
plt.close('all')
plt.clf()
	
