import scipy.io as scio
import torch

def Selecsubmatrix(A,cr,cc):
    m, n = A.size()
    top_left_x = torch.randint(0, m - cr, (1,))
    top_left_y = torch.randint(0, n - cc, (1,))
    A = A[top_left_x:top_left_x + cr, top_left_y:top_left_y + cc]
    return A

torch.manual_seed(0)
'''sector_train'''
data=scio.loadmat('./sector_train.mat')
cr = 500
cc = 1000
selected_indices = torch.randint(0, data['x'].nnz, (cr*cc,1)).squeeze()
tensor_var = torch.tensor(data['x'].data[selected_indices]).view(cr, cc)
torch.save(tensor_var, './pt/sector_train.pt')
loaded_var = torch.load('./pt/sector_train.pt')

'''TDT2'''
data=scio.loadmat('./TDT2.mat')
cr = 1000
cc = 1000
selected_indices = torch.randint(0, data['fea'].nnz, (cr*cc,1)).squeeze()
tensor_var = torch.tensor(data['fea'].data[selected_indices]).view(cr, cc)
torch.save(tensor_var, './pt/TDT2.pt')
loaded_var = torch.load('./pt/TDT2.pt')

'''Cifar'''
data=scio.loadmat('./cifar.mat')
tensor_var = torch.tensor(data['data'])
cr = 1000
cc = 100

tensor_var = Selecsubmatrix(tensor_var,cr,cc)
torch.save(tensor_var, './pt/cifar.pt')
loaded_var = torch.load('./pt/cifar.pt')

'''Mnist'''
data=scio.loadmat('./mnist_uint8.mat')
tensor_var = torch.tensor(data['train_x'])
cr = 1000
cc = 780

tensor_var = Selecsubmatrix(tensor_var,cr,cc)
torch.save(tensor_var, './pt/mnist.pt')
loaded_var = torch.load('./pt/mnist.pt')

'''CnnCaltech'''
data=scio.loadmat('./cnn_4096d_Caltech.mat')
tensor_var = torch.tensor(data['x'])
cr = 2000
cc = 1000

tensor_var = Selecsubmatrix(tensor_var,cr,cc)
torch.save(tensor_var, './pt/CnnCaltech.pt')
loaded_var = torch.load('./pt/CnnCaltech.pt')

'''gisette'''
data=scio.loadmat('./gisette.mat')
tensor_var = torch.tensor(data['x'].toarray())
cr = 3000
cc = 1000

tensor_var = Selecsubmatrix(tensor_var,cr,cc)
torch.save(tensor_var, './pt/gisette.pt')
loaded_var = torch.load('./pt/gisette.pt')

'''w1a_train'''
data=scio.loadmat('./w1a_train.mat')
tensor_var = torch.tensor(data['x'].toarray())
cr = 2470
cc = 290

tensor_var = Selecsubmatrix(tensor_var,cr,cc)
torch.save(tensor_var, './pt/w1a.pt')
loaded_var = torch.load('./pt/w1a.pt')

'''randn-10-10'''
n = 10
tensor_var = torch.randn(n, n)
torch.save(tensor_var, './pt/randn-10-10.pt')
loaded_var = torch.load('./pt/randn-10-10.pt')

'''randn-100-100'''
n = 100
tensor_var = torch.randn(n, n)
torch.save(tensor_var, './pt/randn-100-100.pt')
loaded_var = torch.load('./pt/randn-100-100.pt')

'''randn-1000-1000'''
n = 1000
tensor_var = torch.randn(n, n)
torch.save(tensor_var, './pt/randn-1000-1000.pt')
loaded_var = torch.load('./pt/randn-1000-1000.pt')