import os
import warnings
import numpy as np
# warnings.filterwarnings('ignore', category=UserWarning, module='tensorflow')

from discriminative_metrics import discriminative_score_metrics
from predictive_metrics import predictive_score_metrics
from visualization_metrics import visualization

np.random.seed(123)
# real_data = np.load('./real_data/driving_data.npy')
real_data = np.load('./real_data/repeat_stock_data.npy')
# real_data = np.load('./real_data/weather_data.npy')
# real_data = np.load('./real_data/solar_data.npy')
# real_data = np.load('./real_data/mixed_data.npy')


# fake_data = np.load('./FigureB1/timeDiT/driving_cycle.npy') # (B,L,C)
fake_data = np.load('./FigureB1/timeDiT/stocks.npy') # (B,L,C)
# fake_data = np.load('./FigureB1/timeDiT/weather.npy') # (B,L,C)
# fake_data = np.load('./FigureB1/timeDiT/solar.npy') # (B,L,C)
# fake_data = np.load('./FigureB1/timeDiT/mix.npy') # (B,L,C)
fake_data = np.transpose(fake_data, (0, 2, 1))
# print(fake_data.shape)
# assert 1==2

np.random.shuffle(real_data)
if len(fake_data) >= 2000:
	fake_data = fake_data[:2000]
real_data = real_data[:len(fake_data)]

# discriminative_score_list = []
# predictive_score_list = []

# for _ in range(10):
# 	temp_disc = discriminative_score_metrics(real_data, fake_data)
# 	discriminative_score_list.append(temp_disc)

# 	temp_pred = predictive_score_metrics(real_data, fake_data)
# 	predictive_score_list.append(temp_pred)

# discriminative_score = np.mean(discriminative_score_list)
# std1 = np.std(discriminative_score_list)
# predictive_score = np.mean(predictive_score_list)
# std2 = np.std(predictive_score)
# print("discriminative_score:", discriminative_score, std1)
# print('predictive_score:', predictive_score, std2)

visualization(real_data, fake_data, 'pca')
visualization(real_data, fake_data, 'tsne')