import numpy as np
import sys

data_repo = "data/"
sys.path.append(data_repo)
import dataset_info
from dataset_info import tensor_dims as Ds
import importlib
importlib.reload(dataset_info)

rep_times = 5

max_iter = 2000
max_iter_tl  = 250
max_iter_nnf = 2500

tol = 1.0e-7
tol_tl = 1.0e-6

conv_check_interval = 10
verbose_interval = 10

def MKN(l,M,k,N):
    arr = []
    i = 1
    while i < N:
        arr.append(i)
        if i < M:
            i += l
        else:
            i += k
    return arr

####### Rank Tuning #########

def get_Tumor_train_rank():
    Dim = Ds["Tumor"] -1  # 16

    K = 1
    value = 1
    r0 = [2 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    value = 2
    K = 3
    r1 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    K = 6
    r2 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    K = 9
    r3 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    K = 12
    r3 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    r4 = [value for i in range(Dim)]

    value = 3

    K = 9
    r7 = [value-1 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    K = 12
    r8 = [value-1 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    rtrains = [r0, r1, r2, r3, r4, r7, r8]
    return rtrains



def get_Votes_train_rank():
    Dim = Ds["Votes"] -1  # 16

    K = 1
    value = 1
    r0 = [2 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    value = 2
    K = 3
    r1 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    K = 7
    r2 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    K = 9
    r3 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    K = 12
    r3 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    r4 = [value for i in range(Dim)]

    value = 3
    K = 9
    r7 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    K = 11
    r8 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    K = 13
    r9 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    rtrains = [r0, r1, r2, r3, r4, r7, r8, r9]
    return rtrains



def get_Chess_train_rank():
    Dim = Ds["Chess"] - 1 # 34

    K = 1
    value = 1
    r0 = [2 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    value = 2
    K = 3
    r1 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    K = 10
    r2 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    K = 20
    r3 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    value = 3
    K = 3
    r4 = [value-1 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 2 for i in range(Dim)]

    value = 3
    K = 10
    r5 = [value-1 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 2 for i in range(Dim)]

    value = 3
    K = 20
    r6 = [value-1 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 2 for i in range(Dim)]

    value = 4
    K = 10
    r7 = [value-1 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 3 for i in range(Dim)]

    value = 4
    K = 20
    r8 = [value-1 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 3 for i in range(Dim)]

    rtrains = [r0, r2, 2 * np.ones(Dim, dtype=int), 3 * np.ones(Dim, dtype=int), r3 * np.ones(Dim, dtype=int) ]
    return rtrains



def get_SPECT_train_rank():
    Dim = Ds["SPECT"] - 1 # 21

    value = 2
    K = 2
    r1 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    K = 3
    r2 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    K = 10
    r3 = [value if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else 1 for i in range(Dim)]

    value = 3
    K = 3
    r4 = [value-1 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else value-1 for i in range(Dim)]

    K = 10
    r5 = [value-1 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else value-1 for i in range(Dim)]

    K = 15
    r6 = [value-1 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else value-1 for i in range(Dim)]

    value = 4
    K = 10
    r7 = [value-2 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else value-1 for i in range(Dim)]

    K = 15
    r8 = [value-2 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else value-1 for i in range(Dim)]

    value = 5
    K = 10
    r9 = [value-2 if i == 0 else value if (Dim-K)//2 <= i < (Dim+K)//2 else value-1 for i in range(Dim)]

    rtrains = [ r1, r2, 2 * np.ones(Dim, dtype=int), r3, r4, r5, r6, 3 * np.ones(Dim, dtype=int), r7, r8, 4 * np.ones(Dim, dtype=int), r9 ]
    return rtrains


ranksCP = {
    "Lymphography": MKN(1,10,5,20),
    "DMFT": range(1,20,2),
    "Led7": range(1,55,3),
	"Chess": range(1,115,15),
    "SPECT": MKN(1,10, 6, 36),
    "Tumor": MKN(1,10, 6, 25),
    "Votes": MKN(1,10, 6, 20),
    "SolarFlare": MKN(1,11, 7, 107)
}

ranksTucker = {
	"SolarFlare" : [[1,1,1,1,1,1,1,1,1,1],
                    [2,2,2,2,2,1,1,1,1,1],
                    [2,2,2,2,2,2,1,1,1,1],
                    [2,2,2,2,2,2,2,1,1,1],
                    [2,2,2,2,2,2,2,2,1,1],
                    [2,2,2,2,2,2,2,2,2,1]],
    "Led7" : [
            [1,1,1,1,1,1,1,1],
            [1,1,2,2,2,2,1,1],
            [2,2,2,2,2,2,1,1],
            [2,2,2,2,2,2,2,1],
            [2,2,2,2,2,2,2,2],
            [3,2,2,2,2,2,2,2],
            [3,3,2,2,2,2,2,2],
            [3,3,3,2,2,2,2,2],
            ],

    "DMFT" : [[1,1,1,1,1],
                   [2,2,2,2,2],
                   [2,2,3,3,2],
                   [2,3,3,3,2],
                   [2,3,3,3,3],
                   [3,3,3,3,3],
                  ],
    }

for dataset_name in ranksTucker.keys():
    for l in range( len(ranksTucker[dataset_name] ) ):
        assert len(ranksTucker[dataset_name][l]) == Ds[dataset_name], f"{dataset_name} size of Tucker rank wrong"

ranksTrain = {
    "SPECT": get_SPECT_train_rank(),
    "Chess": [ r * np.ones(Ds["Chess"]-1, dtype=int) for r in range(2,8,1)],
    "Votes": get_Votes_train_rank(),
    "Tumor": get_Tumor_train_rank(),
	"Led7" : [
            [2,1,1,1,1,1,1],
            [2,1,2,2,2,1,1],
            [2,2,2,2,2,2,2],
            [2,2,3,3,3,2,2],
            [3,3,3,3,3,3,3],
            [3,3,4,4,4,3,3],
            [3,4,4,4,4,4,3],
            [4,4,4,4,4,4,4],
            [4,4,5,5,5,4,4],
            [5,5,5,5,5,5,5],
            [5,5,6,6,6,5,5],
            [6,6,6,6,6,6,6],
            [6,6,7,7,7,6,6],
            [7,7,7,7,7,7,7],
            [7,7,8,8,8,7,7],
            [8,8,8,8,8,8,8],
            ],
    "DMFT" : [[2,1,1,1],
            [2,2,1,1],
            [2,1,2,1],
            [2,1,1,2],
            [2,2,2,1],
            [2,2,1,2],
            [2,1,2,2],
            [2,2,2,2],
            [2,2,2,3],
            [2,2,3,2],
            [2,3,2,2],
            [2,3,3,2],
            [2,3,3,3]],

    "Lymphography" :[[2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
                     [2,1,1,1,1,2,2,2,2,2,2,1,1,1,1,1],
                     [2,1,1,1,2,2,2,2,2,2,2,2,1,1,1,1],
                     [2,1,1,2,2,2,2,2,2,2,2,2,2,1,1,1],
                     [2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2],
                     [2,2,2,2,2,2,2,3,3,3,2,2,2,2,2,2],
                     [2,2,2,2,2,2,3,3,3,3,3,2,2,2,2,2],
                     [2,2,2,2,2,3,3,3,3,3,3,3,2,2,2,2],
                     [2,2,2,2,3,3,3,3,3,3,3,3,3,2,2,2],
                     [2,2,2,3,3,3,3,3,3,3,3,3,3,3,2,2],
                     [2,2,3,3,3,3,3,3,3,3,3,3,3,3,3,2],
                     [3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3],
                     [3,3,3,3,3,3,4,4,4,3,3,3,3,3,3,3],
                     ],
    "SolarFlare" : [[2,1,1,1,1,1,1,1,1],
                    [2,1,1,1,2,2,1,1,1],
                    [2,1,1,2,2,2,2,1,1],
                    [2,2,1,2,2,2,2,2,2],
                    [2,2,2,2,2,2,2,2,2],
                    [2,2,2,2,3,3,2,2,2],
                    [2,2,2,3,3,3,3,2,2],
                    [2,3,3,3,3,3,3,3,3],
                    [3,3,3,3,3,3,3,3,3],
                    [3,3,3,3,4,4,3,3,3],
                    [3,3,4,4,4,4,4,4,3],
                    [4,4,4,4,4,4,4,4,4],
                    [4,4,5,5,5,5,4,4,4],
                    [5,5,5,5,5,5,5,5,5],
                    [5,5,5,6,6,6,5,5,5],
                    [5,6,6,6,6,6,6,6,5],
                    [6,6,6,6,6,6,6,6,6],
                    [6,6,6,6,7,7,6,6,6],
                    [7,7,7,7,7,7,7,7,7],
                    [7,7,7,7,8,8,7,7,7],
                    [7,7,7,8,8,8,7,7,7],
                    [7,8,8,8,8,8,8,8,7],
                    [8,8,8,8,8,8,8,8,8],
                    [8,8,8,9,9,8,8,8,8],
                    [8,9,9,9,9,9,9,8,8],
                    [9,9,9,9,9,9,9,9,9],
                    ],
}


for dataset_name in ranksTrain.keys():
    for l in range( len(ranksTrain[dataset_name] ) ):
        assert len(ranksTrain[dataset_name][l]) == Ds[dataset_name]-1, f"{dataset_name} size of train rank wrong"
        assert ranksTrain[dataset_name][l][0] > 1, f"{dataset_name} the first slot of train rank need to be larger than 2"

### For mix model
ranksCPTrain = {
    "Lymphography" : [ (r, rtrain) for r in [1,2,3] for rtrain in ranksTrain["Lymphography"][0:3]],
    "DMFT" : [ (r, rtrain ) for r in [2] for rtrain in ranksTrain["DMFT"][0:-1]],
    "Led7" : [ (r, rtrain ) for r in [10,20] for rtrain in ranksTrain["Led7"][0:-1]],
}


tfnp_rnks = { "MPS" : [2,3,4,5,6,7], "BM": [2,3,4,5,6,7], "LPS": [2,3,4,5,6,7] }


ranksSameTrain = {
    "Chess":[[2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
             [2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2],
             [3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3],
             [4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4],
             [5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5],
             ],

    "Led7" : [
            [2,1,1,1,1,1,1],
            [2,2,2,2,2,2,2],
            [3,3,3,3,3,3,3],
            [4,4,4,4,4,4,4],
            [5,5,5,5,5,5,5],
            [6,6,6,6,6,6,6],
            [7,7,7,7,7,7,7],
            [8,8,8,8,8,8,8],
            ],

    "Lymphography" :[[2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
                     [2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2],
                     [3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3],
                     [4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4],
                     ],
    "Votes": [
        [2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
        [2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2],
        [3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3],
        [4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4],
        ],

    "Tumor": [
        [2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2],
        [3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3],
        [4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4],
        ],

    "DMFT" : [ [2,1,1,1],
               [2,2,2,2],
               [3,3,3,3],
               [4,4,4,4],
               [5,5,5,5],
              ],

    "SolarFlare" : [[2,1,1,1,1,1,1,1,1],
                    [2,2,2,2,2,2,2,2,2],
                    [3,3,3,3,3,3,3,3,3],
                    [4,4,4,4,4,4,4,4,4],
                    [5,5,5,5,5,5,5,5,5],
                    [6,6,6,6,6,6,6,6,6],
                    [7,7,7,7,7,7,7,7,7],
                    [8,8,8,8,8,8,8,8,8],
                    [9,9,9,9,9,9,9,9,9],
                    ],
}

for dataset_name in ranksSameTrain.keys():
    for l in range( len(ranksSameTrain[dataset_name] ) ):
        assert len(ranksSameTrain[dataset_name][l]) == Ds[dataset_name]-1, f"{dataset_name} size of train rank wrong"
        assert ranksSameTrain[dataset_name][l][0] > 1, f"Fast slot of train rank of {dataset_name} need to be larger than 2"

ranks_set = {"emCP":ranksCP,
             "NNCP":ranksCP,
             "NNCPHALS":ranksCP,
             "CP":ranksCP,
             "emTucker":ranksTucker,
             "Tucker":ranksTucker,
             "KLNTDMU":ranksTucker,
             "KLCPMU":ranksCP,
             "PTucker":ranksTucker,
             "NNTucker":ranksTucker,
             "NNTuckerHALS":ranksTucker,
             "emTrain":ranksTrain,
             "emTrainO":ranksTrain,
             "TT": ranksSameTrain,
             "emCPTrain":ranksCPTrain,
             "emCPTrainO":ranksCPTrain
            }



