import numpy as np
import os
import h5py
import pickle
import torch

def dataJoin(rawdir, N=14):
    print(str(1).zfill(4)+'.mat')
    fname = rawdir +str(1).zfill(4)+'.mat'
    if os.path.exists(fname):
        f = h5py.File(fname, 'r')
        A = f['hypercube'][:]
    for i in range(2, N):
        fname = rawdir + str(i).zfill(4)+'.mat'
        if os.path.exists(fname):
            print(str(i).zfill(4)+'.mat')
            f = h5py.File(fname, 'r')
            A = np.concatenate((A, f['hypercube'][:]), 0)
    return A

def data(path, args):
    N, N_train, N_test, raw, dataType = args.N, args.N_train, args.N_test, args.raw, args.dataType
    nmlz = args.normalize
    rawpath = path + 'trainset/' + str(dataType) + str(N) + '_' + str(N_train) + '_' + str(N_test) +'.dat'
    if raw:
        print("Start storing "+str.upper(dataType)+"...")
        try:
            A = torch.Tensor(dataJoin(path + "hyperspectral/")[:N, :, :])
            index = np.random.choice(N, size = N_train + N_test, replace = False)
            A_train = A[index[:N_train], :, :]
            A_test = A[index[N_train:], :, :]
            if nmlz:
                for i in range(N_train):
                    A_temp = A_train[i, :, :]
                    s = torch.linalg.svdvals(A_temp)
                    A_train[i, :, :] /= s[0]
                for i in range(N_test):
                    A_temp = A_test[i, :, :]
                    s = torch.linalg.svdvals(A_temp)
                    A_test[i, :, :] /= s[0]
            with open(rawpath, 'wb') as f:
                pickle.dump([A_train, A_test], f)
                print("Storing finished.")
        except IOError:
            print("Error: path not exist!")
        except IndexError:
            print("Error: index out of range!")
    else:
        with open(rawpath, 'rb') as f:
            A_train, A_test = pickle.load(f)
    
    return A_train, A_test

