import numpy as np
import os
import sys

sys.path.append("../")
sys.path.append("../config/")
import config_path
import config_syn

import utils_exp as ue

def sample_from_JPMF(P, sample_size):
    # check the normalizing condition
    assert np.abs( np.sum(P) - 1 ) < 1.0e-3, "P is not normalized"
    
    # The number of element in tensor
    AbsOmega = np.size(P)
    tensor_size = np.shape(P)

    # vectorized given tensor
    Pvec = P.reshape(AbsOmega)
    
    # vectorized emperical tensor
    Phatvec  = np.zeros( AbsOmega )
    
    idxvec   = np.arange(AbsOmega)
    idxs = np.random.choice(idxvec, size=sample_size, p=Pvec)
    for idx in idxs:
        Phatvec[idx] += 1

    Phat = np.reshape(Phatvec, tensor_size )
    return Phat

def main(dist, N, noise, D):
    
    load_path = os.path.join(config_path.data_repo_syn, f"D{D}")
    if noise:
        path = os.path.join(load_path, "with_noise")
    else:
        path = os.path.join(load_path, "without_noise")
        
    P_cp_true     = np.load(os.path.join(path, f"CP_true_{dist}.npy"))
    P_tucker_true = np.load(os.path.join(path, f"Tucker_true_{dist}.npy"))
    P_TT_true     = np.load(os.path.join(path, f"TT_true_{dist}.npy"))
    
    P_cpTT_true     = np.load(os.path.join(path, f"CPTT_true_{dist}.npy"))
    P_cpTucker_true = np.load(os.path.join(path, f"CPTucker_true_{dist}.npy"))
    P_tuckerTT_true = np.load(os.path.join(path, f"TuckerTT_true_{dist}.npy"))
    
    P_cptuckerTT_true = np.load(os.path.join(path, f"CPTuckerTT_true_{dist}.npy"))
    
    N_sample_train = N
    N_sample_valid = N

    ## Pure model
    
    Phat_cp_train = sample_from_JPMF(P_cp_true, N_sample_train)
    Phat_cp_valid = sample_from_JPMF(P_cp_true, N_sample_valid)
    np.save( os.path.join(path, f"CP_train_{dist}_N{N}.npy"), Phat_cp_train )
    np.save( os.path.join(path, f"CP_valid_{dist}_N{N}.npy"), Phat_cp_valid )
    
    Phat_tucker_train = sample_from_JPMF(P_tucker_true, N_sample_train)
    Phat_tucker_valid = sample_from_JPMF(P_tucker_true, N_sample_valid)
    np.save( os.path.join(path, f"Tucker_train_{dist}_N{N}.npy"), Phat_tucker_train )
    np.save( os.path.join(path, f"Tucker_valid_{dist}_N{N}.npy"), Phat_tucker_valid )
    
    Phat_TT_train = sample_from_JPMF(P_TT_true, N_sample_train)
    Phat_TT_valid = sample_from_JPMF(P_TT_true, N_sample_valid)
    np.save( os.path.join(path, f"TT_train_{dist}_N{N}.npy"), Phat_TT_train )
    np.save( os.path.join(path, f"TT_valid_{dist}_N{N}.npy"), Phat_TT_valid )

    ## Mixture of two low-rank model
    
    Phat_cpTT_train = sample_from_JPMF(P_cpTT_true, N_sample_train)
    Phat_cpTT_valid = sample_from_JPMF(P_cpTT_true, N_sample_valid)
    np.save( os.path.join(path, f"CPTT_train_{dist}_N{N}.npy"), Phat_cpTT_train )
    np.save( os.path.join(path, f"CPTT_valid_{dist}_N{N}.npy"), Phat_cpTT_valid )

    Phat_cpTucker_train = sample_from_JPMF(P_cpTucker_true, N_sample_train)
    Phat_cpTucker_valid = sample_from_JPMF(P_cpTucker_true, N_sample_valid)
    np.save( os.path.join(path, f"CPTucker_train_{dist}_N{N}.npy"), Phat_cpTucker_train )
    np.save( os.path.join(path, f"CPTucker_valid_{dist}_N{N}.npy"), Phat_cpTucker_valid )

    Phat_TuckerTT_train = sample_from_JPMF(P_tuckerTT_true, N_sample_train)
    Phat_TuckerTT_valid = sample_from_JPMF(P_tuckerTT_true, N_sample_valid)
    np.save( os.path.join(path, f"TuckerTT_train_{dist}_N{N}.npy"), Phat_TuckerTT_train)
    np.save( os.path.join(path, f"TuckerTT_valid_{dist}_N{N}.npy"), Phat_TuckerTT_valid)
    
    ## Mixture of three low-rank model
    Phat_CPTuckerTT_train = sample_from_JPMF(P_cptuckerTT_true, N_sample_train)
    Phat_CPTuckerTT_valid = sample_from_JPMF(P_cptuckerTT_true, N_sample_valid)
    np.save( os.path.join(path, f"CPTuckerTT_train_{dist}_N{N}.npy"), Phat_CPTuckerTT_train)
    np.save( os.path.join(path, f"CPTuckerTT_valid_{dist}_N{N}.npy"), Phat_CPTuckerTT_valid)
 

if __name__ == "__main__":
    D = config_syn.D
    Nsets = config_syn.Nsets
    Nsets = np.rint(Nsets).astype(int)
    for noise in [True, False]:
        for N in Nsets:
            main("normal", N, noise, D)
