import numpy as np
import time
from mpl_toolkits.mplot3d import Axes3D  # noqa
import sys

sys.path.append("./SCOT-code/src/")
import scot2 as sc

import CodeGW


# path = 'EXP_GW/SCOT/SCOT-code/'
path = "SCOT-code/"
#####################   scGEM Dataset   ###########################
X_trans = np.genfromtxt(path + "data/scGEM_methylation.csv", delimiter=",")
y_trans = np.genfromtxt(path + "data/scGEM_expression.csv", delimiter=",")
print("Dimensions of input datasets are: ", "X= ", X_trans.shape, " y= ", y_trans.shape)

num_samples = np.shape(X_trans)[0]
## Normalize row by row the datasets
scot = sc.SCOT(X_trans, y_trans)
scot.normalize(norm="l2")
X = scot.X
Y = scot.y[0]

## Set the marginal to uniform distributions
a, b = (
    np.ones(np.shape(X)[0]) / np.shape(X)[0],
    np.ones(np.shape(Y)[0]) / np.shape(Y)[0],
)

## Compute the distance matrices based on shortest path distance
k = 35
scot.init_distances(k)
D1 = scot.Cx[0]
D2 = scot.Cy[0]


ranks = [10, 50, 100]
mod = len(ranks)

gammas_init = [1000, 500, 250, 100, 50, 30, 10, 1]
num_gammas_init = len(gammas_init)

num_iter_LR = 300
num_iter_Sin = 300


Data_Sin = np.zeros((2, num_iter_Sin, num_gammas_init))
Data_LR = np.zeros((num_gammas_init, mod, 2 * num_iter_LR))

Arr_Coupling_Sin = np.zeros((num_gammas_init, num_samples, num_samples))
Arr_Coupling_LR = np.zeros((num_gammas_init, mod, num_samples, num_samples))


Init = "lower_bound"
seed_init = 10
delta_sin = 1e-3
delta_IBP = 1e-3
time_out = 50
LSE = False

start = time.time()

GW_init = CodeGW.GW_Init_Cubic(D1, D2, a, b)

###### Sinkhorn Algorithm #######
for ind_gamma, gamma_0 in enumerate(gammas_init):
    reg = 1 / gamma_0
    results = CodeGW.GW_entropic_distance2(
        D1,
        D2,
        reg,
        a,
        b,
        Init=Init,
        seed_init=seed_init,
        I=num_iter_Sin - 1,
        delta_sin=delta_sin,
        num_iter_sin=10000,
        lam_sin=0,
        time_out=time_out,
        LSE=LSE,
    )
    if results != "Error":
        res, acc_Sin, times_Sin, num_ops_Sin, Couplings_Sin = results

        Arr_Coupling_Sin[ind_gamma, :, :] = Couplings_Sin[-1]

        Data_Sin[0, : np.shape(num_ops_Sin)[0], ind_gamma] = num_ops_Sin
        Data_Sin[0, np.shape(num_ops_Sin)[0] :, ind_gamma] = num_ops_Sin[-1] * np.ones(
            num_iter_Sin - np.shape(num_ops_Sin)[0]
        )

        Data_Sin[1, : np.shape(acc_Sin)[0], ind_gamma] = acc_Sin
        Data_Sin[1, np.shape(acc_Sin)[0] :, ind_gamma] = acc_Sin[-1] * np.ones(
            num_iter_Sin - np.shape(acc_Sin)[0]
        )

    print("ok Sinkhorn for gamma_0: " + str(gamma_0))

    ###### LR methods #######
    for ind_rank, rank in enumerate(ranks):

        gamma_init = "arbitrary"
        alpha = 1e-10
        reg_init = 1e-1
        method = "Dykstra"
        C_init = True

        delta_LR = 1e-100

        cost_SE = (D1, D2)
        reg = 0
        results = CodeGW.Quad_LGW_MD(
            X,
            Y,
            a,
            b,
            rank,
            reg,
            alpha,
            cost_SE,
            C_init=C_init,
            Init=Init,
            seed_init=seed_init,
            reg_init=reg_init,
            gamma_init=gamma_init,
            gamma_0=gamma_0,
            method=method,
            max_iter=num_iter_LR - 1,
            delta=delta_LR,
            max_iter_IBP=10000,
            delta_IBP=delta_IBP,
            lam_IBP=0,
            time_out=time_out,
        )

        if results != "Error":
            res_1, acc_LR, times_LR, num_ops_LR, Couplings_LR = results

            Q, R, g = Couplings_LR[-1]
            P = np.dot(Q / g, R.T)

            Arr_Coupling_LR[ind_gamma, ind_rank, :, :] = P

            Data_LR[ind_gamma, ind_rank, : np.shape(num_ops_LR)[0]] = num_ops_LR
            Data_LR[
                ind_gamma, ind_rank, np.shape(num_ops_LR)[0] : num_iter_LR
            ] = num_ops_LR[-1] * np.ones(num_iter_LR - np.shape(num_ops_LR)[0])

            Data_LR[
                ind_gamma, ind_rank, num_iter_LR : num_iter_LR + np.shape(num_ops_LR)[0]
            ] = acc_LR
            Data_LR[
                ind_gamma, ind_rank, num_iter_LR + np.shape(acc_LR)[0] : 2 * num_iter_LR
            ] = acc_LR[-1] * np.ones(num_iter_LR - np.shape(acc_LR)[0])

    stri = "ok LR, gamma_0 = " + str(gamma_0)
    print(stri)


with open("Couplings_LR_exp1.npy", "wb") as f:
    np.save(f, Arr_Coupling_LR)


with open("Couplings_Sin_exp1.npy", "wb") as f:
    np.save(f, Arr_Coupling_Sin)


with open("GW_init_exp1.csv", "w") as file:
    s = "GW_init" + "," + str(GW_init) + "\n"

    file.write(s)
    file.flush()


with open("acc_OT_vs_time_Sinkhorn_exp1.csv", "w") as file:
    for k in range(num_gammas_init):
        tim, acc = Data_Sin[0, :, k], Data_Sin[1, :, k]

        s1 = ",".join(str(e) for e in tim)
        s2 = ",".join(str(e) for e in acc)

        s = "reg: " + str(1 / gammas_init[k]) + "," + s1 + "," + s2 + "\n"

        file.write(s)
        file.flush()


with open("acc_OT_vs_time_LR_exp1.csv", "w") as file:
    for k in range(num_gammas_init):
        for j in range(mod):
            tim, acc = (
                Data_LR[k, j, :num_iter_LR],
                Data_LR[k, j, num_iter_LR : 2 * num_iter_LR],
            )

            s1 = ",".join(str(e) for e in tim)
            s2 = ",".join(str(e) for e in acc)

            s = (
                "gamma_0: "
                + str(gammas_init[k])
                + ","
                + "num_RF: "
                + str(ranks[j])
                + ","
                + s1
                + ","
                + s2
                + "\n"
            )

            file.write(s)
            file.flush()


end = time.time()
total = end - start
print(total)


X_trans = np.load(path + "data/scatac_feat.npy")
y_trans = np.load(path + "data/scrna_feat.npy")
print("Dimensions of input datasets are: ", "X= ", X_trans.shape, " y= ", y_trans.shape)


num_samples = np.shape(X_trans)[0]
## Normalize row by row the datasets
scot = sc.SCOT(X_trans, y_trans)
scot.normalize(norm="l2")
X = scot.X
Y = scot.y[0]

## Set the marginal to uniform distributions
a, b = (
    np.ones(np.shape(X)[0]) / np.shape(X)[0],
    np.ones(np.shape(Y)[0]) / np.shape(Y)[0],
)

## Compute the distance matrices based on shortest path distance
k = 50
scot.init_distances(k)
D1 = scot.Cx[0]
D2 = scot.Cy[0]


Data_Sin = np.zeros((2, num_iter_Sin, num_gammas_init))
Data_LR = np.zeros((num_gammas_init, mod, 2 * num_iter_LR))

Arr_Coupling_Sin = np.zeros((num_gammas_init, num_samples, num_samples))
Arr_Coupling_LR = np.zeros((num_gammas_init, mod, num_samples, num_samples))


start = time.time()

GW_init = CodeGW.GW_Init_Cubic(D1, D2, a, b)

###### Sinkhorn Algorithm #######
for ind_gamma, gamma_0 in enumerate(gammas_init):
    reg = 1 / gamma_0
    results = CodeGW.GW_entropic_distance2(
        D1,
        D2,
        reg,
        a,
        b,
        Init=Init,
        seed_init=seed_init,
        I=num_iter_Sin - 1,
        delta_sin=delta_sin,
        num_iter_sin=10000,
        lam_sin=0,
        time_out=time_out,
        LSE=LSE,
    )
    if results != "Error":
        res, acc_Sin, times_Sin, num_ops_Sin, Couplings_Sin = results

        Arr_Coupling_Sin[ind_gamma, :, :] = Couplings_Sin[-1]

        Data_Sin[0, : np.shape(num_ops_Sin)[0], ind_gamma] = num_ops_Sin
        Data_Sin[0, np.shape(num_ops_Sin)[0] :, ind_gamma] = num_ops_Sin[-1] * np.ones(
            num_iter_Sin - np.shape(num_ops_Sin)[0]
        )

        Data_Sin[1, : np.shape(acc_Sin)[0], ind_gamma] = acc_Sin
        Data_Sin[1, np.shape(acc_Sin)[0] :, ind_gamma] = acc_Sin[-1] * np.ones(
            num_iter_Sin - np.shape(acc_Sin)[0]
        )

    print("ok Sinkhorn for problem: " + str(k))

    ###### LR methods #######
    for ind_rank, rank in enumerate(ranks):

        gamma_init = "arbitrary"
        alpha = 1e-10
        reg_init = 1e-1
        method = "Dykstra"
        C_init = True

        cost_SE = (D1, D2)
        reg = 0
        results = CodeGW.Quad_LGW_MD(
            X,
            Y,
            a,
            b,
            rank,
            reg,
            alpha,
            cost_SE,
            C_init=C_init,
            Init=Init,
            seed_init=seed_init,
            reg_init=reg_init,
            gamma_init=gamma_init,
            gamma_0=gamma_0,
            method=method,
            max_iter=num_iter_LR - 1,
            delta=delta_LR,
            max_iter_IBP=10000,
            delta_IBP=delta_IBP,
            lam_IBP=0,
            time_out=time_out,
        )

        if results != "Error":
            res_1, acc_LR, times_LR, num_ops_LR, Couplings_LR = results

            Q, R, g = Couplings_LR[-1]
            P = np.dot(Q / g, R.T)

            Arr_Coupling_LR[ind_gamma, ind_rank, :, :] = P

            Data_LR[ind_gamma, ind_rank, : np.shape(num_ops_LR)[0]] = num_ops_LR
            Data_LR[
                ind_gamma, ind_rank, np.shape(num_ops_LR)[0] : num_iter_LR
            ] = num_ops_LR[-1] * np.ones(num_iter_LR - np.shape(num_ops_LR)[0])

            Data_LR[
                ind_gamma, ind_rank, num_iter_LR : num_iter_LR + np.shape(num_ops_LR)[0]
            ] = acc_LR
            Data_LR[
                ind_gamma, ind_rank, num_iter_LR + np.shape(acc_LR)[0] : 2 * num_iter_LR
            ] = acc_LR[-1] * np.ones(num_iter_LR - np.shape(acc_LR)[0])

    stri = "ok LR, rank = " + str(rank)
    print(stri)


with open("Couplings_LR_exp2.npy", "wb") as f:
    np.save(f, Arr_Coupling_LR)


with open("Couplings_Sin_exp2.npy", "wb") as f:
    np.save(f, Arr_Coupling_Sin)


with open("GW_init_exp2.csv", "w") as file:
    s = "GW_init" + "," + str(GW_init) + "\n"

    file.write(s)
    file.flush()


with open("acc_OT_vs_time_Sinkhorn_exp2.csv", "w") as file:
    for k in range(num_gammas_init):
        tim, acc = Data_Sin[0, :, k], Data_Sin[1, :, k]

        s1 = ",".join(str(e) for e in tim)
        s2 = ",".join(str(e) for e in acc)

        s = "reg: " + str(1 / gammas_init[k]) + "," + s1 + "," + s2 + "\n"

        file.write(s)
        file.flush()


with open("acc_OT_vs_time_LR_exp2.csv", "w") as file:
    for k in range(num_gammas_init):
        for j in range(mod):
            tim, acc = (
                Data_LR[k, j, :num_iter_LR],
                Data_LR[k, j, num_iter_LR : 2 * num_iter_LR],
            )

            s1 = ",".join(str(e) for e in tim)
            s2 = ",".join(str(e) for e in acc)

            s = (
                "gamma_0: "
                + str(gammas_init[k])
                + ","
                + "num_RF: "
                + str(ranks[j])
                + ","
                + s1
                + ","
                + s2
                + "\n"
            )

            file.write(s)
            file.flush()


end = time.time()
total = end - start
print(total)
