import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, SubsetRandomSampler

from scipy.stats import entropy
import numpy as np

from classifier import Inception, InceptionBlock, Flatten, find_model
from inception_score import inception_score
from fid import fid_score
from utils import MyDataset, draw_figure, show_figure
import matplotlib.pyplot as plt



# real_dataset = torch.load('./argo/argo2_dataset.pth')

fake_data = np.load('./argo/DiffTS.npy') # (B,L,C)


# fake_data = np.transpose(fake_data, (0, 2, 1))


B, L, C = fake_data.shape
mse_x = []
mse_y = []
for i in range(B):

    for j in range(L-1):
        x = fake_data[i][j][0]
        y = fake_data[i][j][1]
        vx = fake_data[i][j][3]
        vy = fake_data[i][j][4]

        _x = fake_data[i][j+1][0]
        _y = fake_data[i][j+1][1]


        x_pre = x + 0.1*vx
        y_pre = y + 0.1*vy



        mse_x.append((_x - x_pre) ** 2)
        mse_y.append((_y - y_pre) ** 2)
mse = np.mean([np.mean(mse_x), np.mean(mse_y)])

print(mse)
assert 1==2





# fake_data = np.transpose(fake_data, (0, 2, 1))
fake_dataset = MyDataset(fake_data, [0] * len(fake_data))
fake_data = torch.tensor(fake_data)
fake_data = fake_data.to('cuda').float()
real_dataset, _ = random_split(real_dataset, [len(fake_data), len(real_dataset) - len(fake_data)])
x_list = []
for x, _ in real_dataset:
    x_list.append(x)
real_data = torch.stack(x_list).to('cuda')

real_dataset2, _ = random_split(real_dataset, [len(fake_data), len(real_dataset) - len(fake_data)])






# ------------------------load classifier --------------------------------
classifier = nn.Sequential(                  # input_size = （B，C，L）
                    InceptionBlock(
                        in_channels=5, 
                        n_filters=32, 
                        kernel_sizes=[5, 11, 23],
                        bottleneck_channels=32,
                        use_residual=True,
                        activation=nn.ReLU()
                    ),
                    InceptionBlock(
                        in_channels=32*4, 
                        n_filters=32, 
                        kernel_sizes=[5, 11, 23],
                        bottleneck_channels=32,
                        use_residual=True,
                        activation=nn.ReLU()
                    ),
                    nn.AdaptiveAvgPool1d(output_size=1),
                    Flatten(),
                    nn.Linear(in_features=4*32*1, out_features=6)
        ).cuda()

model_path = f"../classifier/model/370.pt"
state_dict = find_model(model_path)
classifier.load_state_dict(state_dict)

classifier.eval()

### ----------------------calculate matrics --------------------------------###
# IS
# scores_mean,scores_std,fidelity,diversity = inception_score(classifier, fake_data, batch_size=2)

print("IS:", inception_score(classifier, fake_data, batch_size=32))
# # FID

fid_list = []
for _ in range(5):

    _fid_score = fid_score(classifier, real_dataset, fake_dataset)
    fid_list.append(_fid_score)

fid_score_ = np.mean(fid_list)
print("FID score:", fid_score_)


# fid_list = []
# for i in range(20):
#     _fid_score = fid_score(classifier, real_dataset, fake_dataset)
#     fid_list.append(_fid_score)


# print(np.mean(fid_list), np.std(fid_list))
