import math
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['Simhei']
plt.rcParams['axes.unicode_minus']=False
#绘制三维图像
import mpl_toolkits.mplot3d as p3d
from srt_online import srt_online

'''
时滞混沌系统Mackey-Glass,Euler(欧拉)公式进行求解,其他方法也可。。。
'''
def Mackey_Glass(x,h,a,b,c,tao,T):
  #这里t是迭代周期，公式x(t)中的t是t*h
  for t in range(tao,tao+T):
    x[t+1]=x[t]+h*(-b*x[t]+a*x[t-tao]/(1+x[t-tao]**c))
  return x
a=0.2
b=0.1
c=10
tao=17
#迭代次数
T=5000
#初值是用连续函数定义的
#迭代步长，应满足tao%h=0,否则存在误差
h=1
#tao转成离散序列索引
tao=int(tao/h)
#前tao+1个数是[-tao,0]的系统初值，后T个数是迭代T次
#dtype默认是int32，要指定dtype
x=np.arange(0,T+tao+1,dtype=np.float64)
#[-tao,0]之间的初始条件
for i in range(tao+1):
    x[i]=0.3
#产生时滞混沌序列
x=Mackey_Glass(x,h,a,b,c,tao,T)

Working_para = np.array([5e+5, 1e-10, 5e-2, 3])
erro = 1e-3

srt =  srt_online(WP = Working_para,erro = erro,max_count = 400)

RMSE = []
train_iter = []
Y_ac = x[4068:4568].copy()
for i in range(4000):
    X = [x[i],x[i+6],x[i+12],x[i+18]]
    Y = x[i+68]
    srt.update(X,Y)
#     if (i+1)%100 == 0:
#         Y_pre = []
#         for j in range(4000,4500):
#             X_test = [x[j],x[j+6],x[j+12],x[j+18]]
#             pre = srt.predict(X_test)
#             Y_pre.append(pre)
#         res = np.array(Y_pre) - np.array(Y_ac)
#         RMSE.append(math.sqrt(np.mean(res**2)))
#         train_iter.append(i+1)
# print(RMSE)
# plt.plot(train_iter,RMSE)
# plt.xlabel('train_iter')
# plt.ylabel('RMSE')
# plt.show()

Y_pre = []
for i in range(4000):
    X = [x[i],x[i+6],x[i+12],x[i+18]]
    Y = x[i+68]
    pre = srt.predict(X)
    Y_pre.append(pre)
res = np.array(Y_pre) - np.array(x[68:4068])
print(math.sqrt(np.mean(res**2)))
print(np.mean(abs(res)))
Y_pre = []
for i in range(4000,4500):
    X = [x[i],x[i+6],x[i+12],x[i+18]]
    Y = x[i+68]
    pre = srt.predict(X)
    Y_pre.append(pre)
t = np.arange(4000,4500)

# plt.subplot(1,2,1)
plt.plot(t,Y_ac,label='Real output',linewidth=2,
    linestyle="dashdot",)
plt.plot(t,Y_pre,label = 'OSRT output',linewidth=2,
    linestyle="dashdot",)
plt.xlabel("Prediction time")
plt.ylabel("Output value")
plt.legend()
plt.show()
res = np.array(Y_pre) - np.array(Y_ac)
#plt.subplot(1,2,2)
plt.plot(t,res,linewidth=2,
    linestyle="dashdot",)
plt.xlabel("Prediction time")
plt.ylabel("Errors")
plt.legend()
plt.show()
print(math.sqrt(np.mean(res**2)))
print(np.mean(abs(res)))