import os
import sys
import numpy as np
import glob 
sys.path.append("../config/")
sys.path.append("config/")
sys.path.append("loader/")
import config_path
import dataset_info
sys.path.append("../methods/emmix")
import sp_tensor
from tabulate import tabulate

def load_data_real(dataset_name, tvt="train", normalize=True, check_empty=False):
    assert tvt in ["train", "valid", "test"], "tvt need to be train/valid/test"
    
    if not(dataset_name in dataset_info.real_datasets_list):
        error_message = f"please chose one of exsiting real datasets name {dataset_info.real_datasets_list}"
        raise NameError(error_message)
    
    coords_path = os.path.join(config_path.data_repo_real, dataset_name,f"X_{tvt}_coords.npy")
    values_path = os.path.join(config_path.data_repo_real, dataset_name,f"X_{tvt}_values.npy")
    
    coords = np.load(coords_path)
    values = np.load(values_path)

    tensor_size = dataset_info.tensor_sizes[dataset_name]

    T = sp_tensor.Sp_tensor(coords, values, tensor_size, normalize=normalize, check_empty=False)
    return T

def load_data_jes(N=6000, tvt="train", normalize=True, check_empty=False):
    assert tvt in ["train", "valid", "test"], "tvt need to be train/valid/test"
    
    #possible_N = [ str.split(path, "/")[-2] for path in glob.glob(f"{config_path.data_repo_jes}/*0/")]
    #assert N in possible_N or str(N) in possible_N, "No N"
    
    coords_path = os.path.join(config_path.data_repo_jes, f"{N}", f"X_{tvt}_coords.npy")
    values_path = os.path.join(config_path.data_repo_jes, f"{N}", f"X_{tvt}_values.npy")

    coords = np.load(coords_path)
    values = np.load(values_path)

    tensor_size = tuple( 5 for d in range(np.shape(coords)[1]))

    T = sp_tensor.Sp_tensor(coords, values, tensor_size, normalize=normalize, check_empty=False)
    return T


def show_dataset_detail(fullsize=False):
    nnz   = dict()
    shape = dict()
    sizes = dict()
    dims  = dict()
    datas = []
    for dataset_name in dataset_info.real_datasets_list:
        T = load_data_real(dataset_name)
        data = dict()

        data["Name"] = dataset_name
        data["Tensor dim"] = T.tensor_dim
        data["NNZ"] = T.nnz
        data["Tensor size"] = np.prod(T.tensor_size)
        if fullsize:
            data["Tensor shape"] = T.tensor_size
        data["mean size"] = np.sum(T.tensor_size) / T.tensor_dim
        datas.append(data)
        
    print(tabulate(datas, headers='keys', floatfmt="10.0f"))