import numpy as np
import torch
import torchvision.datasets as datasets
from torchvision import transforms
import tqdm
from utils import *
import time

# SW = np.zeros((10,10))
A = np.load("reconstruct_random_50_shapenetcore55.npy")
## 33 32
ind1=33# 2
ind2=32# 42
target=A[ind1]*2+10
source=A[ind2]*2-10

X0 = torch.from_numpy(source)
X1 = torch.from_numpy(target)
n = np.min([X0.shape[0],X1.shape[0]])
X0 = X0[:n].view(n,-1)
X1 = X1[:n].view(n,-1)
trueSW = SW(X0,X1,L=100000)
print(trueSW.shape)
trueSW=trueSW.mean()
print(trueSW)
mean_vars_SW=[]
vars_vars_SW=[]
mean_vars_controlSW=[]
vars_vars_controlSW=[]
mean_vars_upcontrolSW=[]
vars_vars_upcontrolSW=[]

mean_vars_controlSW2=[]
vars_vars_controlSW2=[]
mean_vars_upcontrolSW2=[]
vars_vars_upcontrolSW2=[]


mean_times_SW=[]
vars_times_SW=[]
mean_times_controlSW=[]
vars_times_controlSW=[]
mean_times_upcontrolSW=[]
vars_times_upcontrolSW=[]
mean_times_controlSW2=[]
vars_times_controlSW2=[]
mean_times_upcontrolSW2=[]
vars_times_upcontrolSW2=[]
Ls = np.array([2,5,10,50,100,500,1000,5000,10000])
for L in Ls:
    print(L)
    var_SW=[]
    time_SW=[]
    var_controlSW=[]
    time_controlSW=[]
    var_upcontrolSW = []
    time_upcontrolSW=[]
    var_controlSW2 = []
    time_controlSW2 = []
    var_upcontrolSW2 = []
    time_upcontrolSW2 = []
    for seed in range(5):
        for _ in range(1000):
            a = np.random.randn(100)
            a = torch.randn(100)
        np.random.seed(seed)
        torch.manual_seed(seed)
        start =time.time()
        var_SW.append( torch.abs(SW(X0,X1,L=L) -trueSW).cpu().detach().numpy())
        time_SW.append(time.time()-start)
        for _ in range(1000):
            a = np.random.randn(100)
            a = torch.randn(100)
        np.random.seed(seed)
        torch.manual_seed(seed)
        start = time.time()
        var_controlSW.append(torch.abs(Gaussian_controlled_SW(X0, X1, L=L) - trueSW).cpu().detach().numpy())
        time_controlSW.append(time.time() - start)

        for _ in range(1000):
            a = np.random.randn(100)
            a = torch.randn(100)
        np.random.seed(seed)
        torch.manual_seed(seed)
        start = time.time()
        var_upcontrolSW.append(torch.abs(up_Gaussian_controlled_SW(X0, X1, L=L) - trueSW).cpu().detach().numpy())
        time_upcontrolSW.append(time.time() - start)

        for _ in range(1000):
            a = np.random.randn(100)
            a = torch.randn(100)
        np.random.seed(seed)
        torch.manual_seed(seed)
        start = time.time()
        var_controlSW2.append(torch.abs(Gaussian_controlled_SW2(X0, X1, L=L) - trueSW).cpu().detach().numpy())
        time_controlSW2.append(time.time() - start)

        for _ in range(1000):
            a = np.random.randn(100)
            a = torch.randn(100)
        np.random.seed(seed)
        torch.manual_seed(seed)
        start = time.time()
        var_upcontrolSW2.append(torch.abs(up_Gaussian_controlled_SW2(X0, X1, L=L) - trueSW).cpu().detach().numpy())
        time_upcontrolSW2.append(time.time() - start)
    mean_vars_SW.append(np.mean(var_SW))
    vars_vars_SW.append(np.std(var_SW))
    mean_vars_controlSW.append(np.mean(var_controlSW))
    vars_vars_controlSW.append(np.std(var_controlSW))
    mean_vars_upcontrolSW.append(np.mean(var_upcontrolSW))
    vars_vars_upcontrolSW.append(np.std(var_upcontrolSW))
    mean_vars_controlSW2.append(np.mean(var_controlSW2))
    vars_vars_controlSW2.append(np.std(var_controlSW2))
    mean_vars_upcontrolSW2.append(np.mean(var_upcontrolSW2))
    vars_vars_upcontrolSW2.append(np.std(var_upcontrolSW2))

    mean_times_SW.append(np.mean(time_SW))
    vars_times_SW.append(np.std(time_SW))
    mean_times_controlSW.append(np.mean(time_controlSW))
    vars_times_controlSW.append(np.std(time_controlSW))
    mean_times_upcontrolSW.append(np.mean(time_upcontrolSW))
    vars_times_upcontrolSW.append(np.std(time_upcontrolSW))
    mean_times_controlSW2.append(np.mean(time_controlSW2))
    vars_times_controlSW2.append(np.std(time_controlSW2))
    mean_times_upcontrolSW2.append(np.mean(time_upcontrolSW2))
    vars_times_upcontrolSW2.append(np.std(time_upcontrolSW2))
np.savetxt("points_SW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_vars_SW), delimiter=",")
np.savetxt("points_SW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_vars_SW), delimiter=",")
np.savetxt("points_ControlSW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_vars_controlSW), delimiter=",")
np.savetxt("points_ControlSW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_vars_controlSW), delimiter=",")
np.savetxt("points_upControlSW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_vars_upcontrolSW), delimiter=",")
np.savetxt("points_upControlSW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_vars_upcontrolSW), delimiter=",")
np.savetxt("points_ControlSW2_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_vars_controlSW2), delimiter=",")
np.savetxt("points_ControlSW2_{}_{}_std.txt".format(ind1,ind2), np.array(vars_vars_controlSW2), delimiter=",")
np.savetxt("points_upControlSW2_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_vars_upcontrolSW2), delimiter=",")
np.savetxt("points_upControlSW2_{}_{}_std.txt".format(ind1,ind2), np.array(vars_vars_upcontrolSW2), delimiter=",")


np.savetxt("points_time_SW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_times_SW), delimiter=",")
np.savetxt("points_time_SW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_times_SW), delimiter=",")
np.savetxt("points_time_ControlSW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_times_controlSW), delimiter=",")
np.savetxt("points_time_ControlSW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_times_controlSW), delimiter=",")
np.savetxt("points_time_upControlSW_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_times_upcontrolSW), delimiter=",")
np.savetxt("points_time_upControlSW_{}_{}_std.txt".format(ind1,ind2), np.array(vars_times_upcontrolSW), delimiter=",")
np.savetxt("points_time_ControlSW2_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_times_controlSW2), delimiter=",")
np.savetxt("points_time_ControlSW2_{}_{}_std.txt".format(ind1,ind2), np.array(vars_times_controlSW2), delimiter=",")
np.savetxt("points_time_upControlSW2_{}_{}_mean.txt".format(ind1,ind2), np.array(mean_times_upcontrolSW2), delimiter=",")
np.savetxt("points_time_upControlSW2_{}_{}_std.txt".format(ind1,ind2), np.array(vars_times_upcontrolSW2), delimiter=",")
