from torch.autograd import Variable
import yaml
from argparse import ArgumentParser
import time
import re
import random

from utils.initial_x import getU, getV, getH
from utils.Jorthogonal_test import Jtest
from plot_fig.getfig import *

from algorithm.CSDM.CS_decomp_algorithm import CSDM_al
from algorithm.JOBCD.GS_JOBCD_algorithm import GS_JOBCD
from algorithm.ADMM.class_admm import admm
from algorithm.UMCM.class_precise_penalty import UMCM_al
from algorithm.JOBCD.JJOBCD_algorithm import JJOBCD

# files name sorted function
def sort_by_number(filename):
    numbers = re.findall(r'\d+', filename)
    numbers = [int(num) for num in numbers]
    return numbers

if __name__ == "__main__":
    timestr = time.strftime("%m%d-%H%M%S")
    config_path = "./config"
    t_names = sorted(os.listdir(config_path), key=sort_by_number)
    for tnum in range(len(t_names)):
        tname = t_names[tnum]
        file_names = sorted(os.listdir(os.path.join(config_path,t_names[tnum])), key=sort_by_number)
        for experiment in range(len(file_names)):
            experimentname = file_names[experiment]
            argp = ArgumentParser()
            argp.add_argument("--config", type=str, default=experimentname, help="Config file name")
            args = argp.parse_args()

            config_file = os.path.join("./", "config", tname, args.config)
            config_yaml = yaml.load(open(config_file), Loader=yaml.FullLoader)

            torch.manual_seed(config_yaml["run"]["seed"])
            random.seed(config_yaml["run"]["seed"])

            d = config_yaml["datafeature"]["d"]
            p = config_yaml["datafeature"]["p"]
            stopt = config_yaml["run"]["stop_t"]
            maxiter = int(float(config_yaml["run"]["maxiter"]))
            truncationnum = int(float(config_yaml["datafeature"]["truncationnum"]))

            B1 = torch.torch.arange(0, p)
            B2 = torch.torch.arange(p, d)
            x_data = torch.randn([d, d])
            datadir = config_yaml["datafeature"]["datadir"]
            C = torch.load(datadir)
            C = -torch.mm(C.T, C).to(torch.float32)

            dim_theta = int(d/2)
            dim_mu = int(d-p)
            dim_ksi = int(d / 2)
            theta0 = Variable(torch.randn(dim_theta), requires_grad=True)
            mu0 = Variable(torch.clamp(torch.randn(dim_mu),-1e0,1e0), requires_grad=True) # make sure initial X0 strictly J orthogonal
            ksi0 = Variable(torch.randn(dim_ksi), requires_grad=True)

            U0 = getU(d, p, theta0).to(torch.float32)
            H0 = getH(d, p, mu0).to(torch.float32)
            V0 = getV(d, p, ksi0).to(torch.float32)
            U0.retain_grad()
            V0.retain_grad()
            H0.retain_grad()

            X0 = torch.mm(U0, torch.mm(H0, V0))
            X0 = X0.to(torch.float32)

            J = torch.eye(d).to(torch.float32)
            J[p:, p:] = -1 * torch.eye(d - p)
            Jerr = torch.norm(X0.t() @ J @ X0 - J, 'fro')

            '''GS-JOBCD'''
            GS_JOBCD_object = GS_JOBCD(X0.clone().detach(), config_yaml)
            hist_GSJOBCD, X_GSJOBCD, hist_GSJOBCDt = GS_JOBCD_object.train(C)

            '''JJOBCD'''
            JJOBCD_object = JJOBCD(X0.clone().detach(), config_yaml)
            hist_JJOBCD, X_JJOBCD, hist_JJOBCDt = JJOBCD_object.train(C)

            '''ADMM'''
            admm_obj = admm(X0.clone().detach(), config_yaml)
            hist_admm, X_admm, hist_admmt = admm_obj.train(C)

            '''CS_decomp'''
            CSDM_obj = CSDM_al(theta0, mu0, ksi0, config_yaml)
            hist_CSDM, X_CSDM, hist_CSDMt = CSDM_obj.train(C)

            '''UMCM'''
            UMCM_obj = UMCM_al(X0.clone().detach(), config_yaml)
            hist_UMCM, X_UMCM, hist_UMCMt = UMCM_obj.train(C)

            '''ADMM\pp\CS+JOBCDOO'''
            JOBCD_ADMM = GS_JOBCD(X_admm.clone().detach(), config_yaml)
            JOBCD_UMCM = GS_JOBCD(X_UMCM.clone().detach(), config_yaml)
            JOBCD_CSDM = GS_JOBCD(X_CSDM.clone().detach(), config_yaml)

            hist_JOBCD_ADMM, X_JOBCD_ADMM, hist_JOBCDt_ADMM = JOBCD_ADMM.train(C)
            hist_JOBCD_UMCM, X_JOBCD_UMCM, hist_JOBCDt_UMCM = JOBCD_UMCM.train(C)
            hist_JOBCD_CSDM, X_JOBCD_CSDM, hist_JOBCDt_CSDM = JOBCD_CSDM.train(C)

            showtrain = config_yaml["run"]["showtrain"]
            if showtrain:
                print('UMCM obj:{}'.format(hist_UMCM[-1]))
                print('UMCM J orthogonal:{}'.format(Jtest(X_UMCM, p)))
                print('admm obj:{}'.format(hist_admm[-1]))
                print('admm J orthogonal:{}'.format(Jtest(X_admm, p)))
                print('CSDM obj:{}'.format(hist_CSDM[-1]))
                print('CSDM J orthogonal:{}'.format(Jtest(X_CSDM, p)))
                print('GS-JOBCD obj:{}'.format(hist_GSJOBCD[-1]))
                print('GS-JOBCD J orthogonal:{}'.format(Jtest(X_GSJOBCD, p)))
                print('J-JOBCD obj:{}'.format(hist_JJOBCD[-1]))
                print('J-JOBCD J orthogonal:{}'.format(Jtest(X_JJOBCD, p)))
                print('GS-JOBCD+ADMM obj:{}'.format(hist_JOBCD_ADMM[-1]))
                print('GS-JOBCD+ADMM J orthogonal:{}'.format(Jtest(X_JOBCD_ADMM, p)))
                print('GS-JOBCD+UMCM obj:{}'.format(hist_JOBCD_UMCM[-1]))
                print('GS-JOBCD+UMCM J orthogonal:{}'.format(Jtest(X_JOBCD_UMCM, p)))
                print('GS-JOBCD+CSDM obj:{}'.format(hist_JOBCD_CSDM[-1]))
                print('GS-JOBCD+CSDM J orthogonal:{}'.format(Jtest(X_JOBCD_CSDM, p)))

            log_path = config_yaml["log"]["dir"]
            ename = config_yaml["log"]["savename"]
            if not os.path.exists(log_path):
                os.makedirs(log_path)

            log_path = os.path.join(log_path, str(ename) + "-" + timestr)
            if not os.path.exists(log_path):
                os.makedirs(log_path)

            log_file = os.path.join(log_path, str(ename) + "-" + timestr + ".log")
            print('{} done'.format(log_file))
            dataname = config_yaml["datafeature"]["name"]
            with open(log_file, "a") as f:
                f.write(f"experiment name: {ename}\n")
            with open(log_file, "a") as f:
                f.write(f"Dataset name: {dataname}\n")
            with open(log_file, "a") as f:
                f.write(f"Dimension d: {d} , p: {p}\n")
            with open(log_file, "a") as f:
                f.write("UMCM Obj value:" + "{:.2e}".format(hist_UMCM[-1]) + "| J err: {:.1e} \n".format(Jtest(X_UMCM, p)))
                f.write("admm Obj value:" + "{:.2e}".format(hist_admm[-1]) + "| J err: {:.1e} \n".format(Jtest(X_admm, p)))
                f.write("CSDM Obj value:" + "{:.2e}".format(hist_CSDM[-1]) + "| J err: {:.1e} \n".format(Jtest(X_CSDM, p)))
                f.write("GS-JOBCD Obj value:" + "{:.2e}".format(hist_GSJOBCD[-1]) + "| J err: {:.1e} \n".format(Jtest(X_GSJOBCD, p)))
                f.write("J-JOBCD Obj value:" + "{:.2e}".format(hist_JJOBCD[-1]) + "| J err: {:.1e} \n".format(Jtest(X_JJOBCD, p)))
                f.write("GS-JOBCD+ADMM Obj value:" + "{:.2e}".format(hist_JOBCD_ADMM[-1]) + "| J err: {:.1e} \n".format(Jtest(X_JOBCD_ADMM, p)))
                f.write("GS-JOBCD+UMCM Obj value:" + "{:.2e}".format(hist_JOBCD_UMCM[-1]) + "| J err: {:.1e} \n".format(Jtest(X_JOBCD_UMCM, p)))
                f.write("GS-JOBCD+CSDM Obj value:" + "{:.2e}".format(hist_JOBCD_CSDM[-1]) + "| J err: {:.1e} \n".format(Jtest(X_JOBCD_CSDM, p)))


            # gather results into one file
            gather_path = config_yaml["log"]["gatherpath"]
            gather_path = os.path.join(gather_path)
            if not os.path.exists(gather_path):
                os.makedirs(gather_path)

            gather_file = os.path.join(gather_path, tname + "gather result-" + timestr + ".log")
            with open(gather_file, "a") as f:
                f.write(f"{ename} {dataname} ({d},{p}) :")
            with open(gather_file, "a") as f:
                f.write("{:.2e}".format(hist_UMCM[-1]) + "({:.1e})".format(Jtest(X_UMCM, p)) +
                        "{:.2e}".format(hist_admm[-1]) + "({:.1e})".format(Jtest(X_admm, p)) +
                        "{:.2e}".format(hist_CSDM[-1]) + "({:.1e})".format(Jtest(X_CSDM, p)) +
                        "{:.2e}".format(hist_GSJOBCD[-1]) + "({:.1e})".format(Jtest(X_GSJOBCD, p)) +
                        "{:.2e}".format(hist_JJOBCD[-1]) + "({:.1e})".format(Jtest(X_JJOBCD, p)) +
                        "{:.2e}".format(hist_JOBCD_ADMM[-1]) + "({:.1e})".format(Jtest(X_JOBCD_ADMM, p)) +
                        "{:.2e}".format(hist_JOBCD_UMCM[-1]) + "({:.1e})".format(Jtest(X_JOBCD_UMCM, p)) +
                        "{:.2e}".format(hist_JOBCD_CSDM[-1]) + "({:.1e})".format(Jtest(X_JOBCD_CSDM, p)) + "\n")

            # do not plot the figure of JOBCDOO+ADMM/PP/CS
            plott = [hist_UMCMt, hist_admmt, hist_CSDMt, hist_GSJOBCDt, hist_JJOBCDt]
            plotf = [hist_UMCM, hist_admm, hist_CSDM, hist_GSJOBCD, hist_JJOBCD]
            getfig(plott, plotf, config_yaml, log_path, truncationnum)