import numpy as np
import math
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from srt_online import srt_online

def Lorenz(x0,y0,z0,p,q,r,T):
  #微分迭代步长
  h=0.01
  x=[]
  y=[]
  z=[]
  for t in range(T):
    xt=x0+h*p*(y0-x0)
    yt=y0+h*(q*x0-y0-x0*z0)
    zt=z0+h*(x0*y0-r*z0)
    #x0、y0、z0统一更新
    x0,y0,z0=xt,yt,zt
    x.append(x0)
    y.append(y0)
    z.append(z0)
  return x,y,z
#设定参数
p=10
q=28
r=8/3
#迭代次数
T=5000
#设初值
x0=-11
y0=-2
z0=13
# fig=plt.figure()
# ax=p3d.Axes3D(fig)
x,y,z=Lorenz(x0,y0,z0,p,q,r,T)

# t=np.arange(0,T)
# plt.scatter(t,y,s=1)
# plt.show()

Working_para = np.array([5e+5, 1e-10, 5e-2, 3])
erro = 1e-3

srt =  srt_online(WP = Working_para,erro = erro,max_count = 400)

# 以[y[i-24], y[i-18], y[i-12], y[i-6]]作为输入特征， y[i]为输出目标值
rrmse = []
for i in range(2000):
    X = [y[i],y[i+5],y[i+10],y[i+15]]
    Y = y[i+20]
    srt.update(X,Y)
Y_pre = []
Y_ac = y[2020:4920]
for i in range(2000,4900):
    X = [y[i],y[i+5],y[i+10],y[i+15]]
    Y = y[i+20]
    pre = srt.predict(X)
    Y_pre.append(pre)
t = np.arange(2000,4900)

plt.plot(t,Y_ac,label='Real output',linewidth=2,
    linestyle="dashdot",)
plt.plot(t,Y_pre,label = 'OSRT output',linewidth=2,
    linestyle="dashdot",)
plt.xlabel("Prediction time")
plt.ylabel("Output value")
plt.legend()
plt.show()
res = np.array(Y_pre) - np.array(Y_ac)
#plt.subplot(1,2,2)
plt.plot(t,res,linewidth=2,
    linestyle="dashdot",)
plt.xlabel("Prediction time")
plt.ylabel("Errors")
plt.legend()
plt.show()
print(math.sqrt(np.mean(res**2)))
print(np.mean(abs(res)))