from torch.autograd import Variable
import yaml
from argparse import ArgumentParser
import time
import re
import random

from utils.initial_x import getU, getV, getH
from utils.Jorthogonal_test import Jtest
from plot_fig.getfig import *
from plot_fig.getfigloglog import *

from algorithm.CSDM.CS_decomp_algorithm import CSDM_al
from algorithm.JOBCD.GS_JOBCD_algorithm import GS_JOBCD
from algorithm.ADMM.class_admm import admm
from algorithm.UMCM.class_precise_penalty import UMCM_al
from algorithm.JOBCD.JJOBCD_algorithm import JJOBCD

# files name sorted function
def sort_by_number(filename):
    numbers = re.findall(r'\d+', filename)
    numbers = [int(num) for num in numbers]
    return numbers

if __name__ == "__main__":
    timestr = time.strftime("%m%d-%H%M%S")
    config_path = "./config"
    t_names = sorted(os.listdir(config_path), key=sort_by_number)
    for tnum in range(len(t_names)):
        tname = t_names[tnum]
        file_names = sorted(os.listdir(os.path.join(config_path,t_names[tnum])), key=sort_by_number)
        for experiment in range(len(file_names)):
            experimentname = file_names[experiment]
            argp = ArgumentParser()
            argp.add_argument("--config", type=str, default=experimentname, help="Config file name")
            args = argp.parse_args()

            config_file = os.path.join("./", "config", tname, args.config)
            config_yaml = yaml.load(open(config_file), Loader=yaml.FullLoader)

            torch.manual_seed(config_yaml["run"]["seed"])
            random.seed(config_yaml["run"]["seed"])

            d = config_yaml["datafeature"]["d"]
            p = config_yaml["datafeature"]["p"]
            J = torch.eye(d).to(torch.float32)
            J[p:, p:] = -1 * torch.eye(d - p)
            stopt = config_yaml["run"]["stop_t"]
            maxiter = int(float(config_yaml["run"]["maxiter"]))
            truncationnum = int(float(config_yaml["datafeature"]["truncationnum"]))

            B1 = torch.torch.arange(0, p)
            B2 = torch.torch.arange(p, d)
            x_data = torch.randn([d, d])
            datadir = config_yaml["datafeature"]["datadir"]
            C = torch.load(datadir)
            C = -torch.mm(C.T, C).to(torch.float32)

            dim_theta = int(d/2)
            dim_mu = int(d-p)
            dim_ksi = int(d / 2)

            while 1: #使用循环避免初始值误差过大
                theta0 = Variable(torch.randn(dim_theta) / 100000, requires_grad=True)
                mu0 = Variable(torch.clamp(torch.randn(dim_mu),-1e0,1e0)/ 100, requires_grad=True) # make sure initial X0 strictly J orthogonal
                ksi0 = Variable(torch.randn(dim_ksi) / 100000, requires_grad=True)

                U0 = getU(d, p, theta0).to(torch.float32)
                H0 = getH(d, p, mu0).to(torch.float32)
                V0 = getV(d, p, ksi0).to(torch.float32)
                U0.retain_grad()
                V0.retain_grad()
                H0.retain_grad()

                X0 = torch.mm(U0, torch.mm(H0, V0))
                X0 = X0.to(torch.float32)

                Jerr = Jtest(X0, p)

                if Jerr <= 1e-9:
                    break


            '''GS-JOBCD'''
            GS_JOBCD_object = GS_JOBCD(X0.clone().detach(), config_yaml)
            hist_GSJOBCD, X_GSJOBCD, hist_GSJOBCDt, hist_GSerr = GS_JOBCD_object.train(C)

            '''JJOBCD'''
            JJOBCD_object = JJOBCD(X0.clone().detach(), config_yaml)
            hist_JJOBCD, X_JJOBCD, hist_JJOBCDt, hist_JJerr = JJOBCD_object.train(C)

            '''UMCM'''
            UMCM_obj = UMCM_al(X0.clone().detach(), config_yaml, 1e-3)
            hist_UMCM, X_UMCM, hist_UMCMt, hist_UMCMerr = UMCM_obj.Train(C)

            '''ADMM'''
            admm_obj = admm(X0.clone().detach(), config_yaml, 1e4, 'd')  #
            hist_admm, X_admm, hist_admmt, hist_admmerr= admm_obj.train(C)

            '''CS_decomp'''
            CSDM_obj = CSDM_al(theta0, mu0, ksi0, config_yaml)
            hist_CSDM, X_CSDM, hist_CSDMt, hist_CSerr = CSDM_obj.train(C)

            '''ADMM\pp\CS+JOBCDOO'''
            JOBCD_ADMM = GS_JOBCD(X_admm.clone().detach(), config_yaml)
            JOBCD_UMCM = GS_JOBCD(X_UMCM.clone().detach(), config_yaml)
            JOBCD_CSDM = GS_JOBCD(X_CSDM.clone().detach(), config_yaml)

            hist_JOBCD_ADMM, X_JOBCD_ADMM, hist_JOBCDt_ADMM, hist_JADMMerr = JOBCD_ADMM.train(C)
            hist_JOBCD_UMCM, X_JOBCD_UMCM, hist_JOBCDt_UMCM, hist_JUMCMerr = JOBCD_UMCM.train(C)
            hist_JOBCD_CSDM, X_JOBCD_CSDM, hist_JOBCDt_CSDM, hist_JCSDMerr = JOBCD_CSDM.train(C)

            showtrain = config_yaml["run"]["showtrain"]
            if showtrain:
                print('UMCM obj:{}'.format(hist_UMCM[-1]))
                print('UMCM J orthogonal:{}'.format(Jtest(X_UMCM, p)))
                print('admm obj:{}'.format(hist_admm[-1]))
                print('admm J orthogonal:{}'.format(Jtest(X_admm, p)))
                print('CSDM obj:{}'.format(hist_CSDM[-1]))
                print('CSDM J orthogonal:{}'.format(Jtest(X_CSDM, p)))
                print('GS-JOBCD obj:{}'.format(hist_GSJOBCD[-1]))
                print('GS-JOBCD J orthogonal:{}'.format(Jtest(X_GSJOBCD, p)))
                print('J-JOBCD obj:{}'.format(hist_JJOBCD[-1]))
                print('J-JOBCD J orthogonal:{}'.format(Jtest(X_JJOBCD, p)))
                print('GS-JOBCD+ADMM obj:{}'.format(hist_JOBCD_ADMM[-1]))
                print('GS-JOBCD+ADMM J orthogonal:{}'.format(Jtest(X_JOBCD_ADMM, p)))
                print('GS-JOBCD+UMCM obj:{}'.format(hist_JOBCD_UMCM[-1]))
                print('GS-JOBCD+UMCM J orthogonal:{}'.format(Jtest(X_JOBCD_UMCM, p)))
                print('GS-JOBCD+CSDM obj:{}'.format(hist_JOBCD_CSDM[-1]))
                print('GS-JOBCD+CSDM J orthogonal:{}'.format(Jtest(X_JOBCD_CSDM, p)))

            log_path = config_yaml["log"]["dir"]
            ename = config_yaml["log"]["savename"]
            if not os.path.exists(log_path):
                os.makedirs(log_path)

            log_path = os.path.join(log_path, str(ename) + "-" + timestr)
            if not os.path.exists(log_path):
                os.makedirs(log_path)

            log_file = os.path.join(log_path, str(ename) + "-" + timestr + ".log")
            print('{} done'.format(log_file))
            dataname = config_yaml["datafeature"]["name"]
            with open(log_file, "a") as f:
                f.write(f"experiment name: {ename}\n")
            with open(log_file, "a") as f:
                f.write(f"Dataset name: {dataname}\n")
            with open(log_file, "a") as f:
                f.write(f"Dimension d: {d} , p: {p}\n")
            with open(log_file, "a") as f:
                f.write("UMCM Obj value:" + "{:.2e}".format(hist_UMCM[-1]) + "| J err: {:.1e} \n".format(Jtest(X_UMCM, p)))
                f.write("admm Obj value:" + "{:.2e}".format(hist_admm[-1]) + "| J err: {:.1e} \n".format(Jtest(X_admm, p)))
                f.write("CSDM Obj value:" + "{:.2e}".format(hist_CSDM[-1]) + "| J err: {:.1e} \n".format(Jtest(X_CSDM, p)))
                f.write("GS-JOBCD Obj value:" + "{:.2e}".format(hist_GSJOBCD[-1]) + "| J err: {:.1e} \n".format(Jtest(X_GSJOBCD, p)))
                f.write("J-JOBCD Obj value:" + "{:.2e}".format(hist_JJOBCD[-1]) + "| J err: {:.1e} \n".format(Jtest(X_JJOBCD, p)))
                f.write("GS-JOBCD+ADMM Obj value:" + "{:.2e}".format(hist_JOBCD_ADMM[-1]) + "| J err: {:.1e} \n".format(Jtest(X_JOBCD_ADMM, p)))
                f.write("GS-JOBCD+UMCM Obj value:" + "{:.2e}".format(hist_JOBCD_UMCM[-1]) + "| J err: {:.1e} \n".format(Jtest(X_JOBCD_UMCM, p)))
                f.write("GS-JOBCD+CSDM Obj value:" + "{:.2e}".format(hist_JOBCD_CSDM[-1]) + "| J err: {:.1e} \n".format(Jtest(X_JOBCD_CSDM, p)))


            # gather results into one file
            gather_path = config_yaml["log"]["gatherpath"]
            gather_path = os.path.join(gather_path)
            if not os.path.exists(gather_path):
                os.makedirs(gather_path)

            gather_file = os.path.join(gather_path, tname + "gather result-" + timestr + ".log")
            with open(gather_file, "a") as f:
                f.write(f"{ename} {dataname} ({d},{p}) :")
            with open(gather_file, "a") as f:
                f.write("{:.2e}".format(hist_UMCM[-1]) + "({:.1e})".format(Jtest(X_UMCM, p)) +
                        "{:.2e}".format(hist_admm[-1]) + "({:.1e})".format(Jtest(X_admm, p)) +
                        "{:.2e}".format(hist_CSDM[-1]) + "({:.1e})".format(Jtest(X_CSDM, p)) +
                        "{:.2e}".format(hist_GSJOBCD[-1]) + "({:.1e})".format(Jtest(X_GSJOBCD, p)) +
                        "{:.2e}".format(hist_JJOBCD[-1]) + "({:.1e})".format(Jtest(X_JJOBCD, p)) +
                        "{:.2e}".format(hist_JOBCD_ADMM[-1]) + "({:.1e})".format(Jtest(X_JOBCD_ADMM, p)) +
                        "{:.2e}".format(hist_JOBCD_UMCM[-1]) + "({:.1e})".format(Jtest(X_JOBCD_UMCM, p)) +
                        "{:.2e}".format(hist_JOBCD_CSDM[-1]) + "({:.1e})".format(Jtest(X_JOBCD_CSDM, p)) + "\n")

            # do not plot the figure of JOBCDOO+ADMM/PP/CS
            plott = [hist_UMCMt, hist_admmt, hist_CSDMt, hist_GSJOBCDt, hist_JJOBCDt]
            plotf = [hist_UMCM, hist_admm, hist_CSDM, hist_GSJOBCD, hist_JJOBCD]
            ploterr = [hist_UMCMerr, hist_admmerr, hist_CSerr, hist_GSerr, hist_JJerr]
            getfig(plott, plotf, ploterr, config_yaml, log_path, truncationnum)
            getfigloglog(plott, plotf, ploterr, config_yaml, log_path, truncationnum)

            torch.save(
                {'hist_GS-JJOBCD': hist_GSJOBCD, 'hist_GS-JJOBCDt': hist_GSJOBCDt,'hist_GSerr': hist_GSerr,
                 'hist_JJOBCD': hist_JJOBCD, 'hist_JJOBCDt': hist_JJOBCDt,'hist_JJerr': hist_JJerr,
                 'hist_CSDM': hist_CSDM, 'hist_CSDMt': hist_CSDMt,'hist_CSerr': hist_CSerr,
                 'hist_ADMM': hist_admm, 'hist_ADMMt': hist_admmt,'hist_ADMMerr': hist_admmerr,
                 'hist_UMCM': hist_UMCM, 'hist_UMCMt': hist_UMCMt,'hist_UMCMerr': hist_UMCMerr,
                 'hist_JOBCD_ADMM': hist_JOBCD_ADMM, 'hist_JOBCD_ADMMt': hist_JOBCD_ADMMt, 'hist_JOBCD_ADMMerr': hist_JADMMerr,
                 'hist_JOBCD_UMCM': hist_JOBCD_UMCM, 'hist_JOBCD_UMCMt': hist_JOBCD_UMCMt, 'hist_JOBCD_UMCMerr': hist_JUMCMerrr,
                 'hist_JOBCD_CSDM': hist_JOBCD_CSDM, 'hist_JOBCD_CSDMt': hist_JOBCD_CSDMt, 'hist_JOBCD_CSDMerr': hist_JCSDMerr},
                 log_path +'/'+ experimentname[:-5] +timestr+'.pt')