import numpy as np
import torch
import torchvision.datasets as datasets
from torchvision import transforms
import tqdm
from utils import *
import time
train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(
                './data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()])
            ),
            batch_size=100,
            shuffle=True,
            num_workers=16,
        )
digits_0=[]
digits_1=[]
digits_2=[]
digits_3 =[]
digits_4 =[]
digits_5=[]
digits_6=[]
digits_7=[]
digits_8 =[]
digits_9 =[]
for batch_idx, (data, y) in tqdm.tqdm(enumerate(train_loader, start=0)):
    digits_0.append(data[np.where(y==0)])
    digits_1.append(data[np.where(y == 1)])
    digits_2.append(data[np.where(y == 2)])
    digits_3.append(data[np.where(y == 3)])
    digits_4.append(data[np.where(y == 4)])
    digits_5.append(data[np.where(y == 5)])
    digits_6.append(data[np.where(y == 6)])
    digits_7.append(data[np.where(y == 7)])
    digits_8.append(data[np.where(y == 8)])
    digits_9.append(data[np.where(y == 9)])
digits_0 = torch.cat(digits_0,dim=0)
digits_1 = torch.cat(digits_1,dim=0)
digits_2 = torch.cat(digits_2,dim=0)
digits_3 = torch.cat(digits_3,dim=0)
digits_4 = torch.cat(digits_4,dim=0)
digits_5 = torch.cat(digits_5,dim=0)
digits_6 = torch.cat(digits_6,dim=0)
digits_7 = torch.cat(digits_7,dim=0)
digits_8 = torch.cat(digits_8,dim=0)
digits_9 = torch.cat(digits_9,dim=0)
digits=[digits_0,digits_1,digits_2,digits_3,digits_4,digits_5,digits_6,digits_7,digits_8,digits_9]
# SW = np.zeros((10,10))

ind1=1
ind2=7
X0 = digits[ind1]
X1 = digits[ind2]
n = np.min([X0.shape[0],X1.shape[0]])
X0 = X0[:n].view(n,-1)
X1 = X1[:n].view(n,-1)
X0 = X0/torch.max(X0,dim=1,keepdim=True)[0]
X1 = X1/torch.max(X1,dim=1,keepdim=True)[0]
X0 = X0*255
X1 = X1*255
trueSW = SW(X0,X1,L=100000).mean()
mean_vars_SW=[]
vars_vars_SW=[]
mean_vars_controlSW=[]
vars_vars_controlSW=[]
mean_vars_upcontrolSW=[]
vars_vars_upcontrolSW=[]


mean_times_SW=[]
vars_times_SW=[]
mean_times_controlSW=[]
vars_times_controlSW=[]
mean_times_upcontrolSW=[]
vars_times_upcontrolSW=[]
for L in [100000]:
    print(L)
    var_SW=[]
    var_controlSW=[]
    var_simpleSW = []
    time_SW=[]
    time_controlSW=[]
    time_simpleSW = []
    time_upcontrolSW=[]
    var_upcontrolSW = []
    for seed in range(5):
        for _ in range(1000):
            a = np.random.randn(100)
            a = torch.randn(100)
        np.random.seed(seed)
        torch.manual_seed(seed)
        start =time.time()
        var_SW.append(torch.mean((SW(X0,X1,L=L) -trueSW)**2).cpu().detach().numpy())
        time_SW.append(time.time()-start)
        for _ in range(1000):
            a = np.random.randn(100)
            a = torch.randn(100)
        np.random.seed(seed)
        torch.manual_seed(seed)
        start = time.time()
        var_controlSW.append(torch.mean((Gaussian_controlled_SW(X0, X1, L=L) - trueSW) ** 2).cpu().detach().numpy())
        time_controlSW.append(time.time() - start)

        for _ in range(1000):
            a = np.random.randn(100)
            a = torch.randn(100)
        np.random.seed(seed)
        torch.manual_seed(seed)
        start = time.time()
        var_upcontrolSW.append(torch.mean((up_Gaussian_controlled_SW(X0, X1, L=L)- trueSW) ** 2).cpu().detach().numpy())
        time_upcontrolSW.append(time.time() - start)
    mean_vars_SW.append(np.mean(var_SW))
    vars_vars_SW.append(np.std(var_SW))
    mean_vars_controlSW.append(np.mean(var_controlSW))
    vars_vars_controlSW.append(np.std(var_controlSW))
    mean_vars_upcontrolSW.append(np.mean(var_upcontrolSW))
    vars_vars_upcontrolSW.append(np.std(var_upcontrolSW))

    mean_times_SW.append(np.mean(time_SW))
    vars_times_SW.append(np.std(time_SW))
    mean_times_controlSW.append(np.mean(time_controlSW))
    vars_times_controlSW.append(np.std(time_controlSW))
    mean_times_upcontrolSW.append(np.mean(time_upcontrolSW))
    vars_times_upcontrolSW.append(np.std(time_upcontrolSW))
np.savetxt("SW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_vars_SW), delimiter=",")
np.savetxt("SW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_vars_SW), delimiter=",")
np.savetxt("ControlSW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_vars_controlSW), delimiter=",")
np.savetxt("ControlSW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_vars_controlSW), delimiter=",")
np.savetxt("upControlSW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_vars_upcontrolSW), delimiter=",")
np.savetxt("upControlSW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_vars_upcontrolSW), delimiter=",")

np.savetxt("time_SW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_times_SW), delimiter=",")
np.savetxt("time_SW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_times_SW), delimiter=",")
np.savetxt("time_ControlSW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_times_controlSW), delimiter=",")
np.savetxt("time_ControlSW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_times_controlSW), delimiter=",")
np.savetxt("time_upControlSW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_times_upcontrolSW), delimiter=",")
np.savetxt("time_upControlSW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_times_upcontrolSW), delimiter=",")