import numpy as np
import time
from mpl_toolkits.mplot3d import Axes3D  # noqa

import CodeGW


def Mixture_of_Gaussians(num_samples, sigma, dimension1, dimension2, seed=49):
    nX1 = int(num_samples / 3)
    nX2 = nX1
    nX3 = num_samples - 2 * nX1

    cov1 = sigma * np.eye(dimension1)

    mean_X1 = np.zeros(dimension1)
    mean_X2 = np.zeros(dimension1)
    mean_X2[1] = 1
    mean_X3 = np.zeros(dimension1)
    mean_X3[0], mean_X3[1] = 1, 1

    X1 = np.random.multivariate_normal(mean_X1, cov1, nX1)
    X2 = np.random.multivariate_normal(mean_X2, cov1, nX2)
    X3 = np.random.multivariate_normal(mean_X3, cov1, nX3)

    X = np.concatenate([X1, X2, X3], axis=0)

    nY1 = int(num_samples / 2)
    nY2 = num_samples - nY1

    mean_Y1 = np.zeros(dimension2)
    mean_Y1[0], mean_Y1[1] = 0.5, 0.5

    mean_Y2 = np.zeros(dimension2)
    mean_Y2[0], mean_Y2[1] = -0.5, 0.5

    cov2 = sigma * np.eye(dimension2)

    Y1 = np.random.multivariate_normal(mean_Y1, cov2, nY1)
    Y2 = np.random.multivariate_normal(mean_Y2, cov2, nY2)

    Y = np.concatenate([Y1, Y2], axis=0)

    return X, Y


## Two Mixture of Gaussians
n, m = 5000, 5000  # nb samples
dimX, dimY = 5, 10
sigma = 0.05
dimX, dimY = 10, 15
X, Y = Mixture_of_Gaussians(n, sigma, dimX, dimY, seed=49)


# fig = pl.figure()
# ax1 = fig.add_subplot(121)
# ax1.plot(X[:, 0], X[:, 1], '+b', label='Source samples')
# ax2 = fig.add_subplot(122, projection='3d')
# ax2.scatter(Y[:, 0], Y[:, 1], Y[:, 2], color='r')
# pl.show()

### Cost matrices #####
Square_Euclidean_cost = lambda X, Y: CodeGW.Square_Euclidean_Distance(X, Y)
Square_Euclidean_factorized_cost = lambda X, Y: CodeGW.factorized_square_Euclidean(X, Y)
Euclidean_cost = lambda X, Y: CodeGW.Euclidean_Distance(X, Y)
L1_cost = lambda X, Y: CodeGW.Lp_Distance(X, Y, p=1)
L3_cost = lambda X, Y: CodeGW.Lp_Distance(X, Y, p=3)

cost = Square_Euclidean_cost
cost_factorized = Square_Euclidean_factorized_cost

D11, D12 = cost_factorized(X, X)
D21, D22 = cost_factorized(Y, Y)

r11, r12 = D11.max(), D12.max()
r21, r22 = D21.max(), D22.max()


D11, D12 = D11 / np.sqrt(r11), D12 / np.sqrt(r12)
D21, D22 = D21 / np.sqrt(r21), D22 / np.sqrt(r22)


a, b = (1 / n) * np.ones(n), (1 / m) * np.ones(m)

#############################################


list_ranks = [50]  # [10,50,100]
list_gammas = [1000, 500, 250, 100, 50, 30, 10, 1]

max_iter = 500
time_out = 50

gamma_init = "arbitrary"
method = "Dykstra"  #'Dykstra_LSE'
alpha = 1e-10
Init = "lower_bound"
seed_init = 10
reg_init = 1e-1
delta = 1e-100
delta_IBP = 1e-3
C_init = True
cost_SE = (D11, D12, D21, D22)
# cost_SE = (D1,D2)

results_gamma_vs_epsilons = np.zeros((len(list_ranks), len(list_gammas), 7))
Data_LR = np.zeros((len(list_gammas), 7, 2 * max_iter))

start = time.time()
for ind_rank, rank in enumerate(list_ranks):
    for ind_gamma_0, gamma_0 in enumerate(list_gammas):
        list_epsilons = [
            1 / (gamma_0),
            1 / (2 * gamma_0),
            1 / (10 * gamma_0),
            1 / (50 * gamma_0),
            1 / (100 * gamma_0),
            1 / (1000 * gamma_0),
            0,
        ]
        for ind_reg, reg in enumerate(list_epsilons):
            results = CodeGW.Lin_LGW_MD(
                X,
                Y,
                a,
                b,
                rank,
                reg,
                alpha,
                cost_SE,
                C_init=C_init,
                Init=Init,
                seed_init=seed_init,
                reg_init=reg_init,
                gamma_init=gamma_init,
                gamma_0=gamma_0,
                method=method,
                max_iter=max_iter - 1,
                delta=delta,
                max_iter_IBP=10000,
                delta_IBP=delta_IBP,
                lam_IBP=0,
                time_out=time_out,
            )

            if results != "Error":
                res_1, acc_LR, times_LR, num_ops_LR, Couplings_LR = results
                results_gamma_vs_epsilons[ind_rank, ind_gamma_0, ind_reg] = res_1

                Data_LR[ind_gamma_0, ind_reg, : np.shape(num_ops_LR)[0]] = num_ops_LR
                Data_LR[
                    ind_gamma_0, ind_reg, np.shape(num_ops_LR)[0] : max_iter
                ] = num_ops_LR[-1] * np.ones(max_iter - np.shape(num_ops_LR)[0])

                Data_LR[
                    ind_gamma_0, ind_reg, max_iter : max_iter + np.shape(num_ops_LR)[0]
                ] = acc_LR
                Data_LR[
                    ind_gamma_0, ind_reg, max_iter + np.shape(acc_LR)[0] : 2 * max_iter
                ] = acc_LR[-1] * np.ones(max_iter - np.shape(acc_LR)[0])

        print(f"rank = {rank}, gamma_0 = {gamma_0}")

end = time.time()
tot_time = end - start
print(tot_time)

list_to_multiply_for_reg = [1, 1 / 2, 1 / 10, 1 / 50, 1 / 100, 1 / 1000, 0]

with open("acc_OT_vs_time_LR.csv", "w") as file:
    for k in range(len(list_gammas)):
        for j in range(7):
            tim, acc = (
                Data_LR[k, j, :max_iter],
                Data_LR[k, j, max_iter : 2 * max_iter],
            )

            s1 = ",".join(str(e) for e in tim)
            s2 = ",".join(str(e) for e in acc)

            s = (
                "gamma_0: "
                + str(list_gammas[k])
                + ","
                + "reg: "
                + str(list_to_multiply_for_reg[j] * (1 / list_gammas[k]))
                + ","
                + s1
                + ","
                + s2
                + "\n"
            )

            file.write(s)
            file.flush()


with open("results_2_mixtures_rank_50.npy", "wb") as f:
    np.save(f, results_gamma_vs_epsilons)
