# ablation on different low-dimension



from time import time

import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torch import linalg
import random
import numpy as np


from baselines import RGD_QR, TSD, Landing, PCAL
from proposed import RSDM

import pickle

from matplotlib.lines import Line2D


import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


def pca_loss(X):
    return -torch.trace(X.transpose(-1,-2) @ A @ X)/(2)

def pca_optgap(X):
    return abs(pca_loss(X).item() - loss_star)/abs(loss_star)



def load_args(case):
    if case == 1:
        # THIS
        n = 2000
        p = 1500
        lowdim = [0, 500, 500, 600, 600, 700, 700, 800, 800, 900, 900]
        method_names = ["RGD", "RSDM-P", "RSDM-O", "RSDM-P", "RSDM-O", "RSDM-P", "RSDM-O", "RSDM-P", "RSDM-O", "RSDM-P", "RSDM-O"]
        methods = [RGD_QR, RSDM, RSDM,  RSDM, RSDM, RSDM, RSDM, RSDM, RSDM, RSDM, RSDM]
        learning_rates = [0.1, 2, 2,  1.5, 2,  1.5, 1.5, 1.5, 1.5, 1, 1]
        n_epochs = [2000, 2000, 2000, 2000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]

        extra_args = {"lam_land": 1, "lam_pcal": 1}

    return n, p, lowdim, method_names, methods, learning_rates, n_epochs, extra_args



if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    n, p, lowdims, method_names, methods, learning_rates, n_epochs, extra_args = load_args(1)

    # CN = 1000
    # D = 10 * torch.diag(torch.logspace(-np.log10(CN), 0, n))
    # [Q, R] = linalg.qr(torch.randn(n, n))
    # A = Q @ D @ Q
    # A = (A + A.t()) / 2
    #
    # A = A.to(device)
    # init_weights = linalg.qr(torch.randn(n, p))[0]
    #
    # # Compute closed-form solution from svd, used for monitoring.
    # [_, w_star] = torch.linalg.eigh(A / (2))
    # w_star = w_star[:, -p:]
    # loss_star = pca_loss(w_star)
    # loss_star = loss_star.item()
    # #
    # # lowdims = [900]
    # # method_names = ["RSDM-O"]
    # # methods = [RSDM]
    # # learning_rates = [1.5]
    # # repeat starts
    # seed = 42
    #
    # time_epochs_all_rgd = []
    # losses_epochs_all_rsdm = []
    #
    # time_epochs_all = []
    # losses_epochs_all = []
    # optgap_epochs_all = []
    #
    # colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:purple', 'tab:red', 'tab:brown']
    # markers = ['*', 'o', '+', 'd', '>']
    # markers = [''] * 6
    # for method_name, method, learning_rate, n_epoch, lowdim in zip(method_names, methods, learning_rates, n_epochs, lowdims):
    #
    #     W = nn.Parameter(torch.empty(n, p))
    #     W.data = init_weights.clone().to(device)
    #
    #     if method_name in ['RCD', 'TSD']:
    #         # optimizer = method([W], learning_rate, n, p)
    #         optimizer = method([W], learning_rate, numupdate=100)
    #     elif method_name == 'Landing':
    #         optimizer = method([W], learning_rate, extra_args["lam_land"])
    #     elif method_name == 'PLAM':
    #         optimizer = method([W], learning_rate, lambda_regul=0.1)
    #     elif method_name == 'RSDM-P':
    #         optimizer = method([W], learning_rate, r=lowdim, use_permutation=True)
    #     elif method_name == 'RSDM-O':
    #         optimizer = method([W], learning_rate, r=lowdim, use_permutation=False)
    #     elif method_name == 'PCAL':
    #         optimizer = method([W], learning_rate, extra_args["lam_pcal"])
    #     else:
    #         optimizer = method([W], learning_rate)
    #
    #     # init
    #     loss = pca_loss(W.data).cpu().item()
    #     dist2opt = pca_optgap(W.data)
    #     losses = [loss]
    #     time_epochs = [0]
    #     optgap_epochs = [dist2opt]
    #     print(
    #         "|%s|    time for an epoch : %.1e sec, Loss: %.2e, optgap: %.2e"
    #         % (method_name, time_epochs[-1], loss, dist2opt)
    #     )
    #
    #     for epoch in range(n_epoch):
    #         # train
    #         t0 = time()
    #
    #         if method_name == 'RCD':
    #
    #             optimizer.zero_grad()
    #             loss = pca_loss(W)
    #             loss.backward()
    #             optimizer.step()
    #
    #         else:
    #             optimizer.zero_grad()
    #             loss = pca_loss(W)
    #             loss.backward()
    #             optimizer.step()
    #
    #         # test
    #         loss = pca_loss(W.data).cpu().item()
    #         dist2opt = pca_optgap(W.data)
    #         time_epochs.append(time() - t0)
    #         losses.append(loss)
    #         optgap_epochs.append(dist2opt)
    #         print(
    #             "|%s|    time for an epoch : %.1e sec, Loss: %.2e, optgap: %.2e"
    #             % (method_name, time_epochs[-1], loss, dist2opt)
    #         )
    #
    #         if dist2opt < 1e-6:
    #             print(f"Tolerance reached. Break at iter {epoch}!")
    #             break
    #
    #     time_epochs_all.append(time_epochs)
    #     losses_epochs_all.append(losses)
    #     optgap_epochs_all.append(optgap_epochs)
    #
    # all_results = [lowdims, time_epochs_all, losses_epochs_all, optgap_epochs_all]
    #
    # with open(f'results/pca_{n}_{p}_vary_r.pkl', 'wb') as f:
    #     pickle.dump(all_results, f)

    with open(f'results/pca_{n}_{p}_vary_r.pkl', 'rb') as f:
        all_results = pickle.load(f)
        lowdims = all_results[0]
        time_epochs_all = all_results[1]
        losses_epochs_all = all_results[2]
        optgap_epochs_all = all_results[3]


    colors = ['tab:blue', 'tab:orange', 'tab:orange', 'tab:green', 'tab:green', 'tab:purple', 'tab:purple', 'tab:red', 'tab:red', 'tab:brown', 'tab:brown']
    num_method = len(method_names)

    labels = ['RGD', "RSDM-P (500)", "RSDM-O (500)", "RSDM-P (600)", "RSDM-O (500)",
              "RSDM-P (700)", "RSDM-O (700)", "RSDM-P (800)", "RSDM-O (800)", "RSDM-P (900)", "RSDM-O (900)"]
    linestyles = ['-', '-', '--', '-', '--', '-', '--', '-', '--', '-', '--']

    plt.figure(figsize=(5.5, 4.5))
    for method_name, time_epochs, optgaps, color, label, linestyle in \
            zip(method_names, time_epochs_all, optgap_epochs_all, colors, labels, linestyles):
        plt.semilogy(
            torch.cumsum(torch.tensor(time_epochs), dim=0),
            optgaps,
            label=label,
            color=color,
            linewidth=2.5,
            linestyle=linestyle
        )

    # custom_lines = [
    #     Line2D([0], [0], color='tab:orange', lw=2.5, linestyle='-'),
    #     Line2D([0], [0], color='tab:orange', lw=2.5, linestyle='--'),
    #     Line2D([0], [0], color='tab:green', lw=2.5, linestyle='-'),
    #     Line2D([0], [0], color='tab:green', lw=2.5, linestyle='--'),
    #     Line2D([0], [0], color='tab:purple', lw=2.5, linestyle='-'),
    #     Line2D([0], [0], color='tab:purple', lw=2.5, linestyle='--'),
    #     Line2D([0], [0], color='tab:red', lw=2.5, linestyle='-'),
    #     Line2D([0], [0], color='tab:red', lw=2.5, linestyle='--'),
    #     Line2D([0], [0], color='tab:brown', lw=2.5, linestyle='-'),
    #     Line2D([0], [0], color='tab:brown', lw=2.5, linestyle='--')
    # ]
    #
    # plt.legend(custom_lines, [
    #     "RSDM-P (500)", "RSDM-O (500)",
    #     "RSDM-P (600)", "RSDM-O (600)",
    #     "RSDM-P (700)", "RSDM-O (700)",
    #     "RSDM-P (800)", "RSDM-O (800)",
    #     "RSDM-P (900)", "RSDM-O (900)"
    # ],  fontsize=10, title_fontsize='13', loc='best')

    plt.legend(loc = 1, prop={'size': 13})
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    plt.xlabel("Time", fontsize=20)
    plt.ylabel("Optimality Gap", fontsize=20)
    plt.tight_layout()
    plt.savefig(f'pca_{n}_{p}_vary_r.pdf', bbox_inches='tight')
    plt.close()