import numpy as np
import time
from mpl_toolkits.mplot3d import Axes3D  # noqa
import CodeGW


def dist(new_point, points, r_threshold):
    for point in points:
        dist = np.sqrt(np.sum(np.square(new_point - point)))
        if dist < r_threshold:
            return False
    return True


def RandX(N, r_threshold, d1, seed=49):
    points = []
    scope = np.arange(0, N * r_threshold, 0.1)
    while len(points) < N:
        seed = seed + 1
        new_point = np.random.choice(scope, d1)
        if dist(new_point, points, r_threshold):
            points.append(new_point)
    return np.array(points)


def clustered_distributions(
    num_samples, dimX, dimY, num_clusters=10, r_threshold=1, sigma=0.1, seed=49
):
    cluster_centers_X = RandX(num_clusters, r_threshold, dimX, seed=seed)
    cluster_centers_Y = RandX(num_clusters, r_threshold, dimY, seed=seed)

    np.random.seed(seed)
    cov_X = sigma * np.eye(dimX)
    cov_Y = sigma * np.eye(dimY)
    num_samples_per_cluster = int(num_samples / num_clusters)
    n_left = num_samples - num_samples_per_cluster * num_clusters
    X = np.zeros((num_samples_per_cluster * num_clusters + n_left, dimX))
    Y = np.zeros((num_samples_per_cluster * num_clusters + n_left, dimY))
    for k in range(num_clusters):
        if k == 0:
            X[0 : num_samples_per_cluster + n_left, :] = np.random.multivariate_normal(
                cluster_centers_X[k, :], cov_X, num_samples_per_cluster + n_left
            )
            Y[0 : num_samples_per_cluster + n_left, :] = np.random.multivariate_normal(
                cluster_centers_Y[k, :], cov_Y, num_samples_per_cluster + n_left
            )
        else:
            X[
                k * num_samples_per_cluster
                + n_left : (k + 1) * num_samples_per_cluster
                + n_left,
                :,
            ] = np.random.multivariate_normal(
                cluster_centers_X[k, :], cov_X, num_samples_per_cluster
            )
            Y[
                k * num_samples_per_cluster
                + n_left : (k + 1) * num_samples_per_cluster
                + n_left,
                :,
            ] = np.random.multivariate_normal(
                cluster_centers_Y[k, :], cov_Y, num_samples_per_cluster
            )

    X = X  # / np.max(X)
    Y = Y  # / np.max(Y)

    return X, Y


num_samples = 5000
dimX, dimY = 10, 15
r_threshold = 10
sigma = 0.1
seed = 49
Init = "lower_bound"
seed_init = 10
delta_LR = 1e-100
delta_IBP = 1e-3
delta_sin = 1e-3
LSE = False
time_out = 50

Euclidean_cost = lambda X, Y: CodeGW.Euclidean_Distance(X, Y)
cost = Euclidean_cost

rank_cost = 100
cost_factorized = lambda X, Y: CodeGW.factorized_distance_cost(
    X, Y, rank_cost, cost, C_init=False, tol=1e-1, seed=50
)


a = (1 / num_samples) * np.ones(num_samples)
b = (1 / num_samples) * np.ones(num_samples)


list_num_clusters = [2, 5, 10, 20, 30]
num_experiments_blocs = len(list_num_clusters)

ranks = [10, 50, 100, 500]
mod = len(ranks)

gammas_init = [1000, 500, 250, 100, 50, 30, 10, 1]
num_gammas_init = len(gammas_init)

num_iter_LR = 500
num_iter_Sin = 500


Data_GW_init = np.zeros(num_experiments_blocs)

Data_Sin = np.zeros((num_experiments_blocs, 2, num_iter_Sin, num_gammas_init))
Data_Sin_factorized = np.zeros(
    (num_experiments_blocs, 2, num_iter_Sin, num_gammas_init)
)
Data_LR = np.zeros((num_experiments_blocs, num_gammas_init, mod, 2 * num_iter_LR))
Data_LR_factorized = np.zeros(
    (num_experiments_blocs, num_gammas_init, mod, 2 * num_iter_LR)
)

Arr_Coupling_Sin = np.zeros(
    (num_experiments_blocs, num_gammas_init, num_samples, num_samples)
)
Arr_Coupling_Sin_factorized = np.zeros(
    (num_experiments_blocs, num_gammas_init, num_samples, num_samples)
)
Arr_Coupling_LR = np.zeros(
    (num_experiments_blocs, num_gammas_init, mod, num_samples, num_samples)
)
Arr_Coupling_LR_factorized = np.zeros(
    (num_experiments_blocs, num_gammas_init, mod, num_samples, num_samples)
)

start = time.time()

for k in range(len(list_num_clusters)):
    X, Y = clustered_distributions(
        num_samples,
        dimX,
        dimY,
        num_clusters=list_num_clusters[k],
        r_threshold=r_threshold,
        sigma=sigma,
        seed=seed,
    )

    D1 = cost(X, X)
    D11, D12 = cost_factorized(X, X)

    D2 = cost(Y, Y)
    D21, D22 = cost_factorized(Y, Y)

    r1 = D1.max()
    r2 = D2.max()
    #
    D1 /= r1
    D2 /= r2

    D11, D12 = D11 / np.sqrt(r1), D12 / np.sqrt(r1)
    D21, D22 = D21 / np.sqrt(r2), D22 / np.sqrt(r2)

    GW_init = CodeGW.GW_Init(D11, D12, D21, D22, a, b)
    Data_GW_init[k] = GW_init

    ###### Sinkhorn Algorithm #######
    for ind_gamma_0, gamma_0 in enumerate(gammas_init):
        reg = 1 / gamma_0
        results = CodeGW.Quad_GW_entropic_distance2(
            D11,
            D12,
            D21,
            D22,
            reg,
            a,
            b,
            Init=Init,
            seed_init=seed_init,
            I=num_iter_Sin - 1,
            delta_sin=delta_sin,
            num_iter_sin=10000,
            lam_sin=0,
            time_out=time_out,
            LSE=LSE,
        )
        if results != "Error":
            res, acc_Sin, times_Sin, num_ops_Sin, Couplings_Sin = results

            Arr_Coupling_Sin_factorized[k, ind_gamma_0, :, :] = Couplings_Sin[-1]

            Data_Sin_factorized[
                k, 0, : np.shape(num_ops_Sin)[0], ind_gamma_0
            ] = num_ops_Sin
            Data_Sin_factorized[
                k, 0, np.shape(num_ops_Sin)[0] :, ind_gamma_0
            ] = num_ops_Sin[-1] * np.ones(num_iter_Sin - np.shape(num_ops_Sin)[0])

            Data_Sin_factorized[k, 1, : np.shape(acc_Sin)[0], ind_gamma_0] = acc_Sin
            Data_Sin_factorized[k, 1, np.shape(acc_Sin)[0] :, ind_gamma_0] = acc_Sin[
                -1
            ] * np.ones(num_iter_Sin - np.shape(acc_Sin)[0])

        results = CodeGW.GW_entropic_distance2(
            D1,
            D2,
            reg,
            a,
            b,
            Init=Init,
            seed_init=seed_init,
            I=num_iter_Sin - 1,
            delta_sin=delta_sin,
            num_iter_sin=10000,
            lam_sin=0,
            time_out=time_out,
            LSE=LSE,
        )
        if results != "Error":
            res, acc_Sin, times_Sin, num_ops_Sin, Couplings_Sin = results

            Arr_Coupling_Sin[k, ind_gamma_0, :, :] = Couplings_Sin[-1]

            Data_Sin[k, 0, : np.shape(num_ops_Sin)[0], ind_gamma_0] = num_ops_Sin
            Data_Sin[k, 0, np.shape(num_ops_Sin)[0] :, ind_gamma_0] = num_ops_Sin[
                -1
            ] * np.ones(num_iter_Sin - np.shape(num_ops_Sin)[0])

            Data_Sin[k, 1, : np.shape(acc_Sin)[0], ind_gamma_0] = acc_Sin
            Data_Sin[k, 1, np.shape(acc_Sin)[0] :, ind_gamma_0] = acc_Sin[-1] * np.ones(
                num_iter_Sin - np.shape(acc_Sin)[0]
            )

        print("ok Sinkhorn for gamma_0: " + str(gamma_0))

        ###### LR methods #######
        for ind_rank, rank in enumerate(ranks):

            reg = 0
            gamma_init = "arbitrary"
            alpha = 1e-10
            reg_init = 1e-1
            method = "Dykstra"
            C_init = True

            cost_SE_factorized = (D11, D12, D21, D22)
            results = CodeGW.Lin_LGW_MD(
                X,
                Y,
                a,
                b,
                rank,
                reg,
                alpha,
                cost_SE_factorized,
                C_init=C_init,
                Init=Init,
                seed_init=seed_init,
                reg_init=reg_init,
                gamma_init=gamma_init,
                gamma_0=gamma_0,
                method=method,
                max_iter=num_iter_LR - 1,
                delta=delta_LR,
                max_iter_IBP=10000,
                delta_IBP=delta_IBP,
                lam_IBP=0,
                time_out=time_out,
            )

            if results != "Error":
                res_1, acc_LR, times_LR, num_ops_LR, Couplings_LR = results

                Q, R, g = Couplings_LR[-1]
                P = np.dot(Q / g, R.T)
                Arr_Coupling_LR_factorized[k, ind_gamma_0, ind_rank, :, :] = P

                Data_LR_factorized[
                    k, ind_gamma_0, ind_rank, : np.shape(num_ops_LR)[0]
                ] = num_ops_LR
                Data_LR_factorized[
                    k, ind_gamma_0, ind_rank, np.shape(num_ops_LR)[0] : num_iter_LR
                ] = num_ops_LR[-1] * np.ones(num_iter_LR - np.shape(num_ops_LR)[0])

                Data_LR_factorized[
                    k,
                    ind_gamma_0,
                    ind_rank,
                    num_iter_LR : num_iter_LR + np.shape(num_ops_LR)[0],
                ] = acc_LR
                Data_LR_factorized[
                    k,
                    ind_gamma_0,
                    ind_rank,
                    num_iter_LR + np.shape(acc_LR)[0] : 2 * num_iter_LR,
                ] = acc_LR[-1] * np.ones(num_iter_LR - np.shape(acc_LR)[0])

            cost_SE = (D1, D2)
            results = CodeGW.Quad_LGW_MD(
                X,
                Y,
                a,
                b,
                rank,
                reg,
                alpha,
                cost_SE,
                C_init=C_init,
                Init=Init,
                seed_init=seed_init,
                reg_init=reg_init,
                gamma_init=gamma_init,
                gamma_0=gamma_0,
                method=method,
                max_iter=num_iter_LR - 1,
                delta=delta_LR,
                max_iter_IBP=10000,
                delta_IBP=delta_IBP,
                lam_IBP=0,
                time_out=time_out,
            )

            if results != "Error":
                res_1, acc_LR, times_LR, num_ops_LR, Couplings_LR = results

                Q, R, g = Couplings_LR[-1]
                P = np.dot(Q / g, R.T)

                Arr_Coupling_LR[k, ind_gamma_0, ind_rank, :, :] = P

                Data_LR[
                    k, ind_gamma_0, ind_rank, : np.shape(num_ops_LR)[0]
                ] = num_ops_LR
                Data_LR[
                    k, ind_gamma_0, ind_rank, np.shape(num_ops_LR)[0] : num_iter_LR
                ] = num_ops_LR[-1] * np.ones(num_iter_LR - np.shape(num_ops_LR)[0])

                Data_LR[
                    k,
                    ind_gamma_0,
                    ind_rank,
                    num_iter_LR : num_iter_LR + np.shape(num_ops_LR)[0],
                ] = acc_LR
                Data_LR[
                    k,
                    ind_gamma_0,
                    ind_rank,
                    num_iter_LR + np.shape(acc_LR)[0] : 2 * num_iter_LR,
                ] = acc_LR[-1] * np.ones(num_iter_LR - np.shape(acc_LR)[0])

    stri = "ok LR, gamma_0 = " + str(gamma_0)
    print(stri)


with open("Couplings_LR.npy", "wb") as f:
    np.save(f, Arr_Coupling_LR)

with open("Couplings_LR_factorized.npy", "wb") as f:
    np.save(f, Arr_Coupling_LR_factorized)

with open("Couplings_Sin.npy", "wb") as f:
    np.save(f, Arr_Coupling_Sin)

with open("Couplings_Sin_factorized.npy", "wb") as f:
    np.save(f, Arr_Coupling_Sin_factorized)


with open("GW_init.csv", "w") as file:
    for q in range(num_experiments_blocs):
        GW_init = Data_GW_init[q]
        s = (
            "num_bloc: " + str(list_num_clusters[q]) + ","
            "GW_init" + "," + str(GW_init) + "\n"
        )

        file.write(s)
        file.flush()


with open("acc_OT_vs_time_Sinkhorn.csv", "w") as file:
    for q in range(num_experiments_blocs):
        for k in range(num_gammas_init):
            tim, acc = Data_Sin[q, 0, :, k], Data_Sin[q, 1, :, k]

            s1 = ",".join(str(e) for e in tim)
            s2 = ",".join(str(e) for e in acc)

            s = (
                "num_bloc: "
                + str(list_num_clusters[q])
                + ","
                + "reg: "
                + str(1 / gammas_init[k])
                + ","
                + s1
                + ","
                + s2
                + "\n"
            )

            file.write(s)
            file.flush()


with open("acc_OT_vs_time_Sinkhorn_factorized.csv", "w") as file:
    for q in range(num_experiments_blocs):
        for k in range(num_gammas_init):
            tim, acc = Data_Sin_factorized[q, 0, :, k], Data_Sin_factorized[q, 1, :, k]

            s1 = ",".join(str(e) for e in tim)
            s2 = ",".join(str(e) for e in acc)

            s = (
                "num_bloc: "
                + str(list_num_clusters[q])
                + ","
                + "reg: "
                + str(1 / gammas_init[k])
                + ","
                + s1
                + ","
                + s2
                + "\n"
            )

            file.write(s)
            file.flush()


with open("acc_OT_vs_time_LR.csv", "w") as file:
    for q in range(num_experiments_blocs):
        for k in range(num_gammas_init):
            for j in range(mod):
                tim, acc = (
                    Data_LR[q, k, j, :num_iter_LR],
                    Data_LR[q, k, j, num_iter_LR : 2 * num_iter_LR],
                )

                s1 = ",".join(str(e) for e in tim)
                s2 = ",".join(str(e) for e in acc)

                s = (
                    "num_bloc: "
                    + str(list_num_clusters[q])
                    + ","
                    + "gamma_0: "
                    + str(gammas_init[k])
                    + ","
                    + "num_RF: "
                    + str(ranks[j])
                    + ","
                    + s1
                    + ","
                    + s2
                    + "\n"
                )

                file.write(s)
                file.flush()


with open("acc_OT_vs_time_LR_factorized.csv", "w") as file:
    for q in range(num_experiments_blocs):
        for k in range(num_gammas_init):
            for j in range(mod):
                tim, acc = (
                    Data_LR_factorized[q, k, j, :num_iter_LR],
                    Data_LR_factorized[q, k, j, num_iter_LR : 2 * num_iter_LR],
                )

                s1 = ",".join(str(e) for e in tim)
                s2 = ",".join(str(e) for e in acc)

                s = (
                    "num_bloc: "
                    + str(list_num_clusters[q])
                    + ","
                    + "gamma_0: "
                    + str(gammas_init[k])
                    + ","
                    + "num_RF: "
                    + str(ranks[j])
                    + ","
                    + s1
                    + ","
                    + s2
                    + "\n"
                )

                file.write(s)
                file.flush()


end = time.time()
total = end - start
print(total)

## To load the file containing the time
# with open("Time_Sin_10.txt", "rb") as fp:   # Unpickling
#     b = pickle.load(fp)
#
# To load the file containing the matrix
# with open('Comparison_Plans_100.npy', 'rb') as f:
#     a = np.load(f)
