import numpy as np

from scipy.stats import dirichlet
import os
from itertools import product

import sys

sys.path.append("../")
sys.path.append("config")
import config_path
import config_syn
#import utils_exp as ue

save_path = config_path.data_repo_syn

def multilinear_rank(tensor):
    """
    Compute the multilinear rank of a tensor.
    
    Parameters:
        tensor (ndarray): Input tensor.
    
    Returns:
        list: Multilinear rank for each mode of the tensor.
    """
    ranks = []
    ndim = tensor.ndim  # Number of dimensions
    for mode in range(ndim):
        # Unfold the tensor along the current mode
        unfold = np.reshape(np.moveaxis(tensor, mode, 0), (tensor.shape[mode], -1))
        # Compute the rank of the unfolded matrix
        rank = np.linalg.matrix_rank(unfold)
        ranks.append(rank)
    return ranks

def reconst_full_CP(A):
    D = len(A)
    J = [ np.shape(A[d])[0] for d in range(D) ]
    rnk = np.shape(A[0])[1]
    T = np.zeros(J)
    
    for idx, _ in np.ndenumerate(T):
        value = 0.0
        for r in range(rnk):
            value += np.prod( [ A[d][idx[d],r] for d in range(D) ] )
            
        T[idx] = value
    return T

def reconst_full_Tucker(G, A):
    D = len(A)
    J = [ np.shape(A[d])[0] for d in range(D) ]
    rnk = [np.shape(A[d])[1] for d in range(D) ]
    assert len(rnk) == D, "dim of rnk is wrong"
    T = np.zeros(J)

    ranges = [range(r) for r in rnk]
    for idx, _ in np.ndenumerate(T):
        value = 0.0
        for r in product(*ranges):
            value += G[r] * np.prod( [ A[d][idx[d],r[d]] for d in range(D) ] )
        T[idx] = value
        
    return T
   
def reconst_full_TT(cores):
    tensor_shape = [core.shape[1] for core in cores]
    result = np.zeros(tensor_shape)
    indices = np.indices(tensor_shape).reshape(len(tensor_shape), -1).T
    for idx in indices:
        temp = cores[0][:, idx[0], :].reshape(-1) # The first core
        for i in range(1, len(cores)):
            core = cores[i]
            temp = np.tensordot(temp, core[:, idx[i], :], axes=([0], [0]))
        result[tuple(idx)] = temp.item()
    return result

def get_syndata(structure, D, J, rnk, dist="uni"):
    if structure == "CP":
        assert isinstance(2,int), "rnk should be interger for CPD"
        A = [ [] for d in range(D) ]
        for d in range(D):
            if dist == "uni":
                A[d] = np.random.rand( J, rnk )
            elif dist == "normal":
                A[d] = np.abs(np.random.randn( J, rnk ))
            elif dist == "dirichlet":
                alpha = np.random.rand(rnk)
                A[d] = dirichlet.rvs(alpha, size=J, random_state=1)
            else:
                print("dist name error")

        return A
        
    elif structure == "Tucker":
        assert len(rnk) == D, "rnk should be D-dim vecotr"
        A = [ [] for d in range(D) ]

        # Get random factor matrices A[1], A[2], ..., A[D]
        for d in range(D):
            alpha = np.random.rand(rnk[d])
            A[d] = dirichlet.rvs(alpha, size=J, random_state=1)
            if dist == "uni":
                A[d] = np.random.rand( J, rnk[d] )
            elif dist == "normal":
                A[d] = np.abs(np.random.randn( J, rnk[d] ))
            elif dist == "dirichlet":
                alpha = np.random.rand(rnk)
                A[d] = dirichlet.rvs(alpha, size=J, random_state=1)
            else:
                print("dist name error")

        # Get random core tensor G
        if dist == "uni":
            G = np.random.rand( *rnk )
        elif dist == "normal":
            G = np.abs(np.random.randn( *rnk ))
        elif dist == "dirichlet":
            dim = np.prod(rnk)
            mean = np.zeros(dim)
            cov  = D*J*D*np.random.rand(dim, dim)
            cov  = np.dot(cov, cov.T)
            samples = np.random.multivariate_normal(mean, cov)
            G = samples.reshape(rnk)
            G = np.abs(G)
        else:
            print("dist name error")

        return G, A

    elif structure == "TT":
        assert len(rnk) == D-1, "rnk should be (D-1)-dim vecotr"

        G = [ [] for d in range(D) ] # Core tensors of TT
        if dist == "uni":
            G[0] = np.random.rand(1, J, rnk[0])
            for d in range(1,D-1):
                G[d] = np.random.rand(rnk[d-1], J, rnk[d])
            G[D-1] = np.random.rand(rnk[D-2], J, 1)
        elif dist == "normal":
            G[0] = np.abs(np.random.randn(1, J, rnk[0]))
            for d in range(1,D-1):
                G[d] = np.abs(np.random.randn(rnk[d-1], J, rnk[d]))
            G[D-1] = np.abs(np.random.randn(rnk[D-2], J, 1))
        else:
            print("dist name error")


        """
        G = [ [] for d in range(D) ] # Core tensors of TT
        for d in range(D):
            if d == 0:
                G[0] = np.random.rand(J, rnk[d])
            elif d == D-1:
                G[d] = np.random.rand(rnk[d-1], J)
            else:
                G[d] = np.random.rand(rnk[d-1], J, rnk[d])
        """
                
        return G

def main(dist, noise):
    D = config_syn.D; J = config_syn.J;
    rankcp     = config_syn.rankcp;
    ranktucker = config_syn.ranktucker;
    ranktt     = config_syn.ranktt;

    # Mix ratio for EMCPTrain
    etacp = config_syn.etacp 
    etaTucker = config_syn.etaTucker
    etaTT = config_syn.etaTT

    datast_detail = {"CPrank":rankcp, "Tuckerrank":ranktucker, "TTrank":ranktt, "J":J}

    assert etacp + etaTucker + etaTT == 1.0, "the sum of weight need to be 1"

    A     = get_syndata("CP", D, J, rankcp, dist=dist)
    G, A  = get_syndata("Tucker", D, J, ranktucker, dist=dist)
    cores = get_syndata("TT", D, J, ranktt, dist=dist)

    ## Normalized pure models
    
    Tcp = reconst_full_CP(A)
    Ttucker = reconst_full_Tucker(G, A)
    TT = reconst_full_TT(cores)
    
    Tcp = Tcp / np.sum(Tcp)
    Ttucker = Ttucker / np.sum(Ttucker)
    TT = TT / np.sum(TT)
    
    ## mixture of two models
    
    TcpTT = etacp * Tcp + etaTT * TT
    Tcptucker = etacp * Tcp + etaTucker * Ttucker
    TtuckerTT = etaTucker * Ttucker + etaTT * TT
    
    TcpTT = TcpTT / np.sum(TcpTT)
    Tcptucker = Tcptucker / np.sum(Tcptucker)
    TtuckerTT = TtuckerTT / np.sum(TtuckerTT)
    
    ## mixture of three moedels
    
    TcptuckerTT = etacp * Tcp + etaTucker * Ttucker + etaTT * TT

    # Add noise
    if noise != 0.0:
        Tcp     = (1-noise) * Tcp     + 1.0 * noise / np.size(Tcp)
        Ttucker = (1-noise) * Ttucker + 1.0 * noise / np.size(Ttucker)
        TT      = (1-noise) * TT      + 1.0 * noise / np.size(TT)
        
        TcpTT   = (1-noise) * TcpTT   + 1.0 * noise / np.size(TcpTT)
        Tcptucker   = (1-noise) * Tcptucker   + 1.0 * noise / np.size(Tcptucker)
        TtuckerTT   = (1-noise) * TtuckerTT + 1.0 * noise / np.size(TtuckerTT)
        
        TcptuckerTT   = (1-noise) * TcptuckerTT + 1.0 * noise / np.size(TcptuckerTT)

    if noise > 0.0:
        ## pure models
        path = os.path.join(save_path, f"D{D}", "with_noise")
    else:
        path = os.path.join(save_path, f"D{D}", "without_noise")


    np.save(os.path.join(path, f"CP_true_{dist}.npy"), Tcp)
    np.save(os.path.join(path, f"Tucker_true_{dist}.npy"), Ttucker)
    np.save(os.path.join(path, f"TT_true_{dist}.npy"), TT)
    
    ## mixture of two models
    np.save(os.path.join(path, f"CPTucker_true_{dist}.npy"), Tcptucker)
    np.save(os.path.join(path, f"CPTT_true_{dist}.npy"), TcpTT)
    np.save(os.path.join(path, f"TuckerTT_true_{dist}.npy"), TtuckerTT)
    
    ## mixture of three models
    np.save(os.path.join(path, f"CPTuckerTT_true_{dist}.npy"), TcptuckerTT)
    print("saved")

    save_path_dataset_detail = os.path.join(path,"dataset_detail.pkl")
    ue.pickle_dump(datast_detail, save_path_dataset_detail)
    print(f"saved in {save_path_dataset_detail}")
    # If you wanna load the results 
    #ue.pickle_load(save_path_dataset_detail)

if __name__ == "__main__":
    noise = 0.10
    main("normal", noise)
