
from time import time

import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torch import linalg
import random
import numpy as np


from baselines import RGD_QR, PCAL, TSD, Landing
from proposed import RSDM

import pickle

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


def procruste_loss(X):
    # return - (X * BAt).sum()/p
    # return -torch.trace(X.transpose(-1,-2) @ BAt)/p
    temp = X @ A - B
    return (temp*temp).sum()/(2*p)

# def procruste_infea_loss(X):
#     return torch.norm(X.mm(A) - B).pow(2)/(2*p)

def procruste_optgap(X):
    return abs(procruste_loss(X).item() - loss_star)/abs(loss_star)


def procruste_dist2ortho(X):
    temp = X.t() @ X - torch.eye(X.shape[1], device=X.device, dtype=X.dtype)
    return temp.norm()


def RCD_step(ii, jj, lr):
    grad = -BAt/p
    gradi = grad[ii,:]
    gradj = grad[jj,:]
    XVt = torch.inner(gradi, W.data[jj, :])
    VXt = torch.inner(gradj, W.data[ii, :])
    eta = -lr * (XVt - VXt)
    vi = torch.cos(eta) * W.data[ii, :] + torch.sin(eta) * W.data[jj, :]
    vj = -torch.sin(eta) * W.data[ii, :] + torch.cos(eta) * W.data[jj, :]
    W.data[ii, :] = vi
    W.data[jj, :] = vj


def TSD_step(ii, jj, lr):
    grad = -BAt/p
    if ii != jj:
        gradi = grad[:,ii]
        gradj = grad[:,jj]
        UXtij = torch.inner(W.data[:, ii], gradj)
        UXtji = torch.inner(W.data[:, jj], gradi)

        eta = - lr * (UXtij - UXtji)
        vi = torch.cos(eta) * W.data[:, ii] - torch.sin(eta) * W.data[:,jj]
        vj = torch.sin(eta) * W.data[:, ii] + torch.cos(eta) * W.data[:,jj]

        W.data[:, ii] = vi
        W.data[:, jj] = vj
    else:
        gradi = grad[:, ii]
        pgrad = gradi - W.data @ (W.data.t() @ gradi )
        dir = - lr * pgrad
        npgrad = torch.norm(dir)
        W.data[:, ii] = torch.cos(npgrad) * W.data[:, ii] + torch.sin(npgrad) * (dir /npgrad)



def load_args(case):
    extra_args = {}

    # lr selected every 0.5
    if case == 1:
        n = 1000
        p = 1000
        lowdim = 500
        method_names = ["RGD", "Landing", "RSDM"]
        methods = [RGD_QR, Landing, RSDM]
        learning_rates = [0.5, 0.5, 1.5]
        n_epochs = [2000, 2000, 2000]

    elif case == 2:
        n = 2000
        p = 1000
        lowdim = 600
        method_names = ["RGD", "Landing", "RSDM"]
        methods = [RGD_QR, Landing, RSDM]
        learning_rates = [0.5, 0.5, 2]
        n_epochs = [2000, 3000, 3000]

    elif case == 3:
        # THIS
        n = 2000
        p = 2000
        lowdim = 900
        method_names = ["RGD", "Landing", "PCAL", "RSDM-P", "RSDM-O"]
        methods = [RGD_QR, Landing, PCAL, RSDM, RSDM]
        learning_rates = [0.5, 0.5, 0.1, 2, 2]
        n_epochs = [2000, 2000, 2000, 3000, 3000]
        extra_args["lam_land"] = 1
        extra_args["lam_pcal"] = 5

    elif case == 4:
        # THIS
        n = 200
        p = 200
        lowdim = 150
        method_names = ["RGD", "Landing", "PCAL", "RCD", "RSDM-P", "RSDM-O"]
        methods = [RGD_QR, Landing, PCAL, "", RSDM, RSDM]
        learning_rates = [0.5, 0.5, 0.1, 1, 1, 1]
        n_epochs = [1000, 1500, 1500, 50, 1000, 1000]

        extra_args["lam_land"] = 1
        extra_args["lam_pcal"] = 5
        extra_args["max_cd_update"] = int(n * (n-1)/4)

    elif case == 5:
        n = 200
        p = 100
        lowdim = 50
        method_names = ["RGD", "RCD", "TSD", "RSDM"]
        methods = [RGD_QR, None, None, RSDM]


    return n, p, lowdim, method_names, methods, learning_rates, n_epochs, extra_args



if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')
    print(device)

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # load args
    n, p, lowdim, method_names, methods, learning_rates, n_epochs, extra_args = load_args(3)

    # # # create data
    # A = torch.randn(p, p).to(device)
    # B = torch.randn(n, p).to(device)
    # init_weights = linalg.qr(torch.randn(n, p))[0]
    #
    # # Compute closed-form solution from svd, used for monitoring.
    # BAt = B.matmul(A.transpose(-1, -2))
    # u, _, vh = torch.linalg.svd(BAt, full_matrices=False)
    # # u, _, v = torch.svd(BAt)
    # w_star = u.matmul(vh)
    # loss_star = procruste_loss(w_star)
    # loss_star = loss_star.item()

    # all_results = {}
    #
    # time_epochs_all = []
    # losses_epochs_all = []
    # optgap_epochs_all = []
    # colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:purple', 'tab:red']
    # markers = ['*', 'o', '+', 'd', '>']
    # markers = [''] * 5


    # tolerance = 1e-6
    #
    #
    # if not isinstance(n_epochs, list):
    #     n_epochs = [n_epochs] * len(method_names)
    #
    # for method_name, method, learning_rate, n_epoch in zip(method_names, methods, learning_rates, n_epochs):
    #
    #     W = nn.Parameter(torch.empty(n, p))
    #     W.data = init_weights.clone().to(device)
    #
    #     if method_name == 'RCD':
    #         idx1, idx2 = np.triu_indices(n, k=1)
    #     elif method_name == 'TSD':
    #         idx1, idx2 = np.triu_indices(p)
    #     elif method_name == 'Landing':
    #         optimizer = method([W], learning_rate, extra_args["lam_land"])
    #     elif method_name == 'PLAM':
    #         optimizer = method([W], learning_rate, lambda_regul=0.1)
    #     elif method_name == 'RSDM-P':
    #         optimizer = method([W], learning_rate, r=lowdim, use_permutation=True)
    #     elif method_name == 'RSDM-O':
    #         optimizer = method([W], learning_rate, r=lowdim, use_permutation=False)
    #     elif method_name == 'PCAL':
    #         optimizer = method([W], learning_rate, extra_args["lam_pcal"])
    #     else:
    #         optimizer = method([W], learning_rate)
    #
    #     # init
    #     loss_ = procruste_loss(W.data).cpu().item()
    #     dist2opt = procruste_optgap(W.data)
    #     losses = [loss_]
    #     time_epochs = [0]
    #     flop_epochs = [0]
    #     optgap_epochs = [dist2opt]
    #     print(
    #         "|%s|    time for an epoch : %.1e sec, Loss: %.2e, optgap: %.2e"
    #         % (method_name, time_epochs[-1], loss_, dist2opt)
    #     )
    #
    #     counter = 0
    #     for epoch in range(n_epoch):
    #         # train
    #         t0 = time()
    #
    #         if method_name == 'RCD':
    #             seqlist = np.random.permutation(int(n * (n-1)/2))
    #             seqlist = seqlist[:extra_args["max_cd_update"]]
    #             for k in seqlist:
    #                 ii = idx1[k]
    #                 jj = idx2[k]
    #                 RCD_step(ii, jj, learning_rate)
    #                 counter += 1
    #
    #         elif method_name == 'TSD':
    #             pp = int(p*(p+1)/2)
    #             seqlist = np.random.permutation(pp)
    #             seqlist = seqlist[:extra_args["max_cd_update"]]
    #             for k in seqlist:
    #                 ii = idx1[k]
    #                 jj = idx2[k]
    #                 TSD_step(ii,jj, learning_rate)
    #                 counter += 1
    #
    #         else:
    #             optimizer.zero_grad()
    #             loss_ = procruste_loss(W)
    #             loss_.backward()
    #             optimizer.step()
    #
    #         # test
    #         loss_ = procruste_loss(W.data).cpu().item()
    #         dist2opt = procruste_optgap(W.data)
    #         dist2ortho = procruste_dist2ortho(W.data)
    #         time_epochs.append(time() - t0)
    #         losses.append(loss_)
    #         optgap_epochs.append(dist2opt)
    #         print(
    #             "|%s|    time for an epoch : %.1e sec, Loss: %.2e, optgap: %.2e, dist2ortho: %.2e"
    #             % (method_name, time_epochs[-1], loss_, dist2opt, dist2ortho)
    #         )
    #         if dist2opt < tolerance:
    #             print(f"Tolerance reached at iter {epoch}, break!")
    #             break
    #
    #     all_results[method_name] = {'time': time_epochs, 'loss': losses, 'optgap': optgap_epochs}
    #
    #     # time_epochs_all.append(time_epochs)
    #     # losses_epochs_all.append(losses)
    #     # optgap_epochs_all.append(optgap_epochs)
    #
    # with open(f'results/procruste_{n}_{p}_{lowdim}.pkl', 'wb') as f:
    #     pickle.dump(all_results, f)


    #

    with open(f'results/procruste_{n}_{p}_{lowdim}.pkl', 'rb') as f:
        all_results = pickle.load(f)

    colors_map = {"RGD": "tab:blue", "Landing": "tab:orange", "PCAL": "tab:green", "RCD": "tab:red",
                  "RSDM-P": "tab:purple", "RSDM-O": "tab:pink"}

    # runtime
    # plt.figure(figsize=(5.5, 4.5))
    # for method_name in all_results:
    #
    #     method_results = all_results[method_name]
    #
    #     plt.semilogy(
    #         torch.cumsum(torch.tensor(method_results['time']), dim=0),
    #         method_results['optgap'],
    #         label=method_name,
    #         color=colors_map[method_name],
    #         linewidth=2.5,
    #     )
    # plt.legend(loc=1, prop={'size': 16})
    # plt.xticks(fontsize=13)
    # plt.yticks(fontsize=13)
    # plt.xlabel("Time", fontsize=20)
    # plt.ylabel("Optimality Gap", fontsize=20)
    #
    # if n == 200:
    #     inset_axes = plt.axes((0.22, 0.17, 0.3, 0.3))
    #     for method_name in all_results:
    #         method_results = all_results[method_name]
    #         # plot zoomed in
    #         inset_axes.semilogy(torch.cumsum(torch.tensor(method_results['time']), dim=0), method_results['optgap'],
    #                         color=colors_map[method_name], linewidth=2)
    #     inset_axes.set_xlim(0., 2)  # Zoom in on the x-axis from 2 to 6
    #     inset_axes.set_ylim(3E-5, 1E-2)
    #     inset_axes.set_yticklabels([])
    #
    # # plt.tight_layout()
    # plt.savefig(f'st_{n}_{p}_{lowdim}.pdf', bbox_inches='tight')
    # # plt.show()
    # plt.close()


    # iteration
    plt.figure(figsize=(5.5, 4.5))
    for method_name in all_results:
        method_results = all_results[method_name]

        plt.semilogy(
            method_results['optgap'],
            label=method_name,
            color=colors_map[method_name],
            linewidth=2.5,
        )
    plt.legend(loc=1, prop={'size': 16})
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    plt.xlabel("Iteration", fontsize=20)
    plt.ylabel("Optimality Gap", fontsize=20)
    # plt.tight_layout()
    plt.savefig(f'st_{n}_{p}_{lowdim}_iter.pdf', bbox_inches='tight')
    # plt.show()
    plt.close()