import torch
from torch.autograd import Variable
import time
import random
import re
import os
import yaml
from argparse import ArgumentParser

from UHstruct.fobj_val import UltraE_fval_obj
from utils.initial_x import getU, getV, getH
from dataset.gettriple import UltraE_initial
from utils.Jorthogonal_test import Jtest
from utils.get_dist import get_dist
from plot_fig.getfige import *
from plot_fig.getfigt import *

from algorithm.JOBCD.JOBCD_algorithm import JOBCD
from algorithm.JOBCD.JOBCD_VR_algorithm import VRJOBCD
from algorithm.CSDM.CS_decomp_algorithm import CSDM_al
from algorithm.ADMM.class_admm import admm
from algorithm.UMCM.class_precise_penalty import UMCM_al

def sort_by_number(filename):
    numbers = re.findall(r'\d+', filename)
    numbers = [int(num) for num in numbers]
    return numbers

if __name__ == "__main__":
    timestr = time.strftime("%m%d-%H%M%S")
    config_path = "./config"
    t_names = sorted(os.listdir(config_path))
    for tnum in range(len(t_names)):
        tname = t_names[tnum]
        file_names = sorted(os.listdir(os.path.join(config_path,t_names[tnum])), key=sort_by_number)
        for experiment in range(len(file_names)):
            experimentname = file_names[experiment]
            argp = ArgumentParser()
            argp.add_argument("--config", type=str, default=experimentname, help="Config file name")
            args = argp.parse_args()

            config_file = os.path.join("./", "config", tname, args.config)
            config_yaml = yaml.load(open(config_file), Loader=yaml.FullLoader)

            torch.manual_seed(config_yaml["run"]["seed"])
            random.seed(0)

            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            config_yaml['device'] = device
            # config_yaml['device'] = 'cpu'
            print(config_yaml['device'])

            n = config_yaml["datafeature"]["n"]
            d = config_yaml["datafeature"]["d"]
            p = config_yaml["datafeature"]["p"]

            datadir = config_yaml["datafeature"]["datadir"]
            H = torch.load(datadir).to(device).to(torch.float32)
            T = get_dist(H).to(torch.float32)


            dim_theta = int(d/2)
            dim_mu = int(d-p)
            dim_ksi = int(d / 2)

            P0 = torch.randn(n, d).to(torch.float32)

            while 1: #使用循环避免初始值误差过大
                theta0 = Variable(torch.randn(dim_theta) / 10000000, requires_grad=True).to(config_yaml["device"])
                mu0 = Variable(torch.clamp(torch.randn(dim_mu), -1e0, 1e0) / 10000, requires_grad=True).to(
                    config_yaml["device"])  # make sure initial X0 strictly J orthogonal
                ksi0 = Variable(torch.randn(dim_ksi) / 10000000, requires_grad=True).to(config_yaml["device"])
                U0 = getU(d, p, theta0).to(torch.float32)
                H0 = getH(d, p, mu0).to(torch.float32)
                V0 = getV(d, p, ksi0).to(torch.float32)
                U0.retain_grad()
                V0.retain_grad()
                H0.retain_grad()
                Q0 = torch.mm(U0, torch.mm(H0, V0))
                Q0 = Q0.to(torch.float32)
                Jerr = Jtest(Q0, p)

                if Jerr <= 1e-9:
                    break

            f0 = UltraE_fval_obj(T.to(config_yaml["device"]), H.to(config_yaml["device"]), P0.to(config_yaml["device"]), Q0.to(config_yaml["device"]), config_yaml)


            '''experiment'''
            VRJJOBCD = VRJOBCD(T.clone().detach(), H.clone().detach(), Q0.clone().detach(), P0.clone().detach(), config_yaml)
            UMCM = UMCM_al(T.clone().detach(), H.clone().detach(), Q0.clone().detach(), P0.clone().detach(), config_yaml, 5e-4)
            ADMM = admm(T.clone().detach(), H.clone().detach(), Q0.clone().detach(), P0.clone().detach(), config_yaml, 1e6, 'd')
            CSDM = CSDM_al(T.clone().detach(), H.clone().detach(), theta0.clone().detach(), mu0.clone().detach(), ksi0.clone().detach(), Q0.clone().detach(), P0.clone().detach(), config_yaml)
            JJOBCD = JOBCD(T.clone().detach(), H.clone().detach(), Q0.clone().detach(), P0.clone().detach(), config_yaml)



            hist_JJOBCD, vecR_JJOBCD, hist_JJOBCDt, hist_JJerr = JJOBCD.Train()
            hist_VRJJOBCD, vecR_VRJJOBCD, hist_VRJJOBCDt, hist_VRerr = VRJJOBCD.Train()

            hist_admm, vecR_admm, hist_admmt, hist_admmerr = ADMM.Train()
            hist_UMCM, vecR_UMCM, hist_UMCMt, hist_UMCMerr = UMCM.Train()
            hist_CSDM, vecR_CSDM, hist_CSDMt, hist_CSerr = CSDM.Train()




            CS_JJOBCDOO = JOBCD(T.clone().detach(), H.clone().detach(), vecR_CSDM.clone().detach(), P0.clone().detach(), config_yaml)
            hist_CS_VRJJOBCDOO, vecR_CS_VRJJOBCDOO, hist_CS_VRJJOBCDtOO, hist_CS_VRerr = CS_JJOBCDOO.Train()

            log_path = config_yaml["log"]["dir"]
            if not os.path.exists(log_path):
                os.makedirs(log_path)

            ename = config_yaml["log"]["savename"]
            log_path = os.path.join(log_path, str(ename) + "-" + timestr)
            if not os.path.exists(log_path):
                os.makedirs(log_path)

            log_file = os.path.join(log_path, str(ename) + "-" + timestr + ".log")
            print('{} done'.format(log_file))
            dataname = config_yaml["datafeature"]["name"]

            with open(log_file, "a") as f:
                f.write(f"experiment name: {ename}\n")
            with open(log_file, "a") as f:
                f.write(f"Dataset name: {dataname}\n")
            with open(log_file, "a") as f:
                f.write(f"Dimension d: {d} , p: {p}\n")
            with open(log_file, "a") as f:
                f.write("J-JOBCD Obj:" + "{:.2e}".format(hist_JJOBCD[-1]) + "| J err: {:.1e} \n".format(
                    Jtest(vecR_JJOBCD, p)))
                f.write("VR-J-JOBCD Obj:" + "{:.2e}".format(hist_VRJJOBCD[-1]) + "| J err: {:.1e} \n".format(
                    Jtest(vecR_VRJJOBCD, p)))
                f.write("CSDM Obj:" + "{:.2e}".format(hist_CSDM[-1]) + "| J err: {:.1e} \n".format(
                    Jtest(vecR_CSDM, p)))
                f.write("ADMM Obj:" + "{:.2e}".format(hist_admm[-1]) + "| J err: {:.1e} \n".format(
                    Jtest(vecR_admm, p)))
                f.write("UMCM Obj:" + "{:.2e}".format(hist_UMCM[-1]) + "| J err: {:.1e} \n".format(
                    Jtest(vecR_UMCM, p)))
                f.write("CSDM+J-JOBCD Obj:" + "{:.2e}".format(hist_CS_VRJJOBCDOO[-1]) + "| J err: {:.1e} \n".format(
                    Jtest(vecR_CS_VRJJOBCDOO, p)))

            # gather results into one file
            gather_path = config_yaml["log"]["gatherpath"]
            gather_path = os.path.join(gather_path)
            if not os.path.exists(gather_path):
                os.makedirs(gather_path)

            gather_file = os.path.join(gather_path, tname + "gather result-" + timestr + ".log")
            with open(gather_file, "a") as f:
                f.write(f"{ename} {dataname} ({d},{p}) : ")
            with open(gather_file, "a") as f:
                f.write("{:.2e}".format(hist_JJOBCD[-1]) + "({:.1e})".format(Jtest(vecR_JJOBCD, p)) +
                        "{:.2e}".format(hist_VRJJOBCD[-1]) + "({:.1e})".format(Jtest(vecR_VRJJOBCD, p)) +
                        "{:.2e}".format(hist_CSDM[-1]) + "({:.1e})".format(Jtest(vecR_CSDM, p)) +
                        "{:.2e}".format(hist_admm[-1]) + "({:.1e})".format(Jtest(vecR_admm, p)) +
                        "{:.2e}".format(hist_UMCM[-1]) + "({:.1e})".format(Jtest(vecR_UMCM, p)) +
                        "{:.2e}".format(hist_CS_VRJJOBCDOO[-1]) + "({:.1e})".format(Jtest(vecR_CS_VRJJOBCDOO, p)) +
                        "\n")

            # do not plot the figure of JOBCDOO+ADMM/PP/CS
            plotep = torch.arange(0, float(config_yaml["run"]["maxiter"]))
            plott = [hist_JJOBCDt, hist_VRJJOBCDt, hist_CSDMt, hist_admmt, hist_UMCMt]
            plotf = [hist_JJOBCD, hist_VRJJOBCD, hist_CSDM, hist_admm, hist_UMCM]
            ploterr = [hist_JJerr, hist_VRerr, hist_CSerr, hist_admmerr, hist_UMCMerr]
            getfige(plotep, plotf,ploterr, f0, config_yaml, log_path,'Objective')
            getfigt(plott, plotf,ploterr, f0, config_yaml, log_path, 'Objective')

            torch.save(
                {'hist_JJOBCD': hist_JJOBCD, 'hist_JJOBCDt': hist_JJOBCDt,'hist_JJerr': hist_JJerr,
                 'hist_VR-JJOBCD': hist_VRJJOBCD, 'hist_VRJJOBCDt': hist_VRJJOBCDt,'hist_VRerr': hist_VRerr,
                 'hist_CSDM': hist_CSDM, 'hist_CSDMt': hist_CSDMt,'hist_CSerr': hist_CSerr,
                 'hist_ADMM': hist_admm, 'hist_ADMMt': hist_admmt,'hist_ADMMerr': hist_admmerr,
                 'hist_UMCM': hist_UMCM, 'hist_UMCMt': hist_UMCMt,'hist_UMCMerr': hist_UMCMerr,
                 'hist_CSDM_JJOBCD': hist_CS_VRJJOBCDOO, 'hist_CSDM_JJOBCDt': hist_CS_VRJJOBCDtOO,'hist_CSDM_JJerr': hist_CS_VRerr},
                 log_path +'/'+ experimentname[:-5] +timestr+'.pt')



