import matplotlib.pyplot as plt
from model.model import ViTime
import numpy as np
import torch
import numpy as np
from scipy.interpolate import interp1d
from scipy import interpolate

def interpolate_to_512(original_sequence):
    n = len(original_sequence)
    x_original = np.linspace(0, 1, n)
    x_interpolated = np.linspace(0, 1, 512)
    f = interpolate.interp1d(x_original, original_sequence)
    interpolated_sequence = f(x_interpolated)
    return interpolated_sequence
def inverse_interpolate(processed_sequence, original_length):
    processed_length = len(processed_sequence)
    z = int(original_length * 720 / 512)
    x_processed = np.linspace(0, 1, processed_length)
    x_inverse = np.linspace(0, 1, z)
    f_inverse = interpolate.interp1d(x_processed, processed_sequence)
    inverse_interpolated_sequence = f_inverse(x_inverse)
    return inverse_interpolated_sequence

modelPath=r'C:\Users\user\Downloads\ViTime_V1.pth'
deviceNum=0
torch.cuda.set_device(deviceNum)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(modelPath, map_location=device)##V4 better
args=checkpoint['args']
args.device = device
args.flag = 'test'
##### args.upscal=True   max input length =512    max prediction length =720
##### args.upscal=False  max input length =512*2  max prediction length =720*2
args.upscal=True
model = ViTime(args=args)
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()


inputlength=512
xData=np.sin(np.arange(inputlength)/10)+np.sin(np.arange(inputlength)/5+50)+np.cos(np.arange(inputlength)*5+50)+np.cos(np.arange(inputlength)+50)*2
# xData=np.sin(np.arange(inputlength)/40)


xData1=np.copy(xData)
interpolated_sequence = interpolate_to_512(xData)
args.realInputLength = len(interpolated_sequence)
yp = model.inference(interpolated_sequence).flatten()
yp = inverse_interpolate(yp, len(xData)).flatten()


plt.plot(np.concatenate([xData,yp.flatten()],axis=0),label='Prediction')
plt.plot(xData1,label='Input Sequence')
plt.legend()
plt.show()

