import pickle
import numpy as np
import torch

def data(path, args):
    N, N_train, N_test, raw, dataType = args.N, args.N_train, args.N_test, args.raw, args.dataType
    nmlz = args.normalize
    rawpath = path + 'trainset/' + str(dataType) + str(N) + '_' + str(N_train) + '_' + str(N_test) +'.dat'
    if raw:
        print("Start storing "+str.upper(dataType)+"...")
        fname = path + 'brain/'+'t1_ai_msles2_1mm_pn0_rf20.rawb'
        try:
            A = torch.Tensor(np.fromfile(fname, dtype="uint8"))
            A = A.reshape(181, 217, 181)
            index = np.random.choice(N, size = N_train + N_test, replace = False)
            A_train = A[index[:N_train], :, :]
            A_test = A[index[N_train:], :, :]
            if nmlz:
                for i in range(N_train):
                    A_temp = A_train[i, :, :]
                    s = torch.linalg.svdvals(A_temp)
                    A_train[i, :, :] /= s[0]
                for i in range(N_test):
                    A_temp = A_test[i, :, :]
                    s = torch.linalg.svdvals(A_temp)
                    A_test[:, :, :] /= s[0]
            with open(rawpath, 'wb') as f:
                pickle.dump([A_train, A_test], f)
                print("Storing finished.")
        except IOError:
            print("Error: path not exist!")
        except IndexError:
            print("Error: index out of range!")
    else:
        with open(rawpath, 'rb') as f:
            A_train, A_test = pickle.load(f) 
    return A_train, A_test
    
