# test on reptition


from time import time

import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torch import linalg
import random
import numpy as np
import pickle

from baselines import RGD_QR, TSD, Landing, PCAL
from proposed import RSDM

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


def pca_loss(X):
    return -torch.trace(X.transpose(-1,-2) @ A @ X)/(2)

def pca_optgap(X):
    return abs(pca_loss(X).item() - loss_star)/abs(loss_star)



def load_args(case):
    if case == 1:
        # THIS
        n = 2000
        p = 1500
        lowdim = 700
        method_names = ["RGD",  "RSDM-P", "RSDM-O"]
        methods = [RGD_QR,  RSDM, RSDM]
        learning_rates = [0.1,  1.5, 1.5]
        n_epochs = [700, 1000, 800]

        extra_args = {"lam_land": 1, "lam_pcal": 1}

    elif case == 2:
        n = 2000
        p = 1000
        lowdim = 700
        method_names = ["RGD", "RSDM"]
        methods = [RGD_QR, RSDM]
        learning_rates = [0.1, 1.5]
        n_epochs = [1500, 1500]

        extra_args = {"lam_land": 1, "lam_pcal": 5}

    return n, p, lowdim, method_names, methods, learning_rates, n_epochs, extra_args







if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)



    n, p, lowdim, method_names, methods, learning_rates, n_epochs , extra_args= load_args(1)


    # CN = 1000
    # D = 10 * torch.diag(torch.logspace(-np.log10(CN), 0, n))
    # [Q, R] = linalg.qr(torch.randn(n,n))
    # A = Q @ D @ Q
    # A = (A + A.t())/2
    #
    # A = A.to(device)
    # init_weights = linalg.qr(torch.randn(n, p))[0]
    #
    # # Compute closed-form solution from svd, used for monitoring.
    # [_, w_star] = torch.linalg.eigh(A/(2))
    # w_star = w_star[:,-p:]
    # loss_star = pca_loss(w_star)
    # loss_star = loss_star.item()
    #
    # seed = 42
    # n_repeats = 5
    #
    # time_epochs_all_rgd = []
    # losses_epochs_all_rgd = []
    # optgap_epochs_all_rgd = []
    #
    # time_epochs_all_rsdm = []
    # losses_epochs_all_rsdm = []
    # optgap_epochs_all_rsdm = []
    #
    # time_epochs_all_rsdm_o = []
    # losses_epochs_all_rsdm_o = []
    # optgap_epochs_all_rsdm_o = []
    #
    # for n_rep in range(n_repeats):
    #     random.seed(seed)
    #     np.random.seed(seed)
    #     torch.manual_seed(seed)
    #     torch.cuda.manual_seed(seed)
    #     torch.cuda.manual_seed_all(seed)
    #     seed += 1
    #
    #     init_weights = linalg.qr(torch.randn(n, p))[0]
    #
    #     for method_name, method, learning_rate, n_epoch in zip(method_names, methods, learning_rates, n_epochs):
    #
    #         W = nn.Parameter(torch.empty(n, p))
    #         W.data = init_weights.clone().to(device)
    #
    #         if method_name in ['RCD', 'TSD']:
    #             # optimizer = method([W], learning_rate, n, p)
    #             optimizer = method([W], learning_rate, numupdate=100)
    #         elif method_name == 'Landing':
    #             optimizer = method([W], learning_rate, extra_args["lam_land"])
    #         elif method_name == 'PLAM':
    #             optimizer = method([W], learning_rate, lambda_regul=0.1)
    #         elif method_name == 'RSDM-P':
    #             optimizer = method([W], learning_rate, r=lowdim, use_permutation=True)
    #         elif method_name == 'RSDM-O':
    #             optimizer = method([W], learning_rate, r=lowdim, use_permutation=False)
    #         elif method_name == 'PCAL':
    #             optimizer = method([W], learning_rate, extra_args["lam_pcal"])
    #         else:
    #             optimizer = method([W], learning_rate)
    #
    #         # init
    #         loss = pca_loss(W.data).cpu().item()
    #         dist2opt = pca_optgap(W.data)
    #         losses = [loss]
    #         time_epochs = [0]
    #         optgap_epochs = [dist2opt]
    #         print(
    #             "|%s|    time for an epoch : %.1e sec, Loss: %.2e, optgap: %.2e"
    #             % (method_name, time_epochs[-1], loss, dist2opt)
    #         )
    #
    #         for epoch in range(n_epoch):
    #             # train
    #             t0 = time()
    #
    #             if method_name == 'RCD':
    #
    #                 optimizer.zero_grad()
    #                 loss = pca_loss(W)
    #                 loss.backward()
    #                 optimizer.step()
    #
    #             else:
    #                 optimizer.zero_grad()
    #                 loss = pca_loss(W)
    #                 loss.backward()
    #                 optimizer.step()
    #
    #             # test
    #             loss = pca_loss(W.data).cpu().item()
    #             dist2opt = pca_optgap(W.data)
    #             time_epochs.append(time() - t0)
    #             losses.append(loss)
    #             optgap_epochs.append(dist2opt)
    #             print(
    #                 "|%s|    time for an epoch : %.1e sec, Loss: %.2e, optgap: %.2e"
    #                 % (method_name, time_epochs[-1], loss, dist2opt)
    #             )
    #
    #             # if dist2opt < 1e-6:
    #             #     print(f"Tolerance reached. Break at iter {epoch}!")
    #             #     break
    #
    #         if method_name == 'RGD':
    #             time_epochs_all_rgd.append(time_epochs)
    #             losses_epochs_all_rgd.append(losses)
    #             optgap_epochs_all_rgd.append(optgap_epochs)
    #
    #         elif method_name == 'RSDM-P':
    #             time_epochs_all_rsdm.append(time_epochs)
    #             losses_epochs_all_rsdm.append(losses)
    #             optgap_epochs_all_rsdm.append(optgap_epochs)
    #
    #         elif method_name == 'RSDM-O':
    #             time_epochs_all_rsdm_o.append(time_epochs)
    #             losses_epochs_all_rsdm_o.append(losses)
    #             optgap_epochs_all_rsdm_o.append(optgap_epochs)
    #
    #
    #
    # time_epochs_rgd = np.array(time_epochs_all_rgd)
    # loss_epochs_rgd = np.array(losses_epochs_all_rgd)
    # optgap_epochs_rgd = np.array(optgap_epochs_all_rgd)
    #
    # time_epochs_rgd_mean = time_epochs_rgd.mean(axis=0)
    # optgap_epochs_rgd_mean = optgap_epochs_rgd.mean(axis=0)
    # optgap_epochs_rgd_std = optgap_epochs_rgd.std(axis=0)
    #
    # time_epochs_rsdm = np.array(time_epochs_all_rsdm)
    # loss_epochs_rsdm = np.array(losses_epochs_all_rsdm)
    # optgap_epochs_rsdm = np.array(optgap_epochs_all_rsdm)
    #
    # time_epochs_rsdm_mean = time_epochs_rsdm.mean(axis=0)
    # optgap_epochs_rsdm_mean = optgap_epochs_rsdm.mean(axis=0)
    # optgap_epochs_rsdm_std = optgap_epochs_rsdm.std(axis=0)
    #
    # time_epochs_rsdm_o = np.array(time_epochs_all_rsdm_o)
    # loss_epochs_rsdm_o = np.array(losses_epochs_all_rsdm_o)
    # optgap_epochs_rsdm_o = np.array(optgap_epochs_all_rsdm_o)
    #
    # time_epochs_rsdm_mean_o = time_epochs_rsdm_o.mean(axis=0)
    # optgap_epochs_rsdm_mean_o = optgap_epochs_rsdm_o.mean(axis=0)
    # optgap_epochs_rsdm_std_o = optgap_epochs_rsdm_o.std(axis=0)
    #
    # arrays = {'time_rgd': time_epochs_rgd_mean, 'optgap_rgd_mean': optgap_epochs_rgd_mean, 'optgap_rgd_std': optgap_epochs_rgd_std,
    #           'time_rsdm': time_epochs_rsdm_mean, 'optgap_rsdm_mean': optgap_epochs_rsdm_mean, 'optgap_rsdm_std': optgap_epochs_rsdm_std,
    #           'time_rsdm_o': time_epochs_rsdm_mean_o, 'optgap_rsdm_mean_o': optgap_epochs_rsdm_mean_o, 'optgap_rsdm_std_o': optgap_epochs_rsdm_std_o,}
    #
    # # Save the arrays to a pickle file
    # with open(f'results/pca_{n}_{p}_{lowdim}_rep.pkl', 'wb') as f:
    #     pickle.dump(arrays, f)
    #
    #
    #
    #
    # #%%

    # n, p, lowdim, method_names, methods, learning_rates, n_epochs, extra_args = load_args(1)

    with open(f'results/pca_{n}_{p}_{lowdim}_rep.pkl', 'rb') as f:
        loaded_arrays = pickle.load(f)
        time_epochs_rgd_mean = loaded_arrays['time_rgd']
        optgap_epochs_rgd_mean = loaded_arrays['optgap_rgd_mean']
        optgap_epochs_rgd_std = loaded_arrays['optgap_rgd_std']

        time_epochs_rsdm_mean = loaded_arrays['time_rsdm'][:-140]
        optgap_epochs_rsdm_mean = loaded_arrays['optgap_rsdm_mean'][:-140]
        optgap_epochs_rsdm_std = loaded_arrays['optgap_rsdm_std'][:-140]

        time_epochs_rsdm_mean_o = loaded_arrays['time_rsdm_o'][:-20]
        optgap_epochs_rsdm_mean_o = loaded_arrays['optgap_rsdm_mean_o'][:-20]
        optgap_epochs_rsdm_std_o = loaded_arrays['optgap_rsdm_std_o'][:-20]



    plt.figure(figsize=(5.5, 4.5))
    plt.plot(time_epochs_rgd_mean.cumsum(), optgap_epochs_rgd_mean, label='RGD', color='tab:blue', linewidth=2.5,)
    plt.fill_between(time_epochs_rgd_mean.cumsum(), optgap_epochs_rgd_mean - optgap_epochs_rgd_std, optgap_epochs_rgd_mean + optgap_epochs_rgd_std,
                     color='tab:blue', alpha=0.3)
    plt.plot(time_epochs_rsdm_mean.cumsum(), optgap_epochs_rsdm_mean, label='RSDM-P', color='tab:purple', linewidth=2.5,)
    plt.fill_between(time_epochs_rsdm_mean.cumsum(), optgap_epochs_rsdm_mean - optgap_epochs_rsdm_std,
                     optgap_epochs_rsdm_mean + optgap_epochs_rsdm_std,
                     color='tab:purple', alpha=0.3)
    plt.plot(time_epochs_rsdm_mean_o.cumsum(), optgap_epochs_rsdm_mean_o, label='RSDM-O', color='tab:pink', linewidth=2.5,)
    plt.fill_between(time_epochs_rsdm_mean_o.cumsum(), optgap_epochs_rsdm_mean_o - optgap_epochs_rsdm_std_o,
                     optgap_epochs_rsdm_mean_o + optgap_epochs_rsdm_std_o,
                     color='tab:pink', alpha=0.3)
    plt.yscale('log')
    plt.legend(loc = 1, prop={'size': 16})
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    plt.xlabel("Time", fontsize=20)
    plt.ylabel("Optimality Gap", fontsize=20)
    plt.tight_layout()
    plt.savefig(f'pca_{n}_{p}_{lowdim}_rep.pdf', bbox_inches='tight')
    plt.close()