import numpy as np
import time
from mpl_toolkits.mplot3d import Axes3D  # noqa

import CodeGW


def curve_2d_3d(num_samples):
    theta = np.linspace(-4 * np.pi, 4 * np.pi, num_samples)
    z = np.linspace(1, 2, num_samples)
    r = z ** 2 + 1
    x = r * np.sin(theta)
    y = r * np.cos(theta)

    X = np.concatenate([x.reshape(-1, 1), z.reshape(-1, 1)], axis=1)
    Y = np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], axis=1)

    return X, Y


num_samples = 10000  # nb samples
X, Y = curve_2d_3d(num_samples)


a = (1 / num_samples) * np.ones(num_samples)
b = (1 / num_samples) * np.ones(num_samples)

Init = "lower_bound"
seed_init = 10
delta_LR = 1e-100
delta_IBP = 1e-3
delta_sin = 1e-3
LSE = False
time_out = 100

Square_Euclidean_cost = lambda X, Y: CodeGW.Square_Euclidean_Distance(X, Y)
Square_Euclidean_factorized_cost = lambda X, Y: CodeGW.factorized_square_Euclidean(X, Y)

cost = Square_Euclidean_cost
cost_factorized = Square_Euclidean_factorized_cost


D11, D12 = cost_factorized(X, X)
D21, D22 = cost_factorized(Y, Y)

r11, r12 = D11.max(), D12.max()
r21, r22 = D21.max(), D22.max()

D11, D12 = D11 / np.sqrt(r11), D12 / np.sqrt(r12)
D21, D22 = D21 / np.sqrt(r21), D22 / np.sqrt(r22)


GW_init = CodeGW.GW_Init(D11, D12, D21, D22, a, b)
with open("GW_init.csv", "w") as file:
    s = "GW_init" + "," + str(GW_init) + "\n"

    file.write(s)
    file.flush()


ranks = [10, 50, 100]
mod = len(ranks)

gammas_init = [500, 250, 100, 50, 30, 10, 1]
num_gammas_init = len(gammas_init)


num_iter_LR = 1000
num_iter_Sin = 1000


Data_Sin = np.zeros((2, num_iter_Sin, num_gammas_init))
Data_LR = np.zeros((num_gammas_init, mod, 2 * num_iter_LR))

Arr_Coupling_Sin = np.zeros((num_gammas_init, num_samples, num_samples))
Arr_Coupling_LR = np.zeros((num_gammas_init, mod, num_samples, num_samples))

start = time.time()


# ###### Sinkhorn Algorithm #######
for ind_gamma, gamma_0 in enumerate(gammas_init):
    reg = 1 / gamma_0
    results = CodeGW.Quad_GW_entropic_distance2(
        D11,
        D12,
        D21,
        D22,
        reg,
        a,
        b,
        Init=Init,
        seed_init=seed_init,
        I=num_iter_Sin - 1,
        delta_sin=delta_sin,
        num_iter_sin=10000,
        lam_sin=0,
        time_out=time_out,
        LSE=LSE,
    )
    if results != "Error":
        res, acc_Sin, times_Sin, num_ops_Sin, Couplings_Sin = results

        Arr_Coupling_Sin[ind_gamma, :, :] = Couplings_Sin[-1]

        Data_Sin[0, : np.shape(num_ops_Sin)[0], ind_gamma] = num_ops_Sin
        Data_Sin[0, np.shape(num_ops_Sin)[0] :, ind_gamma] = num_ops_Sin[-1] * np.ones(
            num_iter_Sin - np.shape(num_ops_Sin)[0]
        )

        Data_Sin[1, : np.shape(acc_Sin)[0], ind_gamma] = acc_Sin
        Data_Sin[1, np.shape(acc_Sin)[0] :, ind_gamma] = acc_Sin[-1] * np.ones(
            num_iter_Sin - np.shape(acc_Sin)[0]
        )

    print("ok Sinkhorn for gamma_0: " + str(gamma_0))

    ###### LR method #######
    for ind_rank, rank in enumerate(ranks):

        gamma_init = "arbitrary"
        alpha = 1e-10
        reg_init = 1e-1
        method = "Dykstra"
        C_init = True
        cost_SE = (D11, D12, D21, D22)

        reg = 0.0
        results = CodeGW.Lin_LGW_MD(
            X,
            Y,
            a,
            b,
            rank,
            reg,
            alpha,
            cost_SE,
            C_init=C_init,
            Init=Init,
            seed_init=seed_init,
            reg_init=reg_init,
            gamma_init=gamma_init,
            gamma_0=gamma_0,
            method=method,
            max_iter=num_iter_LR - 1,
            delta=delta_LR,
            max_iter_IBP=10000,
            delta_IBP=delta_IBP,
            lam_IBP=0,
            time_out=time_out,
        )

        if results != "Error":
            res_1, acc_LR, times_LR, num_ops_LR, Couplings_LR = results
            Q, R, g = Couplings_LR[-1]
            P = np.dot(Q / g, R.T)

            Arr_Coupling_LR[ind_gamma, ind_rank, :, :] = P

            Data_LR[ind_gamma, ind_rank, : np.shape(num_ops_LR)[0]] = num_ops_LR
            Data_LR[
                ind_gamma, ind_rank, np.shape(num_ops_LR)[0] : num_iter_LR
            ] = num_ops_LR[-1] * np.ones(num_iter_LR - np.shape(num_ops_LR)[0])

            Data_LR[
                ind_gamma, ind_rank, num_iter_LR : num_iter_LR + np.shape(num_ops_LR)[0]
            ] = acc_LR
            Data_LR[
                ind_gamma, ind_rank, num_iter_LR + np.shape(acc_LR)[0] : 2 * num_iter_LR
            ] = acc_LR[-1] * np.ones(num_iter_LR - np.shape(acc_LR)[0])

    stri = "ok LR, gamma_0 = " + str(gamma_0)
    print(stri)


with open("Couplings_LR.npy", "wb") as f:
    np.save(f, Arr_Coupling_LR)

with open("Couplings_Sin.npy", "wb") as f:
    np.save(f, Arr_Coupling_Sin)


with open("acc_GW_vs_time_Sinkhorn.csv", "w") as file:
    for k in range(num_gammas_init):
        tim, acc = Data_Sin[0, :, k], Data_Sin[1, :, k]

        s1 = ",".join(str(e) for e in tim)
        s2 = ",".join(str(e) for e in acc)

        s = "reg: " + str(1 / gammas_init[k]) + "," + s1 + "," + s2 + "\n"

        file.write(s)
        file.flush()


with open("acc_OT_vs_time_LR.csv", "w") as file:
    for k in range(num_gammas_init):
        for j in range(mod):
            tim, acc = (
                Data_LR[k, j, :num_iter_LR],
                Data_LR[k, j, num_iter_LR:],
            )

            s1 = ",".join(str(e) for e in tim)
            s2 = ",".join(str(e) for e in acc)

            s = (
                "gamma_0: "
                + str(gammas_init[k])
                + ","
                + "num_RF: "
                + str(ranks[j])
                + ","
                + s1
                + ","
                + s2
                + "\n"
            )

            file.write(s)
            file.flush()


end = time.time()
total = end - start
print(total)
