import numpy as np
from torchvision import datasets
from empirical_tv_helper_function import *
from TV_estimation import *
import torchvision.models as models
import torch.nn as nn


mnist_dataset_train = datasets.MNIST(root='./data', train=True, download=True)
train_data_real, train_data_real_y = (mnist_dataset_train.train_data/255.).view(60000,784),mnist_dataset_train.targets
mnist_dataset_test = datasets.MNIST(root='./data', train=False, download=True)
test_data_real, test_data_real_y = (mnist_dataset_test.train_data/255.).view(10000,784), mnist_dataset_test.targets

Train_100_image, Train_100_y = torch.load('GAN_100_train.pth').values()
Train_300_image, Train_300_y = torch.load('GAN_300_train.pth').values()
Train_500_image, Train_500_y = torch.load('GAN_500_train.pth').values()

Test_100_image, Test_100_y = torch.load('GAN_100_Test.pth').values()
Test_300_image, Test_300_y = torch.load('GAN_300_Test.pth').values()
Test_500_image, Test_500_y = torch.load('GAN_500_Test.pth').values()


Model = models.resnet18(pretrained=True)
def change_layers(model):
    model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model.fc = nn.Linear(512, 50, bias=True)
    return model
model = change_layers(Model)


All_error = []
for i in range(10):
    for G in [100,300,500]:
        if G==100:
            Train_Syn_X = Train_100_image[Train_100_y==i]
            Test_Syn_X = Test_100_image[Test_100_y == i]
        if G==300:
            Train_Syn_X = Train_300_image[Train_300_y==i]
            Test_Syn_X = Test_300_image[Test_300_y == i]
        if G==500:
            Train_Syn_X = Train_500_image[Train_500_y==i]
            Test_Syn_X = Test_500_image[Test_500_y == i]
        Train_Real_X = train_data_real[train_data_real_y == i]
        Test_Real_X = test_data_real[test_data_real_y == i]
        Train_size = min(Train_Real_X.shape[0], Train_Syn_X.shape[0])
        Test_size = min(Test_Real_X.shape[0], Test_Syn_X.shape[0])
        X_train = np.concatenate([Train_Real_X[0:Train_size, :],
                                  Train_Syn_X[0:Train_size, :]])
        X_test = np.concatenate([Test_Real_X[0:Test_size, :],
                                 Test_Syn_X[0:Test_size, :]])
        x_train = model(torch.tensor(X_train.reshape(-1,1,28,28)))
        x_test = model(torch.tensor(X_test.reshape(-1,1,28,28)))
        y_train = np.concatenate([np.ones(Train_size), np.zeros(Train_size)])
        y_test = np.concatenate([np.ones(Test_size), np.zeros(Test_size)])
        x_real, x_syn = x_train[y_train == 1], x_train[y_train == 0]
        mu_1_bar, mu_2_bar = x_real.detach().numpy().mean(axis=0), x_syn.detach().numpy().mean(axis=0)
        Sigma_1_bar, Sigma_2_bar = np.cov(x_real.detach().numpy(), rowvar=False), np.cov(x_syn.detach().numpy(), rowvar=False)
        DisE = Dist_TV(x_train.detach().numpy(), x_test.detach().numpy(), y_train, y_test)
        KDE = KDE_TV(x_real.detach().numpy(), x_syn.detach().numpy())
        PE = MC_TV_Baseline(mu_1_bar, Sigma_1_bar, mu_2_bar, Sigma_2_bar)
        All_error.append([i,G,DisE,KDE,PE])
        print(i,G)

import pandas as pd
Result = pd.DataFrame(All_error)
Result.to_csv('Resnet18_50')

Result.columns = ['dig','G','DisE','KDE','PE']
Result.groupby(['G']).mean()
Result.groupby(['G']).std()