

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

import matplotlib.pyplot as plt
import seaborn as sns

import operator
from functools import reduce
from functools import partial

from timeit import default_timer
from utilities4 import *
import matplotlib

font = {'size'   : 28}
matplotlib.rc('font', **font)

torch.manual_seed(0)
np.random.seed(0)

T = 800
s = 2048

# dataloader = MatReader('pred/KS_truth.mat')
# truth = dataloader.read_field('pred')[:T,::1]
dataloader = MatReader('data/KS_L64pi_s8192_T1000_t10000_test.mat')
truth = dataloader.read_field('u')[:,::10,::4].reshape(-1, s)


dataloader = MatReader('pred/KS_fourier_res_N150000_s2048_T150_ep400_m32_w32.mat')
pred0 = dataloader.read_field('pred')[:T,::1]

dataloader = MatReader('pred/KS_lstm_N150000_s2048_ep400_l2_w500.mat')
pred1 = dataloader.read_field('pred')[:T,::1]

dataloader = MatReader('pred/KS_gru_N150000_s2048_ep400_l2_w500.mat')
pred2 = dataloader.read_field('pred')[:T,::1]

print(truth.shape)
print(pred0.shape)
print(pred1.shape)
print(pred2.shape)

# T_in = 50
# T_step = [1, 2, 5, 10, 20, 50, 100]
# dataloader = MatReader('data/KS_L64pi_s8192_T1000_t10000_test.mat')
# truth = dataloader.read_field('u')[0, 10 * T_in:10 * (T + T_in):10, ::16]
# print(truth.shape)
#
# pred = []
# for i in T_step:
#     dataloader = MatReader('pred/KS_fourier_step_N10000_Tstep'+str(i)+'_T10_ep500_m32_w32.mat')
#     predt = dataloader.read_field('pred')[:9500//i]
#     print(i, predt.shape)
#     pred.append(predt)

# s_list = [64, 128, 256, 512, 1024, 2048, 4096, 8192]
# pred = []
# for s in s_list:
#     dataloader = MatReader('pred/KS_fourier_res_N2000_s'+str(s)+'_T20_ep500_m32_w32.mat')
#     predt = dataloader.read_field('pred')[:950]
#     print(s, predt.shape)
#     pred.append(predt)




##############################################################
#Set-up MMD
##############################################################


# Ntest = T
# alpha_bases = [0.5, 1.0, 2.0, 4.0, 8.0, 16.0]
# alphas = [1.0/(2.0 * a**2) for a in alpha_bases]
#
# #
# # mmd_err1 = mmd_loss(truth, pred1, alphas).item()
# # mmd_err2 = mmd_loss(truth, pred2, alphas).item()
# # mmd_err3 = mmd_loss(truth, pred3, alphas).item()
# # mmd_err4 = mmd_loss(truth, pred4, alphas).item()
# # print(mmd_err1)
# # print(mmd_err2)
# # print(mmd_err3)
# # print(mmd_err4)
#
# Ntest_truth = 950
# for i in range(8):
#     Ntest_pred = pred[i].shape[0]
#     mmd_loss = MMDStatistic(Ntest_truth, Ntest_pred)
#     mmd_err = mmd_loss(truth, pred[i], alphas).item()
#     print(mmd_err)

##############################################################
# ### Hist
##############################################################

# truth = truth.numpy()
# pred1 = pred1.numpy()
#
# plt.hist(truth.reshape(-1), bins=100)
# plt.show()
# plt.hist(pred1.reshape(-1), bins=100)
# plt.show()

##############################################################
### FFT plot
##############################################################

# kmax = 200
#
# def spectrum(u):
#     u = torch.fft.rfft(u)
#     # k = u.shape[1]//2 + 1
#     # u = u[:, :k]
#     print(u.shape)
#     u = torch.sqrt( u.abs()**2 )
#     u = u.mean(dim=0)
#     u = u
#
#     return u
#
# truth = spectrum(truth).numpy()[1:100]
# pred0 = spectrum(pred0).numpy()[1:100]
# pred1 = spectrum(pred1).numpy()[1:100]
# pred2 = spectrum(pred2).numpy()[1:100]
# # pred3 = spectrum(pred3).numpy()[1:]
# # pred4 = spectrum(pred4).numpy()[1:]
#
# print(truth.shape)
# print(pred0.shape)
# linewidth = 3
#
# fig, ax = plt.subplots(figsize=(10,10))
# ax.set_yscale('log')
# ax.plot(truth, 'k', label='truth', linewidth=linewidth)
# ax.plot(pred0, 'orangered', label='MNO', linewidth=linewidth)
# ax.plot(pred1, 'cornflowerblue', label='LSTM', linewidth=linewidth)
# ax.plot(pred2, 'mediumblue', label='GRU', linewidth=linewidth)
# # ax.plot(pred3, label='LSTM', linewidth=linewidth)
# # ax.plot(pred4, label='LSTM_TF', linewidth=linewidth)
#
# # for i in range(8):
# #     pred1 = spectrum(pred[i]).numpy()[1:kmax]
# #     print(pred1.shape)
# #     ax.plot(pred1, label=str(s_list[i]))
# plt.xlabel('wavenumber')
# plt.ylabel('energy')
# leg = plt.legend(loc='best')
# leg.get_frame().set_alpha(0.5)
# plt.grid()
# plt.show()

##############################################################
### FFT plot time auto-corelation
##############################################################

# kmax = 200
#
# def spectrum(u):
#     u = torch.fft.rfft(u)
#     # k = u.shape[1]//2 + 1
#     # u = u[:, :k]
#     print(u.shape)
#     u = torch.sqrt( u.abs()**2 )
#
#     return u
#
# index = 10
# truth = spectrum(truth).numpy()[:,index]
# pred0 = spectrum(pred0).numpy()[:,index]
# pred1 = spectrum(pred1).numpy()[:,index]
# pred2 = spectrum(pred2).numpy()[:,index]
# # pred3 = spectrum(pred3).numpy()[1:]
# # pred4 = spectrum(pred4).numpy()[1:]
# print(truth.shape)
# print(np.mean(truth))
#
# T_show = 100
# truth = np.correlate(truth, truth, mode='full')[T-1:T-1+T_show]
# pred0 = np.correlate(pred0, pred0, mode='full')[T-1:T-1+T_show]
# pred1 = np.correlate(pred1, pred1, mode='full')[T-1:T-1+T_show]
# pred2 = np.correlate(pred2, pred2, mode='full')[T-1:T-1+T_show]
#
#
# print(truth.shape)
# print(pred0.shape)
# linewidth = 3
#
# fig, ax = plt.subplots(figsize=(10,10))
# ax.set_yscale('log')
# ax.plot(truth, 'k', label='truth', linewidth=linewidth)
# ax.plot(pred0, 'orangered', label='MNO', linewidth=linewidth)
# ax.plot(pred1, 'cornflowerblue', label='LSTM', linewidth=linewidth)
# ax.plot(pred2, 'mediumblue', label='GRU', linewidth=linewidth)
# # ax.plot(pred3, label='LSTM', linewidth=linewidth)
# # ax.plot(pred4, label='LSTM_TF', linewidth=linewidth)
#
# # for i in range(8):
# #     pred1 = spectrum(pred[i]).numpy()[1:kmax]
# #     print(pred1.shape)
# #     ax.plot(pred1, label=str(s_list[i]))
# plt.xlabel('time')
# plt.ylabel('auto-correlation')
#
# ax.set_yticks([1e6,1e7,1e8,1e9])
#
# leg = plt.legend(loc='best')
# leg.get_frame().set_alpha(0.5)
# plt.grid()
# plt.show()



##############################################################
#FFT MMD
##############################################################
#
# kmax = 30
# def spectrum(u):
#     u = torch.fft.rfft(u)
#     k = (u.shape[1]//2 + 1)
#     u = u[:, :kmax]
#     u = torch.sqrt(u.abs() ** 2)
#     u = u / k
#     return u
#
# truth = spectrum(truth)
# # pred1 = spectrum(pred1)
# # pred2 = spectrum(pred2)
# # pred3 = spectrum(pred3)
# # pred4 = spectrum(pred4)
#
#
# Ntest = T
# alpha_bases = [0.5, 1.0, 2.0, 4.0, 8.0, 16.0]
# alphas = [1.0/(2.0 * a**2) for a in alpha_bases]
# # mmd_loss = MMDStatistic(Ntest, Ntest)
# #
# # mmd_err1 = mmd_loss(truth, pred1, alphas).item()
# # mmd_err2 = mmd_loss(truth, pred2, alphas).item()
# # mmd_err3 = mmd_loss(truth, pred3, alphas).item()
# # mmd_err4 = mmd_loss(truth, pred4, alphas).item()
# # print(mmd_err1)
# # print(mmd_err2)
# # print(mmd_err3)
# # print(mmd_err4)
#
# Ntest_truth = 950
# for i in range(7):
#     Ntest_pred = pred[i].shape[0]
#     pred1 = spectrum(pred[i])
#     mmd_loss = MMDStatistic(Ntest_truth, Ntest_pred)
#     mmd_err = mmd_loss(truth, pred1, alphas).item()
#     print(mmd_err)

##############################################################
# correlation in time
##############################################################
#
# # pointwise
# xtruth = truth[:,0]
# result_truth = np.correlate(xtruth, xtruth, mode='full')[T-1:T-1+200]
# xpred1 = pred1[:,0]
# result_pred1 = np.correlate(xpred1, xpred1, mode='full')[T-1:T-1+200]
# xpred2 = pred2[:,0]
# result_pred2 = np.correlate(xpred2, xpred2, mode='full')[T-1:T-1+200]
# xpred3 = pred3[:,0]
# result_pred3 = np.correlate(xpred3, xpred3, mode='full')[T-1:T-1+200]
#
# fig = plt.figure(1)
# ax = fig.add_subplot(1, 1, 1)
# # ax.set_yscale('log')
# ax.plot(result_truth, label='truth')
# ax.plot(result_pred1, label='one-step')
# ax.plot(result_pred2, label='RNN')
# ax.plot(result_pred3, label='LSTM')
# plt.title('autocorrelation, x=1')
#
# leg = plt.legend(loc='best')
# leg.get_frame().set_alpha(0.5)
# plt.show()
#
# t = 200
# result_truth = np.zeros(t, )
# result_pred1 = np.zeros(t, )
# result_pred2 = np.zeros(t, )
# result_pred3 = np.zeros(t, )
# for x in range(s):
#     xtruth = truth[:,x]
#     result_truth += np.correlate(xtruth, xtruth, mode='full')[T-1:T-1+t]
#     xpred1 = pred1[:,x]
#     result_pred1 += np.correlate(xpred1, xpred1, mode='full')[T-1:T-1+t]
#     xpred2 = pred2[:,x]
#     result_pred2 += np.correlate(xpred2, xpred2, mode='full')[T-1:T-1+t]
#     xpred3 = pred3[:,x]
#     result_pred3 += np.correlate(xpred3, xpred3, mode='full')[T-1:T-1+t]
#
# result_truth = result_truth / s
# result_pred1 = result_pred1 / s
# result_pred2 = result_pred2 / s
# result_pred3 = result_pred3 / s
#
# fig = plt.figure(1)
# ax = fig.add_subplot(1, 1, 1)
# # ax.set_yscale('log')
# ax.plot(result_truth, label='truth')
# ax.plot(result_pred1, label='one-step')
# ax.plot(result_pred2, label='RNN')
# ax.plot(result_pred3, label='LSTM')
# plt.title('autocorrelation, averaged')
#
# leg = plt.legend(loc='best')
# leg.get_frame().set_alpha(0.5)
# plt.show()

########################################
# spatial
########################################

# xtruth = truth[-1, :]
# result_truth = np.correlate(xtruth, xtruth, mode='full')[s - 1:] + np.correlate(xtruth, xtruth, mode='full')[:s]
# xpred0 = pred0[-1, :]
# result_pred0 = np.correlate(xpred0, xpred0, mode='full')[s - 1:] + np.correlate(xpred0, xpred0, mode='full')[:s]
# xpred1 = pred1[-1, :]
# result_pred1 = np.correlate(xpred1, xpred1, mode='full')[s - 1:] + np.correlate(xpred1, xpred1, mode='full')[:s]
# xpred2 = pred2[-1, :]
# result_pred2 = np.correlate(xpred2, xpred2, mode='full')[s - 1:] + np.correlate(xpred2, xpred2, mode='full')[:s]
#
# linewidth = 3
# fig, ax = plt.subplots(figsize=(10,10))
# ax.plot(result_truth, 'k', label='truth', linewidth=linewidth)
# ax.plot(result_pred0, 'orangered', label='MNO', linewidth=linewidth)
# ax.plot(result_pred1, 'cornflowerblue', label='LSTM', linewidth=linewidth)
# ax.plot(result_pred2, 'mediumblue', label='GRU', linewidth=linewidth)
#
# plt.xlabel('x')
# plt.ylabel('correlation')
# leg = plt.legend(loc='best')
# leg.get_frame().set_alpha(0.5)
# plt.grid()
# plt.show()
#



# result_truth = np.zeros(s, )
# result_pred1 = np.zeros(s, )
# result_pred2 = np.zeros(s, )
# result_pred0 = np.zeros(s, )
# T = 400
# for t in range(T):
#     xtruth = truth[t, :]
#     result_truth += np.correlate(xtruth, xtruth, mode='full')[s - 1:] + np.correlate(xtruth, xtruth, mode='full')[:s]
#     xpred1 = pred1[t, :]
#     result_pred1 += np.correlate(xpred1, xpred1, mode='full')[s - 1:] + np.correlate(xpred1, xpred1, mode='full')[:s]
#     xpred2 = pred2[t, :]
#     result_pred2 += np.correlate(xpred2, xpred2, mode='full')[s - 1:] + np.correlate(xpred2, xpred2, mode='full')[:s]
#     xpred0 = pred0[t, :]
#     result_pred0 += np.correlate(xpred0, xpred0, mode='full')[s - 1:] + np.correlate(xpred0, xpred0, mode='full')[:s]
#
#
# result_truth = result_truth / T
# result_pred1 = result_pred1 / T
# result_pred2 = result_pred2 / T
# result_pred0 = result_pred0 / T
#
# linewidth = 3
# fig, ax = plt.subplots(figsize=(10,10))
#
# ax.plot(result_pred1, '--', color='cornflowerblue', label='LSTM', linewidth=linewidth)
# ax.plot(result_pred2, '--', color='mediumblue', label='GRU', linewidth=linewidth)
#
# ax.plot(result_pred0, '--', color='orangered', label='MNO', linewidth=linewidth)
# ax.plot(result_truth, 'k', label='truth', linewidth=linewidth)
# plt.xlabel('x')
# plt.ylabel('correlation')
# leg = plt.legend(loc='best')
# leg.get_frame().set_alpha(0.5)
# plt.grid()
# plt.show()

##############################################################
# ### Hist
##############################################################

# truth = truth.numpy()
# pred0 = pred0.numpy()
# pred1 = pred1.numpy()
# pred2 = pred2.numpy()
# # pred3 = pred3.numpy()
# # pred4 = pred4.numpy()
#
# # plt.hist(truth.reshape(-1), bins=100)
# # plt.show()
# # plt.hist(pred0.reshape(-1), bins=100)
# # plt.show()
#
#
# linewidth = 5
# fig, ax = plt.subplots(figsize=(10,10))
# sns.distplot(ax=ax, a=pred1.reshape(-1), kde=True,hist=False, bins=200, label='LSTM', kde_kws=dict(linewidth=linewidth,linestyle='--', color='cornflowerblue'))
# sns.distplot(ax=ax, a=pred2.reshape(-1), kde=True,hist=False, bins=200, label='GRU', kde_kws=dict(linewidth=linewidth,linestyle='--', color='mediumblue'))
# sns.distplot(ax=ax, a=truth.reshape(-1), kde=True,hist=False, bins=200, label='truth', kde_kws=dict(linewidth=linewidth, color='k'))
# sns.distplot(ax=ax, a=pred0.reshape(-1), kde=True,hist=False, bins=200, label='MNO', kde_kws=dict(linewidth=linewidth,linestyle='--', color='orangered'))
#
# # ax.set_xlim(-6,6)
# # ax.set_ylim(0.05,0.175)
# # ax.set_yticks([0.05,0.10,0.15])
#
# plt.legend(prop={'size': 20})
# plt.title('pixelwise distribution')
# plt.xlabel('velocity u(x)')
# plt.ylabel('density')
# plt.show()


##############################################################
#  Kinetic enegry
##############################################################

# ##KE = < u^2 >
# ##TKE = 1/T * \int_T  (u - mean(u))^2 dt
#
# def KE(u):
#     T = u.shape[0]
#     u = u.reshape(T, -1)
#     return torch.mean(u**2, dim=1)
#
# # def TKE(u):
# #     T = u.shape[0]
# #     u = u.reshape(T, 64*64*2)
# #     umean = torch.mean(u, dim=0)
# #     return torch.mean((u-umean)**2, dim=1)
#
#
# Etruth = KE(truth)
# Epred0 = KE(pred0)
# Epred1 = KE(pred1)
# Epred2 = KE(pred2)
# # Epred3 = KE(pred3)
# # Epred4 = KE(pred4)
# # Epred5 = KE(pred5)
#
# # print(torch.mean(Epred3))
# # print(torch.mean(Epred4))
# # print(torch.mean(Epred5))



# T = np.linspace(0,1000,1000)[1:401]
#
# fig = plt.figure(figsize=(20,10))
# gs = fig.add_gridspec(nrows=3, ncols=1, hspace=0)
# ax0 = fig.add_subplot(gs[0,0])
# ax1 = fig.add_subplot(gs[1,0])
# ax2 = fig.add_subplot(gs[2,0])
#
# ax0.plot(T, Etruth[:400], 'black', label='Truth')
# # ax0.plot(T, Etruth[400:800], 'black', label='')
# # ax0.plot(T, Etruth[800:1200], 'black', label='')
# ax1.plot(T, Epred2[:400], 'firebrick', label='MNO')
# ax2.plot(T, Epred5[:400], 'royalblue', label='UNet')
#
# leg = ax0.legend(loc='upper right')
# leg.get_frame().set_alpha(0.5)
# leg = ax1.legend(loc='upper right')
# leg.get_frame().set_alpha(0.5)
# leg = ax2.legend(loc='upper right')
# leg.get_frame().set_alpha(0.5)
#
# # Hide x labels and tick labels for all but bottom plot.
#
# ax2.label_outer()
# # fig.suptitle('Kinetic energy')
# ax2.set_xlabel('t (s)')
# ax1.set_ylabel('Kinetic energy')
#
# fig.show()

#
# # histogram
#
# fig, ax = plt.subplots(figsize=(10,10))
# linewidth = 5
# sns.distplot(ax=ax, a=Epred1.reshape(-1), kde=True,hist=False, bins=200, label='LSTM', kde_kws=dict(linewidth=linewidth,linestyle='--', color='cornflowerblue'))
# sns.distplot(ax=ax, a=Epred2.reshape(-1), kde=True,hist=False, bins=200, label='GRU', kde_kws=dict(linewidth=linewidth,linestyle='--', color='mediumblue'))
# sns.distplot(ax=ax, a=Etruth.reshape(-1), kde=True,hist=False, bins=200, label='truth', kde_kws=dict(linewidth=linewidth, color='k'))
# sns.distplot(ax=ax, a=Epred0.reshape(-1), kde=True,hist=False, bins=200, label='MNO', kde_kws=dict(linewidth=linewidth,linestyle='--', color='orangered'))
#
# # ax.set_xlim(0.55,0.8)
# # ax.set_ylim(0.05,0.175)
# # ax.set_yticks([0.05,0.10,0.15])
#
# plt.legend(prop={'size': 20})
# plt.xlabel('Kinetic energy')
# plt.ylabel('density')
# plt.show()


##############################################################
#  Kinetic enegry
##############################################################


# class PCA(object):
#     def __init__(self, x, dim, subtract_mean=True):
#         super(PCA, self).__init__()
#
#         # Input size
#         x_size = list(x.size())
#
#         # Input data is a matrix
#         assert len(x_size) == 2
#
#         # Reducing dimension is less than the minimum of the
#         # number of observations and the feature dimension
#         assert dim <= min(x_size)
#
#         self.reduced_dim = dim
#
#         if subtract_mean:
#             self.x_mean = torch.mean(x, dim=0).view(1, -1)
#         else:
#             self.x_mean = torch.zeros((x_size[1],), dtype=x.dtype, layout=x.layout, device=x.device)
#
#         # SVD
#         U, S, V = torch.svd(x - self.x_mean)
#         V = V.t()
#
#         # Flip sign to ensure deterministic output
#         max_abs_cols = torch.argmax(torch.abs(U), dim=0)
#         signs = torch.sign(U[max_abs_cols, range(U.size()[1])]).view(-1, 1)
#         V *= signs
#
#         self.W = V.t()[:, 0:self.reduced_dim]
#         self.sing_vals = S.view(-1, )
#
#     def cuda(self):
#         self.W = self.W.cuda()
#         self.x_mean = self.x_mean.cuda()
#         self.sing_vals = self.sing_vals.cuda()
#
#     def encode(self, x):
#         return (x - self.x_mean).mm(self.W)
#
#     def decode(self, x):
#         return x.mm(self.W.t()) + self.x_mean
#
#     def forward(self, x):
#         return self.decode(self.encode(x))
#
#     def __call__(self, x):
#         return self.forward(x)
#
#
# x_pca = PCA(truth, 50, subtract_mean=False)
#
# index = 0
# pred0 = x_pca.encode(pred0)[:,index]
# pred1 = x_pca.encode(pred1)[:,index]
# pred2 = x_pca.encode(pred2)[:,index]
#
# pred0 = np.correlate(pred0, pred0, mode='full')[800-1:] /(800-np.array(range(800)))
# pred1 = np.correlate(pred1, pred1, mode='full')[800-1:] /(800-np.array(range(800)))
# pred2 = np.correlate(pred2, pred2, mode='full')[800-1:] /(800-np.array(range(800)))
#
# print(pred0.shape)
#
# linewidth = 3
# fig, ax = plt.subplots(figsize=(10,10))
# # ax.plot(result_truth, 'k', label='truth', linewidth=linewidth)
# ax.plot(pred0[:500], 'orangered', label='MNO', linewidth=linewidth)
# ax.plot(pred1[:500], 'cornflowerblue', label='LSTM', linewidth=linewidth)
# ax.plot(pred2[:500], 'mediumblue', label='GRU', linewidth=linewidth)
#
# plt.xlabel('time')
# plt.ylabel('auto-correlation')
# leg = plt.legend(loc='best')
# leg.get_frame().set_alpha(0.5)
# plt.grid()
# plt.show()
