# test different retractions



from time import time

import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torch import linalg
import random
import numpy as np


from baselines import RGD_GEN, TSD, Landing, PCAL
from proposed import  RSDM_GEN

import pickle

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"



def pca_loss(X):
    return -torch.trace(X.transpose(-1,-2) @ A @ X)/(2)

def pca_optgap(X):
    return abs(pca_loss(X).item() - loss_star)/abs(loss_star)



def load_args(case):
    if case == 1:
        # THIS
        n = 2000
        p = 1500
        lowdim = 700
        retrs_names = ["qr", "exp", "polar", "cayley", "qr", "exp", "polar", "cayley"]
        retrs = ["QR", "EXP", "POL", "CAY", "QR", "EXP", "POL", "CAY"]
        method_names = ["RGD", "RGD", "RGD", "RGD", "RSDM", "RSDM", "RSDM", "RSDM"]
        methods = [RGD_GEN, RGD_GEN, RGD_GEN,RGD_GEN, RSDM_GEN, RSDM_GEN, RSDM_GEN, RSDM_GEN]
        learning_rates = [0.1, 0.1, 0.1, 0.1, 1.5, 1.5, 1.5, 1.5]
        n_epochs = [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]

        extra_args = {"lam_land": 1, "lam_pcal": 1}




    # elif case == 2:
    #     n = 3000
    #     p = 2500
    #     lowdim = 1000
    #     method_names = ["RGD", "Landing", "PCAL", "RSDM-P", "RSDM-O"]
    #     methods = [RGD_QR, Landing, PCAL, RSDM, RSDM]
    #     learning_rates = [0.1, 0.1, 0.05, 1.5, 1.5]
    #     n_epochs = [1000, 1000, 1000, 1000, 1000]
    #
    #     extra_args = {"lam_land": 5, "lam_pcal": 5}

    return n, p, lowdim, method_names, methods, learning_rates, n_epochs, extra_args, retrs, retrs_names



def dist2ortho_(X):
    temp = X.t() @ X - torch.eye(X.shape[1], device=X.device, dtype=X.dtype)
    return temp.norm()



if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    n, p, lowdim, method_names, methods, learning_rates, n_epochs , extra_args, retrs, retrs_names = load_args(1)
    #
    # CN = 1000
    # D = 10 * torch.diag(torch.logspace(-np.log10(CN), 0, n))
    # [Q, R] = linalg.qr(torch.randn(n,n))
    # A = Q @ D @ Q
    # A = (A + A.t())/2
    #
    # A = A.to(device)
    # init_weights = linalg.qr(torch.randn(n, p))[0]
    #
    # # Compute closed-form solution from svd, used for monitoring.
    # [_, w_star] = torch.linalg.eigh(A/(2))
    # w_star = w_star[:,-p:]
    # loss_star = pca_loss(w_star)
    # loss_star = loss_star.item()
    #
    # all_results = {}
    #
    # # method_names = ["RSDM-P"]
    # # methods = [RSDM]
    # # learning_rates = [1.5]
    # # n_epochs = [1000]
    # # lowdim = 800
    #
    # time_epochs_all = []
    # losses_epochs_all = []
    # optgap_epochs_all = []
    #
    #
    # for method_name, method, learning_rate, n_epoch, retr in zip(method_names, methods, learning_rates, n_epochs, retrs):
    #
    #     W = nn.Parameter(torch.empty(n, p))
    #     W.data = init_weights.clone().to(device)
    #
    #     print(method_name, retr)
    #
    #     if method_name == 'RGD':
    #         optimizer = method([W], learning_rate, retraction=retr)
    #     elif method_name == 'RSDM':
    #         optimizer = method([W], learning_rate, r=lowdim, use_permutation=True, retraction=retr)
    #     else:
    #         raise ValueError
    #
    #     # init
    #     loss = pca_loss(W.data).cpu().item()
    #     dist2opt = pca_optgap(W.data)
    #     dist2ortho = dist2ortho_(W.data)
    #     losses = [loss]
    #     time_epochs = [0]
    #     optgap_epochs = [dist2opt]
    #     print(
    #         "|%s|    time for an epoch : %.1e sec, Loss: %.2e, optgap: %.2e, dist2ortho: %.2e"
    #         % (method_name, time_epochs[-1], loss, dist2opt, dist2ortho)
    #     )
    #
    #     for epoch in range(n_epoch):
    #         # train
    #         t0 = time()
    #
    #         optimizer.zero_grad()
    #         loss = pca_loss(W)
    #         loss.backward()
    #         optimizer.step()
    #
    #         # test
    #         loss = pca_loss(W.data).cpu().item()
    #         dist2opt = pca_optgap(W.data)
    #
    #         time_epochs.append(time() - t0)
    #         losses.append(loss)
    #         optgap_epochs.append(dist2opt)
    #         dist2ortho = dist2ortho_(W.data)
    #         print(
    #             "|%s|    time for an epoch : %.1e sec, Loss: %.2e, optgap: %.2e, dist2ortho: %.2e"
    #             % (method_name, time_epochs[-1], loss, dist2opt, dist2ortho)
    #         )
    #
    #         if dist2opt < 1e-6:
    #             print(f"Tolerance reached. Break at iter {epoch}!")
    #             break
    #
    #     # all_results[method_name] = {'time': time_epochs, 'loss': losses, 'optgap': optgap_epochs}
    #
    #     time_epochs_all.append(time_epochs)
    #     losses_epochs_all.append(losses)
    #     optgap_epochs_all.append(optgap_epochs)
    #
    # all_results = {'times': time_epochs_all, 'loss': losses_epochs_all, 'optgap': optgap_epochs_all,
    #                'retrs': retrs_names}
    #
    # with open(f'results/pca_{n}_{p}_{lowdim}_retr.pkl', 'wb') as f:
    #     pickle.dump(all_results, f)


    with open(f'results/pca_{n}_{p}_{lowdim}_retr.pkl', 'rb') as f:
        all_results = pickle.load(f)
        time_epochs_all = all_results['times']
        losses_epochs_all = all_results['loss']
        optgap_epochs_all = all_results['optgap']

    # colors_map = {"RGD": "tab:blue", "Landing": "tab:orange", "PCAL": "tab:green", "RCD": "tab:red",
    #              "RSDM-P": "tab:purple", "RSDM-O": "tab:pink"}

    colors = ["tab:blue", "tab:orange", "tab:green", "tab:red", "tab:blue", "tab:orange", "tab:green", "tab:red"]
    markers = ['', '', '', '', '-.', '-.', '-.', '-.']
    linestyles = ['-', '-', '-', '-', '-.', '-.', '-.', '-.']

    plt.figure(figsize=(5.5, 4.5))
    for method_name, time_epochs, optgaps, retr, color , marker, linestyle in \
            zip(method_names, time_epochs_all, optgap_epochs_all, retrs_names, colors, markers, linestyles):

        plt.semilogy(
            torch.cumsum(torch.tensor(time_epochs), dim=0),
            optgaps,
            label=method_name +' (' + retr + ')',
            # color=colors_map[method_name],
            color = color,
            linestyle = linestyle,
            linewidth=2.5,
        )
    plt.legend(loc=1, prop={'size': 16})
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    plt.xlabel("Time", fontsize=20)
    plt.ylabel("Optimality Gap", fontsize=20)
    plt.tight_layout()
    plt.savefig(f'pca_{n}_{p}_{lowdim}_retr.pdf', bbox_inches='tight')
    # plt.show()
    plt.close()