# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
import os, sys

PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

import random
from random import SystemRandom

import matplotlib.pyplot as plt

import lib.utils as utils
from lib.odefunc import DDEfuncGC
from torchdiffeq import odeint_adjoint as odeint
import torch.nn.utils.prune as prune
from lib.prune import ThresholdPruning, GroupThresholdPruning, group_unstructured
import argparse
parser = argparse.ArgumentParser(description='.')
parser.add_argument('--r', type=int, default=0, help='random_seed')

parser.add_argument('--lr', type=float, default=1e-2, help='learning rate')
parser.add_argument('--nepoch', type=int, default=2000, help='max epochs')
parser.add_argument('--niterbatch', type=int, default=100, help='max epochs')

parser.add_argument('--nlayer', type=int, default=4, help='max epochs')
parser.add_argument('--nunit', type=int, default=100, help='max epochs')

parser.add_argument('--lMB', type=int, default=100, help='length of seq in each MB')
parser.add_argument('--nMB', type=int, default=40, help='length of seq in each MB')

parser.add_argument('--odeint', type=str, default='rk4', help='integrator')

args = parser.parse_args()

torch.set_default_dtype(torch.float64)
device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu')

seed = args.r
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

save_path = 'experiments/'
utils.makedirs(save_path)
experimentID = int(SystemRandom().random()*100000)
ckpt_path = os.path.join(save_path, "experiment_" + str(experimentID) + '.ckpt')
fig_save_path = os.path.join(save_path,"experiment_"+str(experimentID))
utils.makedirs(fig_save_path)
print(ckpt_path)

data = np.load("../data/mackey_tau_5_.npz")
h_ref = 1e-2 
Time = 30.72*5
tau = 5 # ground truth tau 
a = 1 # adjustable parameter
a_h_ref = 100
n = 10 # 10 candidates of delayed variables x(t), x(t-1), x(t-2), ... x(t-5), ..., x(t-9), x(t-10)
hist_idx = a_h_ref * torch.arange(1, n+1)
N_steps = int(np.floor(Time/h_ref)) + 1
t = np.expand_dims(np.linspace(0,Time,N_steps,endpoint=True,dtype=np.float64),axis=-1)[::1] 
t = torch.tensor(t).squeeze()

train_data = torch.tensor(data['train_data']).to(device).unsqueeze(0).unsqueeze(-1) # 1 traj, 3073 steps, 1 variable
odefunc = DDEfuncGC(n+1, args.nlayer, args.nunit).to(device)

params = odefunc.parameters()
optimizer = optim.Adamax(params, lr=args.lr)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9987)

best_loss = 1e30
frame = 0 

causality_mask = torch.ones_like(odefunc.gradient_net[0].weight,dtype=torch.float64).to(device)
s_val = torch.from_numpy(np.random.choice(np.arange(len(t) - args.lMB, dtype=np.int64), args.nMB, replace=False))

for itr in range(args.nepoch):
	print('=={0:d}=='.format(itr))
	#for mb_data in train_data:
	for i in range(args.niterbatch):
		optimizer.zero_grad()
	
		s = torch.from_numpy(np.random.choice(np.arange(len(t) - args.lMB, dtype=np.int64), args.nMB, replace=False))
		cur = s.clone()
		batch_t = t[:2]  # (T)
		pred = [] 
		loss = 0
		for j in range(args.lMB):
			batch_y0 = torch.zeros((args.nMB, n+1, 1)).to(device) # batch size, history+1, num of states
			if j == 0:
				batch_y0[:,0,:] = train_data[0,s,:]
			else:
				batch_y0[:,0,:] = pred[-1]
			hist_indices = cur.unsqueeze(-1) - hist_idx.unsqueeze(0)
			batch_y0[:,1:,:][hist_indices < 0] = .5
			batch_y0[:,1:,:][hist_indices >=0] = train_data[0, hist_indices,:][hist_indices >= 0]
			batch_y0 = batch_y0.squeeze()
			batch_y = torch.stack([train_data[0,cur + i,:] for i in range(2)], dim=1)  # (T, M, D)
			pred_y = odeint(odefunc, batch_y0, batch_t, method=args.odeint).to(device).transpose(0,1)
			pred.append(pred_y[:,-1,:1]) # single variable case

			loss += torch.mean(torch.abs(pred_y[:,-1,:1] - batch_y[:,-1,:1]))
			cur = cur + 1
		loss += 1e-1 * ( torch.sum(torch.norm(odefunc.gradient_net[0].weight, p=2, dim=0) ))
		print(itr,i,loss.item())
		loss.backward()
		optimizer.step()
		
		with torch.no_grad():
			cw_norm = torch.norm(odefunc.gradient_net[0].weight, p=2, dim=0)
			cw_mask = cw_norm <= 1e-3
			cw_mask[0] = 0 
			causality_mask[:,cw_mask] = 0.0
			odefunc.gradient_net[0].weight.data = causality_mask * odefunc.gradient_net[0].weight
				
	scheduler.step()
	with torch.no_grad():
		for l in range(n+1):
			print(torch.norm(odefunc.gradient_net[0].weight[:,l], p=2))
		
		val_loss = 0
		
		cur = 0 
		batch_t = t[:2]  # (T)
		pred = torch.zeros((len(t)-1, 1)) 
		loss = 0
		for j in range(len(t)-1):
			batch_y0 = torch.zeros((1,n+1, 1)).to(device) # history+1, num of states
			if j == 0:
				batch_y0[0,0,:] = train_data[0,0,:]
			else:
				batch_y0[0,0,:] = pred[j-1,:]
			hist_indices = cur - hist_idx
			batch_y0[0,1:,:][hist_indices < 0] = .5
			batch_y0[0,1:,:][hist_indices >=0] = pred[hist_indices-1,:][hist_indices >= 0]
			batch_y0 = batch_y0.squeeze()
			batch_y = torch.stack([train_data[0,cur + i,:] for i in range(2)], dim=1)  # (T, M, D)
			pred_y = odeint(odefunc, batch_y0, batch_t, method=args.odeint).to(device)
			pred[j,:] = pred_y[-1,:1]

			val_loss += torch.mean(torch.abs(pred_y[-1,:1] - batch_y[0,:1]))
			cur = cur + 1
		print('val loss', val_loss)
		
		if best_loss > val_loss:
			print('saving...', val_loss)
			torch.save({'state_dict': odefunc.state_dict(),}, ckpt_path)
			best_loss = val_loss 

		plt.figure()
		plt.tight_layout()
		save_file = os.path.join(fig_save_path,"image_{:03d}.png".format(frame))
		fig = plt.figure(figsize=(4,4))
		axes = []
		axes.append(fig.add_subplot(1,1,1))
		axes[0].plot(t[:-1],train_data[0,1:,0].cpu().detach().numpy(),lw=2,color='k')
		axes[0].plot(t[:-1],pred.cpu().detach().numpy()[:,0],lw=2,color='c',ls='--')
		plt.savefig(save_file)
		plt.close(fig)
		plt.close('all')
		plt.clf()
		frame += 1
