import numpy as np
import torch
import torchvision.datasets as datasets
from torchvision import transforms
import tqdm
from utils import *
import time

# SW = np.zeros((10,10))
A = np.load("reconstruct_random_100_shapenetcore55.npy")
ind1=98
ind2= 99
target=A[ind2]
source=A[ind1]
# source=torch.randn(target.shape)
# source = source/torch.sqrt(torch.sum(source**2,dim=1,keepdim=True))
seed=2023
X1 = torch.from_numpy(target)
X0 = torch.from_numpy(source)
np.random.seed(seed)
torch.manual_seed(seed)
trueSW = SW(X0,X1,L=100000)
SWs=[]
NQSWs=[]
QSWs=[]
SQSWs=[]
ODQSWs = []
OCQSWs = []
for L in [10,100,500,1000,2000,5000,10000]:
    print(L)


    np.random.seed(seed)
    torch.manual_seed(seed)
    start =time.time()
    SWs.append(   (torch.abs(SW(X0,X1,L=L) -trueSW)).cpu().detach().numpy() )

        ####
    for _ in range(1000):
        a = np.random.randn(100)
        a = torch.randn(100)
    np.random.seed(seed)
    torch.manual_seed(seed)
    start =time.time()
    NQSWs.append( (torch.abs(QSW(X0,X1,L=L,type='nqsw') -trueSW)).cpu().detach().numpy())
    ####
    for _ in range(1000):
        a = np.random.randn(100)
        a = torch.randn(100)
    np.random.seed(seed)
    torch.manual_seed(seed)
    start = time.time()
    QSWs.append( (torch.abs(QSW(X0,X1,L=L,type='qsw') -trueSW)).cpu().detach().numpy())
    ####
    for _ in range(1000):
        a = np.random.randn(100)
        a = torch.randn(100)
    np.random.seed(seed)
    torch.manual_seed(seed)
    start = time.time()
    SQSWs.append( (torch.abs(QSW(X0,X1,L=L,type='sqsw') -trueSW)).cpu().detach().numpy())

    for _ in range(1000):
        a = np.random.randn(100)
        a = torch.randn(100)
    np.random.seed(seed)
    torch.manual_seed(seed)
    start = time.time()
    ODQSWs.append( (torch.abs(QSW(X0,X1,L=L,type='odqsw') -trueSW)).cpu().detach().numpy())

    for _ in range(1000):
        a = np.random.randn(100)
        a = torch.randn(100)
    np.random.seed(seed)
    torch.manual_seed(seed)
    start = time.time()
    OCQSWs.append( (torch.abs(QSW(X0,X1,L=L,type='ocqsw') -trueSW)).cpu().detach().numpy())






np.savetxt("pointcloud/SW_{}_{}.txt".format(ind1,ind2), np.array(SWs), delimiter=",")
np.savetxt("pointcloud/NQSW_{}_{}.txt".format(ind1,ind2), np.array(NQSWs), delimiter=",")
np.savetxt("pointcloud/QSW_{}_{}.txt".format(ind1,ind2), np.array(QSWs), delimiter=",")
np.savetxt("pointcloud/SQSW_{}_{}.txt".format(ind1,ind2), np.array(SQSWs), delimiter=",")
np.savetxt("pointcloud/ODQSW_{}_{}.txt".format(ind1,ind2), np.array(ODQSWs), delimiter=",")
np.savetxt("pointcloud/OCQSW_{}_{}.txt".format(ind1,ind2), np.array(OCQSWs), delimiter=",")








