import numpy as np
from torchvision import datasets
from empirical_tv_helper_function import *
from TV_estimation import *


mnist_dataset_train = datasets.MNIST(root='./data', train=True, download=True)
train_data_real, train_data_real_y = (mnist_dataset_train.train_data/255.).view(60000,784),mnist_dataset_train.targets
mnist_dataset_test = datasets.MNIST(root='./data', train=False, download=True)
test_data_real, test_data_real_y = (mnist_dataset_test.train_data/255.).view(10000,784), mnist_dataset_test.targets

Train_100_image, Train_100_y = torch.load('GAN_100_train.pth').values()
Train_300_image, Train_300_y = torch.load('GAN_300_train.pth').values()
Train_500_image, Train_500_y = torch.load('GAN_500_train.pth').values()

Test_100_image, Test_100_y = torch.load('GAN_100_Test.pth').values()
Test_300_image, Test_300_y = torch.load('GAN_300_Test.pth').values()
Test_500_image, Test_500_y = torch.load('GAN_500_Test.pth').values()

np.random.seed(1)
RP = np.random.normal(0,0.1,[784,35])

All_error = []
for i in range(10):
    for G in [100,300,500]:
        if G==100:
            Train_Syn_X = Train_100_image[Train_100_y==i]
            Test_Syn_X = Test_100_image[Test_100_y == i]
        if G==300:
            Train_Syn_X = Train_300_image[Train_300_y==i]
            Test_Syn_X = Test_300_image[Test_300_y == i]
        if G==500:
            Train_Syn_X = Train_500_image[Train_500_y==i]
            Test_Syn_X = Test_500_image[Test_500_y == i]
        Train_Real_X = train_data_real[train_data_real_y == i]
        Test_Real_X = test_data_real[test_data_real_y == i]
        Train_size = min(Train_Real_X.shape[0], Train_Syn_X.shape[0])
        Test_size = min(Test_Real_X.shape[0], Test_Syn_X.shape[0])
        x_train = np.concatenate([Train_Real_X[0:Train_size, :],
                                  Train_Syn_X[0:Train_size, :]]) @ RP
        x_test = np.concatenate([Test_Real_X[0:Test_size, :],
                                 Test_Syn_X[0:Test_size, :]]) @ RP
        y_train = np.concatenate([np.ones(Train_size), np.zeros(Train_size)])
        y_test = np.concatenate([np.ones(Test_size), np.zeros(Test_size)])
        x_real, x_syn = x_train[y_train == 1], x_train[y_train == 0]
        mu_1_bar, mu_2_bar = x_real.mean(axis=0), x_syn.mean(axis=0)
        Sigma_1_bar, Sigma_2_bar = np.cov(x_real, rowvar=False), np.cov(x_syn, rowvar=False)
        DisE = Dist_TV(x_train, x_test, y_train, y_test)
        KDE = KDE_TV(x_real, x_syn)
        PE = MC_TV_Baseline(mu_1_bar, Sigma_1_bar, mu_2_bar, Sigma_2_bar)
        All_error.append([i,G,DisE,KDE,PE])
        print(i,G)

import pandas as pd
Result = pd.DataFrame(All_error)
Result.to_csv('RP_50')
Result.columns = ['dig','G','DisE','KDE','PE']
Result.groupby(['G']).mean()
Result.groupby(['G']).std()









