

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

font = {'size'   : 28}

matplotlib.rc('font', **font)

import operator
from functools import reduce
from functools import partial

from timeit import default_timer
from utilities4 import *

torch.manual_seed(0)
np.random.seed(0)

# T = 400
s = 64

# dataloader = MatReader('data/ns_data_V100_N10_T1000.mat')
# truth = dataloader.read_field('u')[0,:,:,-T:].reshape(s*s,T).transpose(0,1)
#
# dataloader = MatReader('pred/semigroup_ns_fourier_one_noBN_V100_T200_N100_ep0_m12_w32.mat')
# pred1 = dataloader.read_field('pred')[:,:,:].reshape(s*s,T).transpose(0,1)
# dataloader = MatReader('pred/semigroup_ns_unet_V100_T200_N100_ep0_m12_w32.mat')
# pred2 = dataloader.read_field('pred')[:,:,:].reshape(s*s,T).transpose(0,1)

# dataloader = MatReader('data/ns_data_V1000_N10_T1000.mat')
# truth = dataloader.read_field('u')[0,:,:,-T:].reshape(s*s,T).transpose(0,1)
#
# dataloader = MatReader('pred/semigroup_ns_fourier_one_noBN_V1000_T400_N100_ep0_m12_w32.mat')
# pred1 = dataloader.read_field('pred')[:,:,:].reshape(s*s,T).transpose(0,1)
# dataloader = MatReader('pred/semigroup_ns_unet_V1000_T400_N100_ep0_m12_w32.mat')
# pred2 = dataloader.read_field('pred')[:,:,:].reshape(s*s,T).transpose(0,1)

Re = 40
index = 1
T = 10000


data = np.load('data/KFvorticity_Re'+str(Re)+'_N25_part1.npy')[:,100:500,::4,::4].reshape(-1,64,64)
truth = torch.tensor(data, dtype=torch.float)
# truth = w_to_u(truth)[..., index]

############################################################################
# RE40
dataloader = MatReader('pred/KF_w_fourier40_T10000_k0_gTrue_ep200_m20_w64.mat')
pred0 = dataloader.read_field('pred')[:,:,:T,0].permute(2,0,1)
dataloader = MatReader('pred/KF_w_fourier40_T10000_k1_gTrue_ep200_m20_w64.mat')
pred1 = dataloader.read_field('pred')[:,:,:T,0].permute(2,0,1)
dataloader = MatReader('pred/KF_w_fourier40_T10000_k2_gTrue_ep200_m20_w64.mat')
pred2 = dataloader.read_field('pred')[:,:,:T,0].permute(2,0,1)

dataloader = MatReader('pred/KF_w_unet40_T10000_k0_gTrue_ep200_m20_w64.mat')
pred3 = dataloader.read_field('pred')[:,:,:T,0].permute(2,0,1)
dataloader = MatReader('pred/KF_w_unet40_T10000_k1_gTrue_ep200_m20_w64.mat')
pred4 = dataloader.read_field('pred')[:,:,:T,0].permute(2,0,1)
dataloader = MatReader('pred/KF_w_unet40_T10000_k2_gTrue_ep200_m20_w64.mat')
pred5 = dataloader.read_field('pred')[:,:,:T,0].permute(2,0,1)

############################################################################
# RE500

# dataloader = MatReader('pred/KF_fourier500_N900_k0_gTrue_ep200_m20_w64.mat')
# pred0 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)
# dataloader = MatReader('pred/KF_fourier500_N900_k1_gTrue_ep200_m20_w64.mat')
# pred1 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)
# dataloader = MatReader('pred/KF_fourier500_N900_k2_gTrue_ep200_m20_w64.mat')
# pred2 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)


# dataloader = MatReader('pred/KF_w_fourier500_N900_k0_gTrue_ep50_m20_w64.mat')
# pred0 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)
# dataloader = MatReader('pred/KF_w_fourier500_N900_k1_gTrue_ep50_m20_w64.mat')
# pred1 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)
# dataloader = MatReader('pred/KF_w_fourier500_N900_k2_gTrue_ep50_m20_w64.mat')
# pred2 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)

# dataloader = MatReader('pred/KF_f_fourier500_N900_k0_gTrue_ep50_m20_w64.mat')
# pred0 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)
# dataloader = MatReader('pred/KF_f_fourier500_N900_k1_gTrue_ep50_m20_w64.mat')
# pred1 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)
# dataloader = MatReader('pred/KF_f_fourier500_N900_k2_gTrue_ep50_m20_w64.mat')
# pred2 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)

# dataloader = MatReader('pred/KF_u_fourier500_N900_k0_gTrue_ep50_m20_w64.mat')
# pred0 = dataloader.read_field('pred')[:,:,:T, index].permute(2,0,1)
# dataloader = MatReader('pred/KF_u_fourier500_N900_k1_gTrue_ep50_m20_w64.mat')
# pred1 = dataloader.read_field('pred')[:,:,:T, index].permute(2,0,1)
# dataloader = MatReader('pred/KF_u_fourier500_N900_k2_gTrue_ep50_m20_w64.mat')
# pred2 = dataloader.read_field('pred')[:,:,:T, index].permute(2,0,1)

# dataloader = MatReader('pred/KF_unet500_N900_k0_gTrue_ep200.mat')
# pred0 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)
# dataloader = MatReader('pred/KF_unet500_N900_k1_gTrue_ep200.mat')
# pred1 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)
# dataloader = MatReader('pred/KF_unet500_N900_k2_gTrue_ep200.mat')
# pred2 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)

# dataloader = MatReader('pred/KF_unet500_N900_k0_gTrue_ep50.mat')
# pred3 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)
# dataloader = MatReader('pred/KF_unet500_N900_k1_gTrue_ep50.mat')
# pred4 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)
# dataloader = MatReader('pred/KF_unet500_N900_k2_gTrue_ep50.mat')
# pred5 = dataloader.read_field('pred')[:,:,:T].permute(2,0,1)

############################################################################
# data preparation

# index = [0,1]
# # index = 0
# truth = w_to_u(truth)[..., index]
# pred0 = w_to_u(pred0)[..., index]
# pred1 = w_to_u(pred1)[..., index]
# pred2 = w_to_u(pred2)[..., index]
# pred3 = w_to_u(pred3)[..., index]
# pred4 = w_to_u(pred4)[..., index]
# pred5 = w_to_u(pred5)[..., index]

print(truth.shape)

# dataloader = MatReader('pred/KF_vol_Re'+str(Re)+'_k0_a1.01.0_N40_ep100_m20_w32.mat')
# pred0 = dataloader.read_field('pred')[:,:,:T,index].permute(2,0,1)
# dataloader = MatReader('pred/KF_vol_Re'+str(Re)+'_k1_a1.01.0_N40_ep100_m20_w32.mat')
# pred1 = dataloader.read_field('pred')[:,:,:T,index].permute(2,0,1)
# dataloader = MatReader('pred/KF_vol_Re'+str(Re)+'_k2_a1.01.0_N40_ep100_m20_w32.mat')
# pred2 = dataloader.read_field('pred')[:,:,:T,index].permute(2,0,1)
# dataloader = MatReader('pred/KF_vol_Re'+str(Re)+'_a2.04.0_N40_ep100_m20_w32.mat')
# pred3 = dataloader.read_field('pred')[:,:,:T,index].permute(2,0,1)
# dataloader = MatReader('pred/KF_vol_Re'+str(Re)+'_a4.016.0_N40_ep100_m20_w32.mat')
# pred4 = dataloader.read_field('pred')[:,:,:T,index].permute(2,0,1)
#
# dataloader = MatReader('pred/KF_vol_Re500_k2_gTrue_a2.04.0_N20_ep100_m20_w32.mat')
# pred2a1 = dataloader.read_field('pred')[:,:,:T,index].permute(2,0,1)
# dataloader = MatReader('pred/KF_vol_Re500_k2_gTrue_a4.016.0_N20_ep100_m20_w32.mat')
# pred2a2 = dataloader.read_field('pred')[:,:,:T,index].permute(2,0,1)

print(torch.mean(truth))
print(torch.mean(pred0))
print(torch.mean(pred1))
print(torch.mean(pred2))
print(torch.mean(pred3))
print(torch.mean(pred4))
print(torch.mean(pred5))


##############################################################
#Set-up MMD
##############################################################

# Ntest = T
# alpha_bases = [0.5, 1.0, 2.0, 4.0, 8.0, 16.0]
# alphas = [1.0/(2.0 * a**2) for a in alpha_bases]
# mmd_loss = MMDStatistic(Ntest, Ntest)
#
# mmd_err1 = mmd_loss(truth, pred1, alphas).item()
# mmd_err2 = mmd_loss(truth, pred2, alphas).item()
# print(mmd_err1)
# print(mmd_err2)

##############################################################
# ### Hist
##############################################################

# truth = truth.numpy()
# pred0 = pred0.numpy()
# pred1 = pred1.numpy()
# pred2 = pred2.numpy()
# pred3 = pred3.numpy()
# pred4 = pred4.numpy()
# pred5 = pred5.numpy()
#
# # plt.hist(truth.reshape(-1), bins=100)
# # plt.show()
# # plt.hist(pred0.reshape(-1), bins=100)
# # plt.show()
#
# fig, ax = plt.subplots(figsize=(10,10))
# linewidth = 5
# sns.distplot(ax=ax, a=pred5.reshape(-1), kde=True, hist=False, bins=50, label='UNet', kde_kws=dict(linewidth=linewidth,linestyle='--'))
# sns.distplot(ax=ax, a=pred2.reshape(-1), kde=True, hist=False, bins=50, label='MNO', kde_kws=dict(linewidth=linewidth,linestyle='--'))
# sns.distplot(ax=ax, a=truth.reshape(-1), kde=True, hist=False, bins=50, label='truth', kde_kws=dict(linewidth=linewidth))
# # sns.distplot(pred3.reshape(-1),  kde=True, label='h2 1/2')
# # sns.distplot(pred4.reshape(-1),  kde=True, label='h2 1/4')
#
# ax.set_xlim(-2,2)
# # ax.set_ylim(0.05,0.175)
# # ax.set_yticks([0.05,0.10,0.15])
#
# plt.legend(prop={'size': 20})
# plt.title('pixelwise distribution')
# plt.xlabel('velocity u(x)')
# plt.ylabel('density')
# plt.show()
# ##############################################################
### FFT plot
##############################################################

# def spectrum1(u):
#     u = u.reshape(T,s,s)
#     u = torch.rfft(u, 2, normalized=False, onesided=False)
#     u = u[:, :10, :10, :]
#     u = torch.sqrt( u[:, :, :, 0]**2 + u[:, :, :, 1]**2 )
#     u = u.reshape(T, -1)
#
#     u = u.mean(dim=0)
#     return u

# def spectrum2(u):
#     u = u.reshape(T, s, s)
#     # u = torch.rfft(u, 2, normalized=False, onesided=False)
#     u = torch.fft.fft2(u)
#     # ur = u[..., 0]
#     # uc = u[..., 1]
#
#
#     # 2d wavenumbers following Pytorch fft convention
#     k_max = s // 2
#     wavenumers = torch.cat((torch.arange(start=0, end=k_max, step=1), \
#                             torch.arange(start=-k_max, end=0, step=1)), 0).repeat(s, 1)
#     k_x = wavenumers.transpose(0, 1)
#     k_y = wavenumers
#     # Sum wavenumbers
#     sum_k = torch.abs(k_x) + torch.abs(k_y)
#     sum_k = sum_k.numpy()
#     # Remove symmetric components from wavenumbers
#     index = -1.0 * np.ones((s, s))
#     index[0:k_max + 1, 0:k_max + 1] = sum_k[0:k_max + 1, 0:k_max + 1]
#
#
#
#     spectrum = np.zeros((T, s))
#     for j in range(1, s + 1):
#         ind = np.where(index == j)
#         # spectrum[:, j - 1] = np.sqrt((ur[:, ind[0], ind[1]].sum(axis=1)) ** 2
#                                      # + (uc[:, ind[0], ind[1]].sum(axis=1)) ** 2)
#         spectrum[:, j - 1] = np.sqrt( (u[:, ind[0], ind[1]].sum(axis=1)).abs() ** 2)
#
#
#     spectrum = spectrum.mean(axis=0)
#     return spectrum
#
# struth = spectrum2(truth)[:s//2]
# spred0 = spectrum2(pred0)[:s//2]
# spred1 = spectrum2(pred1)[:s//2]
# spred2 = spectrum2(pred2)[:s//2]
# spred3 = spectrum2(pred3)[:s//2]
# spred4 = spectrum2(pred4)[:s//2]
# spred5 = spectrum2(pred5)[:s//2]
# # spred2a2 = spectrum2(pred2a2)[:s//2]
#
# print(struth.shape)
# fig, ax = plt.subplots(figsize=(10,10))
#
# linewidth = 3
# ax.set_yscale('log')
# ax.plot(struth, 'k', label='truth', linewidth=linewidth)
# ax.plot(spred0, 'r--', label='MNO H0', linewidth=linewidth)
# # ax.plot(spred1, 'orangered', label='MNO H1', linewidth=linewidth)
# ax.plot(spred2, 'r', label='MNO H2', linewidth=linewidth)
# ax.plot(spred3, 'b--', label='Unet H0', linewidth=linewidth)
# # ax.plot(spred4, 'deepskyblue', label='Unet H1', linewidth=linewidth)
# ax.plot(spred5, 'b',label='Unet H2', linewidth=linewidth)
# # ax.plot(spred2a2, label='h2 1-4-16')
#
# # ax.set_xlim(-6,6)
# # ax.set_ylim(0.05,0.175)
# # ax.set_yticks([0.05,0.10,0.15])
# #
# plt.legend(prop={'size': 20})
# plt.title('spectrum of vorticity')
# plt.xlabel('wavenumber')
# plt.ylabel('energy')
#
# leg = plt.legend(loc='best')
# leg.get_frame().set_alpha(0.5)
# plt.show()
#

# ##############################################################
### FFT temporal
##############################################################

# def spectrum2(u):
#     T = u.shape[0]
#     u = u.reshape(T, s, s)
#     # u = torch.rfft(u, 2, normalized=False, onesided=False)
#     u = torch.fft.fft2(u)
#     # ur = u[..., 0]
#     # uc = u[..., 1]
#
#
#     # 2d wavenumbers following Pytorch fft convention
#     k_max = s // 2
#     wavenumers = torch.cat((torch.arange(start=0, end=k_max, step=1), \
#                             torch.arange(start=-k_max, end=0, step=1)), 0).repeat(s, 1)
#     k_x = wavenumers.transpose(0, 1)
#     k_y = wavenumers
#     # Sum wavenumbers
#     sum_k = torch.abs(k_x) + torch.abs(k_y)
#     sum_k = sum_k.numpy()
#     # Remove symmetric components from wavenumbers
#     index = -1.0 * np.ones((s, s))
#     index[0:k_max + 1, 0:k_max + 1] = sum_k[0:k_max + 1, 0:k_max + 1]
#
#
#
#     spectrum = np.zeros((T, s))
#     for j in range(1, s + 1):
#         ind = np.where(index == j)
#         # spectrum[:, j - 1] = np.sqrt((ur[:, ind[0], ind[1]].sum(axis=1)) ** 2
#                                      # + (uc[:, ind[0], ind[1]].sum(axis=1)) ** 2)
#         spectrum[:, j - 1] = np.sqrt( (u[:, ind[0], ind[1]].sum(axis=1)).abs() ** 2)
#
#
#     # spectrum = spectrum.mean(axis=0)
#     return spectrum
#
# index = 10
# # T = 2000
# struth = spectrum2(truth)[:T, index]
# spred0 = spectrum2(pred0)[:T, index]
# spred1 = spectrum2(pred1)[:T, index]
# spred2 = spectrum2(pred2)[:T, index]
# spred3 = spectrum2(pred3)[:T, index]
# spred4 = spectrum2(pred4)[:T, index]
# spred5 = spectrum2(pred5)[:T, index]
# # spred2a2 = spectrum2(pred2a2)[:s//2]
# print(struth.shape)
# print(spred0.shape)
#
# T_show = 1000
# struth = np.correlate(struth, struth, mode='full')[400-1:400-1+T_show]
# spred0 = np.correlate(spred0, spred0, mode='full')[T-1:T-1+T_show]
# spred1 = np.correlate(spred1, spred1, mode='full')[T-1:T-1+T_show]
# spred2 = np.correlate(spred2, spred2, mode='full')[T-1:T-1+T_show]
# spred3 = np.correlate(spred3, spred3, mode='full')[T-1:T-1+T_show]
# spred4 = np.correlate(spred4, spred4, mode='full')[T-1:T-1+T_show]
# spred5 = np.correlate(spred5, spred5, mode='full')[T-1:T-1+T_show]
#
# print(struth.shape)
# fig, ax = plt.subplots(figsize=(10,10))
#
# num = 10000 -  np.array(range(T_show))
# # struth = struth / (400 - np.array(range(400)))
#
#
# spred0 = spred0 / num
# spred1 = spred1 / num
# spred2 = spred2 / num
# spred3 = spred3 / num
# spred4 = spred4 / num
# spred5 = spred5 / num
#
#
# linewidth = 3
# ax.set_yscale('log')
# # ax.plot(struth, 'k', label='truth', linewidth=linewidth)
# ax.plot(spred0, 'r--', label='MNO H0', linewidth=linewidth)
# # ax.plot(spred1, 'orangered', label='MNO H1', linewidth=linewidth)
# ax.plot(spred2, 'r', label='MNO H2', linewidth=linewidth)
# ax.plot(spred3, 'b--', label='Unet H0', linewidth=linewidth)
# # ax.plot(spred4, 'deepskyblue', label='Unet H1', linewidth=linewidth)
# ax.plot(spred5, 'b',label='Unet H2', linewidth=linewidth)
# # ax.plot(spred2a2, label='h2 1-4-16')
#
# # ax.set_xlim(-6,6)
# # ax.set_ylim(0.05,0.175)
# # ax.set_yticks([0.05,0.10,0.15])
# ax.set_yticks([1e4, 1e5, 1e6,1e7])
# plt.legend(prop={'size': 20})
# # plt.title('spectrum of vorticity')
# plt.xlabel('time')
# plt.ylabel('auto-correlation')
#
# leg = plt.legend(loc='best')
# leg.get_frame().set_alpha(0.5)
# plt.show()

##############################################################
# PCA auto-correlation
##############################################################

# x_pca = PCA(truth.reshape(-1, 64*64), 50, subtract_mean=False)
#
# index = 0
# # T = 2000
# struth = x_pca.encode(truth.reshape(-1, 64*64))[:T, index]
# spred0 = x_pca.encode(pred0.reshape(-1, 64*64))[:T, index]
# spred1 = x_pca.encode(pred1.reshape(-1, 64*64))[:T, index]
# spred2 = x_pca.encode(pred2.reshape(-1, 64*64))[:T, index]
# spred3 = x_pca.encode(pred3.reshape(-1, 64*64))[:T, index]
# spred4 = x_pca.encode(pred4.reshape(-1, 64*64))[:T, index]
# spred5 = x_pca.encode(pred5.reshape(-1, 64*64))[:T, index]
# # spred2a2 = spectrum2(pred2a2)[:s//2]
# print(struth.shape)
# print(spred0.shape)
#
# T_show = 5000
# struth = np.correlate(struth, struth, mode='full')[400-1:400-1+T_show]
# spred0 = np.correlate(spred0, spred0, mode='full')[T-1:T-1+T_show]
# spred1 = np.correlate(spred1, spred1, mode='full')[T-1:T-1+T_show]
# spred2 = np.correlate(spred2, spred2, mode='full')[T-1:T-1+T_show]
# spred3 = np.correlate(spred3, spred3, mode='full')[T-1:T-1+T_show]
# spred4 = np.correlate(spred4, spred4, mode='full')[T-1:T-1+T_show]
# spred5 = np.correlate(spred5, spred5, mode='full')[T-1:T-1+T_show]
#
# print(struth.shape)
# fig, ax = plt.subplots(figsize=(10,10))
#
# num = 10000 -  np.array(range(T_show))
# # struth = struth / (400 - np.array(range(400)))
#
#
# spred0 = spred0 / num
# spred1 = spred1 / num
# spred2 = spred2 / num
# spred3 = spred3 / num
# spred4 = spred4 / num
# spred5 = spred5 / num
#
#
# linewidth = 3
# # ax.set_yscale('log')
# # ax.plot(struth, 'k', label='truth', linewidth=linewidth)
# ax.plot(spred0, 'r--', label='MNO H0', linewidth=linewidth)
# # ax.plot(spred1, 'orangered', label='MNO H1', linewidth=linewidth)
# ax.plot(spred2, 'r', label='MNO H2', linewidth=linewidth)
# ax.plot(spred3, 'b--', label='Unet H0', linewidth=linewidth)
# # ax.plot(spred4, 'deepskyblue', label='Unet H1', linewidth=linewidth)
# ax.plot(spred5, 'b',label='Unet H2', linewidth=linewidth)
# # ax.plot(spred2a2, label='h2 1-4-16')
#
# # ax.set_xlim(-6,6)
# # ax.set_ylim(0.05,0.175)
# # ax.set_yticks([0.05,0.10,0.15])
# # ax.set_yticks([1e2, 1e3, 1e4,1e5])
# plt.legend(prop={'size': 20})
# # plt.title('spectrum of vorticity')
# plt.xlabel('time')
# plt.ylabel('auto-correlation')
#
# leg = plt.legend(loc='best')
# leg.get_frame().set_alpha(0.5)
# plt.show()


##############################################################
#FFT MMD
##############################################################

# def spectrum1(u):
#     u = u.reshape(T,s,s)
#     u = torch.rfft(u, 2, normalized=False, onesided=False)
#     u = u[:, :10, :10, :]
#     u = torch.sqrt( u[:, :, :, 0]**2 + u[:, :, :, 1]**2 )
#     u = u.reshape(T, -1)
#     return u

# def spectrum2(u):
#     u = u.reshape(T, s, s)
#     u = torch.rfft(u, 2, normalized=False, onesided=False)
#     ur = u[..., 0]
#     uc = u[..., 1]
#
#
#     # 2d wavenumbers following Pytorch fft convention
#     k_max = s // 2
#     wavenumers = torch.cat((torch.arange(start=0, end=k_max, step=1), \
#                             torch.arange(start=-k_max, end=0, step=1)), 0).repeat(s, 1)
#     k_x = wavenumers.transpose(0, 1)
#     k_y = wavenumers
#     # Sum wavenumbers
#     sum_k = torch.abs(k_x) + torch.abs(k_y)
#     sum_k = sum_k.numpy()
#     # Remove symmetric components from wavenumbers
#     index = -1.0 * np.ones((s, s))
#     index[0:k_max + 1, 0:k_max + 1] = sum_k[0:k_max + 1, 0:k_max + 1]
#
#
#
#     spectrum = np.zeros((T, s))
#     for j in range(1, s + 1):
#         ind = np.where(index == j)
#         spectrum[:, j - 1] = np.sqrt((ur[:, ind[0], ind[1]].sum(axis=1)) ** 2
#                                      + (uc[:, ind[0], ind[1]].sum(axis=1)) ** 2)
#
#
#     # spectrum = spectrum.mean(axis=0)
#     return spectrum
#
# struth = torch.from_numpy(spectrum2(truth))
# spred1 = torch.from_numpy(spectrum2(pred1))
# spred2 = torch.from_numpy(spectrum2(pred2))
#
#
# Ntest = T
# alpha_bases = [0.5, 1.0, 2.0, 4.0, 8.0, 16.0]
# alphas = [1.0/(2.0 * a**2) for a in alpha_bases]
# mmd_loss = MMDStatistic(Ntest, Ntest)
#
# mmd_err1 = mmd_loss(struth, spred1, alphas).item()
# mmd_err2 = mmd_loss(struth, spred2, alphas).item()
# print(mmd_err1)
# print(mmd_err2)

##############################################################
#average
##############################################################
# fig, axs = plt.subplots(2, 4, figsize=(20,8))
#
# truth = torch.mean(truth, dim=0)
# axs[0,0].imshow(truth)
# axs[0,0].set_title('truth')
# # plt.show()
# pred0 = torch.mean(pred0, dim=0)
# axs[0,1].imshow(pred0)
# axs[0,1].set_title('MNO H0')
# # plt.show()
# pred1 = torch.mean(pred1, dim=0)
# axs[0,2].imshow(pred1)
# axs[0,2].set_title('MNO H1')
# # plt.show()
# pred2 = torch.mean(pred2, dim=0)
# axs[0,3].imshow(pred2)
# axs[0,3].set_title('MNO H2')
# # plt.show()
#
# pred3 = torch.mean(pred0, dim=0)
# axs[1,1].imshow(pred0)
# axs[1,1].set_title('UNet H0')
# # plt.show()
# pred4 = torch.mean(pred1, dim=0)
# axs[1,2].imshow(pred1)
# axs[1,2].set_title('UNet H1')
# # plt.show()
# pred5 = torch.mean(pred2, dim=0)
# axs[1,3].imshow(pred2)
# axs[1,3].set_title('UNet H2')
#
# plt.show()

##############################################################
#  temporal auto-correlation
##############################################################


# #pointwise
# x1 = 0
# x2 = 0
# xtruth = truth[:,x1,x2]
# # xtruth = torch.mean(truth,dim=[1,2])
# xtruth = np.correlate(xtruth, xtruth, mode='full')
#
# xpred0 = pred0[:,x1,x2]
# # xpred0 = torch.mean(pred0,dim=[1,2])
# xpred0 = np.correlate(xpred0, xpred0, mode='full')
#
# xpred1 = pred1[:,x1,x2]
# # xpred1 = torch.mean(pred1,dim=[1,2])
# xpred1 = np.correlate(xpred1, xpred1, mode='full')
#
# xpred2 = pred2[:,x1,x2]
# # xpred2 = torch.mean(pred2,dim=[1,2])
# xpred2 = np.correlate(xpred2, xpred2, mode='full')
#
# # xpred3 = pred3[:,x1,x2]
# # # xpred3 = torch.mean(pred3,dim=[1,2])
# # xpred3 = np.correlate(xpred3, xpred3, mode='full')
# #
# # xpred4 = pred4[:,x1,x2]
# # # xpred4 = torch.mean(pred4,dim=[1,2])
# # xpred4 = np.correlate(xpred4, xpred4, mode='full')
#
# fig, ax = plt.subplots(figsize=(10,10))
# # ax.set_yscale('log')
# t = np.array(range(-400,400-1))
# ax.plot(t,xtruth, label='truth')
# ax.plot(t,xpred0, label='l2')
# ax.plot(t,xpred1, label='h1')
# ax.plot(t,xpred2, label='h2')
# # ax.plot(t,xpred3, label='h2, 1/2')
# # ax.plot(t,xpred4, label='h2, 1/4')
# plt.title('temporal autocorrelation at x=('+str(x1)+', '+str(x2)+')')
# plt.ylim([1.5*np.min(xtruth), 1.5*np.max(xtruth)])
# leg = plt.legend(loc='best')
# leg.get_frame().set_alpha(0.5)
# plt.show()
#
#
# # averaged
# t = 400
# result_truth = np.zeros(t, )
# result_pred0 = np.zeros(t, )
# result_pred1 = np.zeros(t, )
# result_pred2 = np.zeros(t, )
# result_pred3 = np.zeros(t, )
# result_pred4 = np.zeros(t, )
# for x1 in range(s):
#     for x2 in range(s):
#         xtruth = truth[:,x1,x2]
#         result_truth += np.correlate(xtruth, xtruth, mode='full')[T-1:T-1+t]
#         xpred0 = pred0[:,x1,x2]
#         result_pred0 += np.correlate(xpred0, xpred0, mode='full')[T-1:T-1+t]
#         xpred1 = pred1[:,x1,x2]
#         result_pred1 += np.correlate(xpred1, xpred1, mode='full')[T-1:T-1+t]
#         xpred2 = pred2[:,x1,x2]
#         result_pred2 += np.correlate(xpred2, xpred2, mode='full')[T-1:T-1+t]
#         # xpred3 = pred3[:,x1,x2]
#         # result_pred3 += np.correlate(xpred3, xpred3, mode='full')[T-1:T-1+t]
#         # xpred4 = pred4[:,x1,x2]
#         # result_pred4 += np.correlate(xpred4, xpred4, mode='full')[T-1:T-1+t]
#
# result_truth = result_truth / s**2
# result_pred0 = result_pred0 / s**2
# result_pred1 = result_pred1 / s**2
# result_pred2 = result_pred2 / s**2
# # result_pred3 = result_pred3 / s**2
# # result_pred4 = result_pred4 / s**2
#
#
# fig, ax = plt.subplots(figsize=(10,10))
# # ax.set_yscale('log')
# ax.plot(result_truth, label='truth')
# ax.plot(result_pred0, label='l1')
# ax.plot(result_pred1, label='h1')
# ax.plot(result_pred2, label='h2')
# # ax.plot(result_pred3, label='h2, 1/2')
# # ax.plot(result_pred4, label='h2, 1/4')
# plt.title('autocorrelation, averaged')
#
# leg = plt.legend(loc='best')
# leg.get_frame().set_alpha(0.5)
# plt.show()

##############################################################
#  spatial auto-correlation
##############################################################
# from scipy import signal
#
# result_truth = 0
# result_pred0 = 0
# result_pred1 = 0
# result_pred2 = 0
# # result_pred3 = 0
# # result_pred4 = 0
# t = T-1
# for t in range(T):
#     xtruth = truth[t]
#     result_truth += signal.correlate2d(xtruth, xtruth, boundary='wrap', mode='same')
#     xpred0 = pred0[t]
#     result_pred0 += signal.correlate2d(xpred0, xpred0, boundary='wrap', mode='same')
#     xpred1 = pred1[t]
#     result_pred1 += signal.correlate2d(xpred1, xpred1, boundary='wrap', mode='same')
#     xpred2 = pred2[t]
#     result_pred2 += signal.correlate2d(xpred2, xpred2, boundary='wrap', mode='same')
#     # xpred3 = pred3[t]
#     # result_pred3 += signal.correlate2d(xpred3, xpred3, boundary='wrap', mode='same')
#     # xpred4 = pred4[t]
#     # result_pred4 += signal.correlate2d(xpred4, xpred4, boundary='wrap', mode='same')
#
# xtruth = xtruth/T
# xpred0 = xpred0/T
# xpred1 = xpred1/T
# xpred2 = xpred2/T
#
# fig, axs = plt.subplots(1, 4, figsize=(20,10))
# axs[0].imshow(result_truth)
# axs[0].set_title('truth')
# axs[1].imshow(result_pred0)
# axs[1].set_title('l2')
# axs[2].imshow(result_pred1)
# axs[2].set_title('h1')
# axs[3].imshow(result_pred2)
# axs[3].set_title('h2')
# # axs[4].imshow(result_pred3)
# # axs[4].set_title('h2, 1/2')
# # axs[5].imshow(result_pred4)
# # axs[5].set_title('h2, 1/4')
#
# # fig.colorbar(a)
#
# # leg = plt.legend(loc='best')
# # leg.get_frame().set_alpha(0.5)
# plt.show()

##############################################################
#  Kinetic enegry
##############################################################

##KE = < u^2 >
##TKE = 1/T * \int_T  (u - mean(u))^2 dt

# def KE(u):
#     T = u.shape[0]
#     u = u.reshape(T, 64*64*2)
#     return torch.mean(u**2, dim=1)
#
# def TKE(u):
#     T = u.shape[0]
#     u = u.reshape(T, 64*64*2)
#     umean = torch.mean(u, dim=0)
#     return torch.mean((u-umean)**2, dim=1)
#
#
# Etruth = KE(truth)
# Epred0 = KE(pred0)
# Epred1 = KE(pred1)
# Epred2 = KE(pred2)
# Epred3 = KE(pred3)
# Epred4 = KE(pred4)
# Epred5 = KE(pred5)
#
# print(Etruth.shape)
# print(torch.mean(Etruth))
# print(torch.mean(Epred0))
# print(torch.mean(Epred1))
# print(torch.mean(Epred2))
# print(torch.mean(Epred3))
# print(torch.mean(Epred4))
# print(torch.mean(Epred5))
#
#
# Etruth = TKE(truth)
# Epred0 = TKE(pred0)
# Epred1 = TKE(pred1)
# Epred2 = TKE(pred2)
# Epred3 = TKE(pred3)
# Epred4 = TKE(pred4)
# Epred5 = TKE(pred5)
# print(Etruth.shape)
# print(torch.mean(Etruth))
# print(torch.mean(Epred0))
# print(torch.mean(Epred1))
# print(torch.mean(Epred2))
# print(torch.mean(Epred3))
# print(torch.mean(Epred4))
# print(torch.mean(Epred5))
#
# # T = np.linspace(0,1000,1000)[1:401]
# #
# # fig = plt.figure(figsize=(20,10))
# # gs = fig.add_gridspec(nrows=3, ncols=1, hspace=0)
# # ax0 = fig.add_subplot(gs[0,0])
# # ax1 = fig.add_subplot(gs[1,0])
# # ax2 = fig.add_subplot(gs[2,0])
#
# # ax0.plot(T, Etruth[:400], 'black', label='Truth')
# # # ax0.plot(T, Etruth[400:800], 'black', label='')
# # # ax0.plot(T, Etruth[800:1200], 'black', label='')
# # ax1.plot(T, Epred2[:400], 'firebrick', label='MNO')
# # ax2.plot(T, Epred5[:400], 'royalblue', label='UNet')
#
# # leg = ax0.legend(loc='upper right')
# # leg.get_frame().set_alpha(0.5)
# # leg = ax1.legend(loc='upper right')
# # leg.get_frame().set_alpha(0.5)
# # leg = ax2.legend(loc='upper right')
# # leg.get_frame().set_alpha(0.5)
#
# # Hide x labels and tick labels for all but bottom plot.
#
# # ax2.label_outer()
# # # fig.suptitle('Kinetic energy')
# # ax2.set_xlabel('t (s)')
# # ax1.set_ylabel('Kinetic energy')
#
# # fig.show()
#
#
# ###histogram
#
# fig, ax = plt.subplots(figsize=(10,10))
# linewidth = 5
# sns.distplot(ax=ax, a=Epred5.reshape(-1), kde=True, hist=False, bins=50, label='UNet', kde_kws=dict(linewidth=linewidth,linestyle='--'))
# sns.distplot(ax=ax, a=Epred2.reshape(-1), kde=True, hist=False, bins=50, label='MNO', kde_kws=dict(linewidth=linewidth,linestyle='--'))
# sns.distplot(ax=ax, a=Etruth.reshape(-1), kde=True, hist=False, bins=50, label='truth', kde_kws=dict(linewidth=linewidth))
# # sns.distplot(pred3.reshape(-1),  kde=True, label='h2 1/2')
# # sns.distplot(pred4.reshape(-1),  kde=True, label='h2 1/4')
#
# # ax.set_xlim(0.55,0.8)
# # ax.set_ylim(0.05,0.175)
# # ax.set_yticks([0.05,0.10,0.15])
#
# plt.legend(prop={'size': 20})
# plt.xlabel('Kinetic energy')
# plt.ylabel('density')
# plt.show()


##############################################################
#  dissipation
##############################################################

# ### epi = v * < (\grad{u})^2 > = v * < w^2 >
# #
# Re = 40
#
# def dissipation(w):
#     T = w.shape[0]
#     w = w.reshape(T, 64*64)
#     return torch.mean(w**2, dim=1) / Re
#
#
#
# Etruth = dissipation(truth)
# Epred0 = dissipation(pred0)
# Epred1 = dissipation(pred1)
# Epred2 = dissipation(pred2)
# Epred3 = dissipation(pred3)
# Epred4 = dissipation(pred4)
# Epred5 = dissipation(pred5)
#
#
# truth = truth.numpy()
# pred0 = pred0.numpy()
# pred1 = pred1.numpy()
# pred2 = pred2.numpy()
# # pred3 = pred3.numpy()
# # pred4 = pred4.numpy()
#
# # plt.hist(truth.reshape(-1), bins=100)
# # plt.show()
# # plt.hist(pred0.reshape(-1), bins=100)
# # plt.show()
#
# fig, ax = plt.subplots(figsize=(10,10))
# # sns.distplot(ax=ax, a=Etruth.reshape(-1), kde=True,hist=False, bins=200, label='truth')
# # sns.distplot(ax=ax, a=Epred0.reshape(-1), kde=True,hist=False, bins=200, label='H0')
# # sns.distplot(ax=ax, a=Epred1.reshape(-1), kde=True,hist=False, bins=200, label='H1')
# # sns.distplot(ax=ax, a=Epred2.reshape(-1), kde=True,hist=False, bins=200, label='H2')
#
# linewidth = 5
# sns.distplot(ax=ax, a=Epred5.reshape(-1), kde=True, hist=False, bins=50, label='UNet', kde_kws=dict(linewidth=linewidth,linestyle='--'))
# sns.distplot(ax=ax, a=Epred2.reshape(-1), kde=True, hist=False, bins=50, label='MNO', kde_kws=dict(linewidth=linewidth,linestyle='--'))
# sns.distplot(ax=ax, a=Etruth.reshape(-1), kde=True, hist=False, bins=50, label='truth', kde_kws=dict(linewidth=linewidth))
#
# # sns.distplot(pred3.reshape(-1),  kde=True, label='h2 1/2')
# # sns.distplot(pred4.reshape(-1),  kde=True, label='h2 1/4')
#
# ax.set_xlim(0.05,0.15)
# # ax.set_ylim(0.05,0.175)
# # ax.set_yticks([0.05,0.10,0.15])
#
# plt.legend(prop={'size': 20})
# # plt.title('distribution')
# plt.xlabel('dissipation')
# plt.ylabel('density')
# plt.show()
#
# # print(Etruth.shape)
# # print(torch.mean(Etruth))
# # print(torch.mean(Epred0))
# # print(torch.mean(Epred1))
# # print(torch.mean(Epred2))
# # print(torch.mean(Epred3))
# # print(torch.mean(Epred4))
# # print(torch.mean(Epred5))
#
# print(Etruth.shape)
# print(torch.mean(Etruth))
# print(torch.mean(Epred0))
# print(torch.mean(Epred1))
# print(torch.mean(Epred2))
# print(torch.mean(Epred3))
# print(torch.mean(Epred4))
# print(torch.mean(Epred5))
