import numpy as np
import torch
import torchvision.datasets as datasets
from torchvision import transforms
import tqdm
from utils import *
import time

# SW = np.zeros((10,10))
A = np.load("reconstruct_random_50_shapenetcore55.npy")
## 33 32 2 42
ind1=33
ind2=32
target=A[ind1]*2+10
source=A[ind2]*2-10

X0 = torch.from_numpy(source)
X1 = torch.from_numpy(target)
n = np.min([X0.shape[0],X1.shape[0]])
X0 = X0[:n].view(n,-1)
X1 = X1[:n].view(n,-1)
trueSW = SW(X0,X1,L=100000).mean()
mean_vars_SW=[]
vars_vars_SW=[]
mean_vars_controlSW=[]
vars_vars_controlSW=[]


mean_vars_upcontrolSW=[]
vars_vars_upcontrolSW=[]


mean_times_SW=[]
vars_times_SW=[]
mean_times_controlSW=[]
vars_times_controlSW=[]
mean_times_simpleSW=[]
vars_times_simpleSW=[]
mean_times_upcontrolSW=[]
vars_times_upcontrolSW=[]
for L in [100000]:
    print(L)
    var_SW=[]
    var_controlSW=[]
    time_SW=[]
    time_controlSW=[]
    time_upcontrolSW=[]
    var_upcontrolSW = []
    for seed in range(5):
        for _ in range(1000):
            a = np.random.randn(100)
            a = torch.randn(100)
        np.random.seed(seed)
        torch.manual_seed(seed)
        start =time.time()
        var_SW.append( torch.mean((SW(X0,X1,L=L)-trueSW)**2).cpu().detach().numpy())
        time_SW.append(time.time()-start)
        for _ in range(1000):
            a = np.random.randn(100)
            a = torch.randn(100)
        np.random.seed(seed)
        torch.manual_seed(seed)
        start = time.time()
        var_controlSW.append(torch.mean((Gaussian_controlled_SW(X0, X1, L=L) - trueSW) ** 2).cpu().detach().numpy())
        time_controlSW.append(time.time() - start)


        for _ in range(1000):
            a = np.random.randn(100)
            a = torch.randn(100)
        np.random.seed(seed)
        torch.manual_seed(seed)
        start = time.time()
        var_upcontrolSW.append(torch.mean((up_Gaussian_controlled_SW(X0, X1, L=L) - trueSW) ** 2).cpu().detach().numpy())
        time_upcontrolSW.append(time.time() - start)
    mean_vars_SW.append(np.mean(var_SW))
    vars_vars_SW.append(np.std(var_SW))
    mean_vars_controlSW.append(np.mean(var_controlSW))
    vars_vars_controlSW.append(np.std(var_controlSW))
    mean_vars_upcontrolSW.append(np.mean(var_upcontrolSW))
    vars_vars_upcontrolSW.append(np.std(var_upcontrolSW))

    mean_times_SW.append(np.mean(time_SW))
    vars_times_SW.append(np.std(time_SW))
    mean_times_controlSW.append(np.mean(time_controlSW))
    vars_times_controlSW.append(np.std(time_controlSW))
    mean_times_upcontrolSW.append(np.mean(time_upcontrolSW))
    vars_times_upcontrolSW.append(np.std(time_upcontrolSW))
np.savetxt("points_SW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_vars_SW), delimiter=",")
np.savetxt("points_SW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_vars_SW), delimiter=",")
np.savetxt("points_ControlSW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_vars_controlSW), delimiter=",")
np.savetxt("points_ControlSW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_vars_controlSW), delimiter=",")
np.savetxt("points_upControlSW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_vars_upcontrolSW), delimiter=",")
np.savetxt("points_upControlSW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_vars_upcontrolSW), delimiter=",")

np.savetxt("points_time_SW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_times_SW), delimiter=",")
np.savetxt("points_time_SW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_times_SW), delimiter=",")
np.savetxt("points_time_ControlSW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_times_controlSW), delimiter=",")
np.savetxt("points_time_ControlSW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_times_controlSW), delimiter=",")
np.savetxt("points_time_upControlSW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_times_upcontrolSW), delimiter=",")
np.savetxt("points_time_upControlSW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_times_upcontrolSW), delimiter=",")

