import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, SubsetRandomSampler

from scipy.stats import entropy
import numpy as np

from classifier import Inception, InceptionBlock, Flatten, find_model
from inception_score import inception_score
from fid import fid_score
from utils import MyDataset, draw_figure, show_figure
import matplotlib.pyplot as plt



real_dataset = torch.load('./e1/driving_data.pth')
# real_dataset = torch.load('./e1/stock_data.pth')
# real_dataset = torch.load('./e1/weather_data.pth')
# real_dataset = torch.load('./e1/solar_data.pth')
# real_dataset = torch.load('./e1/mixed_dataset.pth')



# for i in range (21):
#     fake_data = np.load(f'./e1/2to0/fake_data_tau{50*i}.npy') # (B,L,C)
#     fake_data = np.transpose(fake_data, (0, 2, 1))
#     example = fake_data[0]
#     plt.figure()
#     draw_figure(example)
#     plt.savefig(f'./figure4/2to0/tau{50*i}.pdf', format='pdf')
#     plt.close()
#     np.save(f'./figure4/2to0/tau{50*i}', example)
# assert 1==2

fake_data = np.load('./FigureB1/timeDiT/driving_cycle.npy') # (B,L,C)

# fake_data = fake_data[:5000]

# print(fake_data.shape)
# for i in range(10):
#     example = np.transpose(real_dataset[i][0].cpu().numpy(), (1,0))
#     plt.figure()
#     draw_figure(example)
#     plt.savefig(f'./Figure_2a/real_data/real_data{i}.pdf', format='pdf')
#     plt.close()
#     np.save(f'./Figure_2a/real_data/real_data{i}', example)


# fake_data = np.transpose(fake_data, (0, 2, 1))
# for i in range(10):
#     example = fake_data[i]
#     plt.figure()
#     draw_figure(example)
#     plt.savefig(f'./Figure_2a/mask/mask{i}.pdf', format='pdf')
#     plt.close()
#     np.save(f'./Figure_2a/mask/mask{i}', example)

# assert 1==2
# print(fake_data.shape)
# fake_data = np.transpose(fake_data, (0, 2, 1))

# draw_figure(fake_data[0])
# show_figure()
# assert 1==2

# fake_data = np.transpose(fake_data, (0, 2, 1))
fake_dataset = MyDataset(fake_data, [0] * len(fake_data))
fake_data = torch.tensor(fake_data)
fake_data = fake_data.to('cuda').float()
real_dataset, _ = random_split(real_dataset, [len(fake_data), len(real_dataset) - len(fake_data)])
x_list = []
for x, _ in real_dataset:
    x_list.append(x)
real_data = torch.stack(x_list).to('cuda')

real_dataset2, _ = random_split(real_dataset, [len(fake_data), len(real_dataset) - len(fake_data)])
# ------------------------- create Identity sampler -------------------------------------------
# x_arrays = []
# for sample in real_dataset2:
#     x = sample[0]  # 提取样本中的 x
#     x_array = x.numpy()  # 将 x 转换为 NumPy 数组
#     x_arrays.append(x_array)  # 将 x 数组添加到列表中

# # 将所有的 x 数组组合成一个大的数组
# real_data = np.array(x_arrays)
# real_data = torch.tensor(real_data).to('cuda')

# ------------------------load classifier --------------------------------
classifier = nn.Sequential(                  # input_size = （B，C，L）
                    InceptionBlock(
                        in_channels=1, 
                        n_filters=32, 
                        kernel_sizes=[5, 11, 23],
                        bottleneck_channels=32,
                        use_residual=True,
                        activation=nn.ReLU()
                    ),
                    InceptionBlock(
                        in_channels=32*4, 
                        n_filters=32, 
                        kernel_sizes=[5, 11, 23],
                        bottleneck_channels=32,
                        use_residual=True,
                        activation=nn.ReLU()
                    ),
                    nn.AdaptiveAvgPool1d(output_size=1),
                    Flatten(),
                    nn.Linear(in_features=4*32*1, out_features=4)
        ).cuda()

model_path = f"../classifier/model/e1_and_e2.pt"
state_dict = find_model(model_path)
classifier.load_state_dict(state_dict)

classifier.eval()

### ----------------------calculate matrics --------------------------------###
# IS
# scores_mean,scores_std,fidelity,diversity = inception_score(classifier, fake_data, batch_size=2)

print("IS:", inception_score(classifier, fake_data, batch_size=32))
# # FID

fid_list = []
for _ in range(20):

    _fid_score = fid_score(classifier, real_dataset, fake_dataset)
    fid_list.append(_fid_score)

fid_score_ = np.mean(fid_list)
print("FID score:", fid_score_)


# fid_list = []
# for i in range(20):
#     _fid_score = fid_score(classifier, real_dataset, fake_dataset)
#     fid_list.append(_fid_score)


# print(np.mean(fid_list), np.std(fid_list))
