import numpy as np
import pickle
from concurrent.futures import ProcessPoolExecutor, as_completed

import os
import sys
print(sys.version)
sys.path.append("methods/ours")
sys.path.append("methods/baselines/tnfp")

## Utils
import utils_exp as ue
import utils_train as ut

## Configs
import config_exp
sys.path.append(config_exp.data_repo)
import dataset_info
import n_params
import sp_tensor

## Baselines
from tensornetworks.PositiveMPS import PositiveMPS
from tensornetworks.RealLPS import RealLPS
from tensornetworks.RealBorn import RealBorn

import importlib

def lr_tuning(dataset_name, method):
    dataset_repo = os.path.join(config_exp.data_repo, dataset_name)
    coords_train = np.load(os.path.join(dataset_repo, "X_train_coords.npy"))
    values_train = np.load(os.path.join(dataset_repo, "X_train_values.npy"))

    coords_valid = np.load(os.path.join(dataset_repo, "X_valid_coords.npy"))
    values_valid = np.load(os.path.join(dataset_repo, "X_valid_values.npy"))

    coords_test = np.load(os.path.join(dataset_repo, "X_test_coords.npy"))
    values_test = np.load(os.path.join(dataset_repo, "X_test_values.npy"))

    tensor_dims = coords_train.shape[1]
    tensor_size = dataset_info.tensor_sizes[dataset_name]

    T_train = sp_tensor.Sp_tensor(coords_train, values_train, tensor_size, normalize=True )
    T_valid = sp_tensor.Sp_tensor(coords_valid, values_valid, tensor_size, check_empty=False, normalize=True )
    T_test = sp_tensor.Sp_tensor(coords_test, values_test, tensor_size, check_empty=False, normalize=True )

    ## Reslve Double Count ##
    #has_duplicates = len(coords_train) != len(set(tuple(row) for row in coords_train))
    #if has_duplicates:
    #    dec_pos = [ n for (n, i) in enumerate(values_train) ]
    #    T_train_coords_dec = np.vstack([np.tile(coords_train[m], (k, 1)) for m, k in dec_pos])
    #else:
    #    T_train_coords_dec = T_train.coords

    T_train_coords_dec = np.vstack([np.tile(coords_train[m], ( values_train[m], 1)) for m in range( len(coords_train) ) ] )


    lrs  = [0.0001, 0.001, 0.01, 0.1, 1.0]
    rnks = config_exp.tfnp_rnks[method]
    save_path = os.path.join("results/", f"{method}", dataset_name+"_lrs.pkl")

    results_values = { r : {} for r in rnks }
    results_means = { r : {} for r in rnks }
    results_valid_means = { r : {} for r in rnks }
    best_lr = { r : {} for r in rnks }
    best_lr_valid = { r : {} for r in rnks }
    for (k, rnk) in enumerate(rnks):
        print("rnk is", rnk)
        with ProcessPoolExecutor(max_workers=10) as executor:
            futures = [executor.submit(run_tfnp, T_train_coords_dec, T_valid, method, rnk, lr, 1) for lr in lrs]
            results = [future.result() for future in futures]
            #train_scores, valid_scores, n_params = run_tfnp(T_train_coords_dec, T_valid, method, rnk, lr)

        M = len(results)
        results_values[rnk] = { lrs[m]:results[m][0] for m in range(M) }
        results_means[rnk] = { lrs[m]: np.mean(results[m][0]) for m in range(M) }

        results_valid_means[rnk] = { lrs[m]: np.mean(results[m][1]) for m in range(M) }

        best_lr[rnk] = min(results_means[rnk], key=results_means[rnk].get)
        best_lr_valid[rnk] = min(results_valid_means[rnk], key=results_valid_means[rnk].get)


    results_for_save = {"score_train":results_values, "score_train_means":results_means, "best_lr":best_lr, "best_lr_valid":best_lr_valid, "results_valid_means":results_valid_means, "method":method, "dataset_name":dataset_name}
    ue.pickle_dump(results_for_save, save_path)
    print(f"saved in {save_path}")
    # If you wanna load the results
    #ue.pickle_load(save_path)

def eval_tfnp(dataset_name, method):
    dataset_repo = os.path.join(config_exp.data_repo, dataset_name)
    coords_train = np.load(os.path.join(dataset_repo, "X_train_coords.npy"))
    values_train = np.load(os.path.join(dataset_repo, "X_train_values.npy"))

    coords_valid = np.load(os.path.join(dataset_repo, "X_valid_coords.npy"))
    values_valid = np.load(os.path.join(dataset_repo, "X_valid_values.npy"))

    coords_test = np.load(os.path.join(dataset_repo, "X_test_coords.npy"))
    values_test = np.load(os.path.join(dataset_repo, "X_test_values.npy"))

    tensor_dims = coords_train.shape[1]
    tensor_size = dataset_info.tensor_sizes[dataset_name]

    T_train = sp_tensor.Sp_tensor(coords_train, values_train, tensor_size, normalize=True )
    T_valid = sp_tensor.Sp_tensor(coords_valid, values_valid, tensor_size, check_empty=False, normalize=True )
    T_test = sp_tensor.Sp_tensor(coords_test, values_test, tensor_size, check_empty=False, normalize=True )

    ## Reslve Double Count ##
    #has_duplicates = len(coords_train) != len(set(tuple(row) for row in coords_train))
    #if has_duplicates:
    #    dec_pos = [ n for (n, i) in enumerate(values_train) ]
    #    T_train_coords_dec = np.vstack([np.tile(coords_train[m], (k, 1)) for m, k in dec_pos])
    #else:
    #    T_train_coords_dec = T_train.coords
    T_train_coords_dec = np.vstack([np.tile(coords_train[m], ( values_train[m], 1)) for m in range( len(coords_train) ) ] )

    save_path = os.path.join("results/", f"{method}", dataset_name+"_test.pkl")

    load_lrs_path = os.path.join("results/", f"{method}", dataset_name+"_lrs.pkl")
    best_lr  = ue.pickle_load(load_lrs_path)["best_lr_valid"]

    load_rnk_path = os.path.join("results/", f"{method}", dataset_name+".pkl")

    score_valid_each_rank = ue.pickle_load(load_rnk_path)["score_valid"]
    best_rnk = min(score_valid_each_rank, key=lambda k: np.mean(score_valid_each_rank[k])) 
    with ProcessPoolExecutor(max_workers=10) as executor:
        futures = [executor.submit(run_tfnp, T_train_coords_dec, T_test, method, best_rnk, best_lr[best_rnk], config_exp.rep_times)]
        results = [future.result() for future in futures]
    #    train_scores, test_scores, n_params = run_tfnp(T_train_coords_dec, T_test, method, best_rnk, best_lr[best_rnk], config_exp.rep_times)

    M = len(results)
    train_scores = [ results[m][0] for m in range(M) ]
    test_scores  = [ results[m][1] for m in range(M) ]
    n_paras = [ results[m][2] for m in range(M) ]

    print(train_scores)
    print(test_scores)
    print("mean:", np.mean(test_scores))
    print(" std:", np.std(test_scores, ddof=1))
    results_for_save = {"rnk":best_rnk, "n_params":n_paras, "score_train":train_scores, "score_test":test_scores, "method":method, "lr":best_lr,
                   "dataset_name":dataset_name}
    ue.pickle_dump(results_for_save, save_path)
    print(f"saved in {save_path}")
    # If you wanna load the results
    #ue.pickle_load(save_path)



def exp_tfnp(dataset_name, method):
    dataset_repo = os.path.join(config_exp.data_repo, dataset_name)
    coords_train = np.load(os.path.join(dataset_repo, "X_train_coords.npy"))
    values_train = np.load(os.path.join(dataset_repo, "X_train_values.npy"))

    coords_valid = np.load(os.path.join(dataset_repo, "X_valid_coords.npy"))
    values_valid = np.load(os.path.join(dataset_repo, "X_valid_values.npy"))

    coords_test = np.load(os.path.join(dataset_repo, "X_test_coords.npy"))
    values_test = np.load(os.path.join(dataset_repo, "X_test_values.npy"))

    tensor_dims = coords_train.shape[1]
    tensor_size = dataset_info.tensor_sizes[dataset_name]

    T_train = sp_tensor.Sp_tensor(coords_train, values_train, tensor_size, normalize=True )
    T_valid = sp_tensor.Sp_tensor(coords_valid, values_valid, tensor_size, check_empty=False, normalize=True )
    T_test = sp_tensor.Sp_tensor(coords_test, values_test, tensor_size, check_empty=False, normalize=True )

    ## Reslve Double Count ##
    #has_duplicates = len(coords_train) != len(set(tuple(row) for row in coords_train))
    #if has_duplicates:
    #    dec_pos = [ n for (n, i) in enumerate(values_train) ]
    #    T_train_coords_dec = np.vstack([np.tile(coords_train[m], (k, 1)) for m, k in dec_pos])
    #else:
    #    T_train_coords_dec = T_train.coords
    T_train_coords_dec = np.vstack([np.tile(coords_train[m], ( values_train[m], 1)) for m in range( len(coords_train) ) ] )

    #rnks = [1,2,3,4,5,6,7,8]
    rnks = config_exp.tfnp_rnks[method]

    save_path = os.path.join("results/", f"{method}", dataset_name+".pkl")
    load_lrs_path = os.path.join("results/", f"{method}", dataset_name+"_lrs.pkl")
    best_lr = ue.pickle_load(load_lrs_path)["best_lr"]
    with ProcessPoolExecutor(max_workers=10) as executor:
        futures = [executor.submit(run_tfnp, T_train_coords_dec, T_valid, method, rnk, best_lr[rnk], config_exp.rep_times) for rnk in rnks]
        results = [future.result() for future in futures]
        #train_scores, valid_scores, n_params = run_tfnp(T_train_coords_dec, T_valid, method, rnk, lr)

    M = len(results)
    n_paras = [ results[m][2] for m in range(M) ]

    train_scores = { rnks[m]:results[m][0] for m in range(M) }
    valid_scores = { rnks[m]:results[m][1] for m in range(M) }

    print(train_scores)
    print(valid_scores)
    results_for_save = {"rank":rnks, "n_params":n_paras, "score_train":train_scores, "score_valid":valid_scores, "method":method, "lr":best_lr, 
                   "dataset_name":dataset_name}
    ue.pickle_dump(results_for_save, save_path)
    print(f"saved in {save_path}")
    # If you wanna load the results
    #ue.pickle_load(save_path)

def run_tfnp(T_train_coords_dec, T_valid, method, rnk, lr, rep_times):
    train_scores = np.zeros(rep_times)
    valid_scores = np.zeros(rep_times)
    for rep in range( rep_times ):
        np.random.seed(rep)

        if method == "MPS":
            mps = PositiveMPS(D=rnk, learning_rate=lr, batch_size=20, n_iter=10000, verbose=True)
        elif method == "LPS":
            mps = RealLPS(D=rnk, learning_rate=lr, batch_size=20, n_iter=10000, verbose=True)
        elif method == "BM":
            mps = RealBorn(D=rnk, learning_rate=lr, batch_size=20, n_iter=10000, verbose=True)
        else:
            print("error")

        mps.fit(T_train_coords_dec)

        train_scores[rep] = mps.likelihood(T_train_coords_dec)
        valid_scores[rep] = mps.cross(T_valid.coords, T_valid.values)
        n_params = np.shape(mps.w)[0]

    return train_scores, valid_scores, n_params

if __name__ == "__main__":
    #dataset_name = "DMFT"
    #dataset_name = "Chess"
    #dataset_name = "Tumor"
    #dataset_name = "Votes"
    #dataset_name = "Lymphography"
    #dataset_name = "SPECT"
    #dataset_name = "Led7"
    dataset_name = "SolarFlare"

    for method in ["MPS", "BM", "LPS"]:
        lr_tuning(dataset_name, method)
        exp_tfnp(dataset_name, method)
        eval_tfnp(dataset_name, method)
