import numpy as np
import math

import sys
data_repo = "data/"
sys.path.append(data_repo)
import dataset_info
import config_exp
import importlib
importlib.reload(config_exp)

def cp_n_params(tensor_size, rnk):
    return np.sum( np.array( tensor_size ) - 1 ) * rnk

def tucker_n_params(tensor_size, rnk):
    tensor_dim = len(tensor_size)
    n_param_core = np.prod( np.array(rnk) )
    n_param_factor = sum( tensor_size[d] * rnk[d] for d in range(tensor_dim) )
    return n_param_core + n_param_factor

def train_n_params(tensor_size, rnk):
    tensor_dim = len(tensor_size)
    term = 0
    for d in range(tensor_dim):
        if d == 0:
            term += tensor_size[0] * rnk[0]
        elif d == tensor_dim - 1:
            term += rnk[d-1] * tensor_size[d]
        else:
            term += rnk[d-1] * tensor_size[d] * rnk[d]
    return term

def cptrain_n_params(tensor_size, rnk):
    n_cp = cp_n_params(tensor_size, rnk[0])
    n_train = train_n_params(tensor_size, rnk[1])
    return n_cp + n_train


def get_all_n_params(dataset_name):
    tensor_size = dataset_info.tensor_sizes[dataset_name]

    Ncps = []
    NTuckers = []
    NTrains = []
    NCPTrains = []
    NTrainsSame= []

    method = "emCP"
    if dataset_name in config_exp.ranks_set[method].keys():
        rnks = config_exp.ranks_set[method][dataset_name]
        for rnk in rnks:
            Ncps.append( cp_n_params(tensor_size, rnk) )
        print("CP", Ncps)

    method = "emTucker"
    if dataset_name in config_exp.ranks_set[method].keys():
        rnks = config_exp.ranks_set[method][dataset_name]
        for rnk in rnks:
            NTuckers.append(tucker_n_params(tensor_size, rnk))
        print("Tucker", NTuckers)

    method = "emTrain"
    if dataset_name in config_exp.ranks_set[method].keys():
        rnks = config_exp.ranks_set[method][dataset_name]
        for rnk in rnks:
            NTrains.append(train_n_params(tensor_size, rnk))
        print("Train", NTrains)

    method = "emCPTrain"
    if dataset_name in config_exp.ranks_set[method].keys():
        rnks = config_exp.ranks_set[method][dataset_name]
        for rnk in rnks:
            NCPTrains.append(cptrain_n_params(tensor_size, rnk))
        print("CPTrain", NCPTrains)
        print("CPTrain_sorted", sorted(NCPTrains))
        assert len(NCPTrains) == len(np.unique(np.array(NCPTrains))), "same n_param"

    if dataset_name in config_exp.ranksSameTrain.keys():
        rnks = config_exp.ranksSameTrain[dataset_name]
        for rnk in rnks:
            NTrainsSame.append(train_n_params(tensor_size, rnk))
        print("TrainSame", NTrainsSame)

if __name__ == "__main__":
    args = sys.argv
    get_all_n_params(args[1])

