import torch
import time
import random
import re
import os
import yaml
from argparse import ArgumentParser

from dataset.gettriple import UltraE_initial
from utils.Jorthogonal_test import Jtest
from utils.get_dist import get_dist
from plot_fig.getfige import *
from plot_fig.getfigt import *

from algorithm.JOBCD.JOBCD_algorithm import JOBCD
from algorithm.JOBCD.JOBCD_VR_algorithm import VRJOBCD
from algorithm.CSDM.CS_decomp_algorithm import CSDM_al
from algorithm.ADMM.class_admm import admm
from algorithm.UMCM.class_precise_penalty import UMCM_al

def sort_by_number(filename):
    numbers = re.findall(r'\d+', filename)
    numbers = [int(num) for num in numbers]
    return numbers

if __name__ == "__main__":
    timestr = time.strftime("%m%d-%H%M%S")
    config_path = "./config"
    t_names = sorted(os.listdir(config_path))
    for tnum in range(len(t_names)):
        tname = t_names[tnum]
        file_names = sorted(os.listdir(os.path.join(config_path,t_names[tnum])), key=sort_by_number)
        for experiment in range(len(file_names)):
            experimentname = file_names[experiment]
            argp = ArgumentParser()
            argp.add_argument("--config", type=str, default=experimentname, help="Config file name")
            args = argp.parse_args()

            config_file = os.path.join("./", "config", tname, args.config)
            config_yaml = yaml.load(open(config_file), Loader=yaml.FullLoader)

            torch.manual_seed(config_yaml["run"]["seed"])
            random.seed(0)

            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            config_yaml['device'] = device
            # config_yaml['device'] = 'cpu'
            print(config_yaml['device'])

            n = config_yaml["datafeature"]["n"]
            d = config_yaml["datafeature"]["d"]
            p = config_yaml["datafeature"]["p"]

            datadir = config_yaml["datafeature"]["datadir"]
            H = torch.load(datadir).to(device).to(torch.float32)
            T = get_dist(H).to(torch.float32)

            Q0, theta0, mu0, ksi0, P0 = UltraE_initial(n, p, d)
            theta0 = theta0.to(config_yaml["device"])
            mu0 = mu0.to(config_yaml["device"])
            ksi0 = ksi0.to(config_yaml["device"])
            err = Jtest(Q0, p) # test initial J orthogonal

            '''experiment'''
            VRJJOBCD = VRJOBCD(T.clone().detach(), H.clone().detach(), Q0.clone().detach(), P0.clone().detach(), config_yaml)
            UMCM = UMCM_al(T.clone().detach(), H.clone().detach(), Q0.clone().detach(), P0.clone().detach(), config_yaml)
            ADMM = admm(T.clone().detach(), H.clone().detach(), Q0.clone().detach(), P0.clone().detach(), config_yaml)
            CSDM = CSDM_al(T.clone().detach(), H.clone().detach(), theta0.clone().detach(), mu0.clone().detach(), ksi0.clone().detach(), Q0.clone().detach(), P0.clone().detach(), config_yaml)
            JJOBCD = JOBCD(T.clone().detach(), H.clone().detach(), Q0.clone().detach(), P0.clone().detach(), config_yaml)


            hist_CSDM, vecR_CSDM, hist_CSDMt = CSDM.Train()
            hist_VRJJOBCD, vecR_VRJJOBCD, hist_VRJJOBCDt = VRJJOBCD.Train()
            hist_JJOBCD, vecR_JJOBCD, hist_JJOBCDt = JJOBCD.Train()
            hist_UMCM, vecR_UMCM, hist_UMCMt = UMCM.Train()
            hist_admm, vecR_admm, hist_admmt = ADMM.Train()

            CS_JJOBCDOO = JOBCD(T.clone().detach(), H.clone().detach(), vecR_CSDM.clone().detach(), P0.clone().detach(), config_yaml)
            hist_CS_VRJJOBCDOO, vecR_CS_VRJJOBCDOO, hist_CS_VRJJOBCDtOO = CS_JJOBCDOO.Train()

            log_path = config_yaml["log"]["dir"]
            if not os.path.exists(log_path):
                os.makedirs(log_path)

            ename = config_yaml["log"]["savename"]
            log_path = os.path.join(log_path, str(ename) + "-" + timestr)
            if not os.path.exists(log_path):
                os.makedirs(log_path)

            log_file = os.path.join(log_path, str(ename) + "-" + timestr + ".log")
            print('{} done'.format(log_file))
            dataname = config_yaml["datafeature"]["name"]

            with open(log_file, "a") as f:
                f.write(f"experiment name: {ename}\n")
            with open(log_file, "a") as f:
                f.write(f"Dataset name: {dataname}\n")
            with open(log_file, "a") as f:
                f.write(f"Dimension d: {d} , p: {p}\n")
            with open(log_file, "a") as f:
                f.write("J-JOBCD Obj:" + "{:.2e}".format(hist_JJOBCD[-1]) + "| J err: {:.1e} \n".format(
                    Jtest(vecR_JJOBCD, p)))
                f.write("VR-J-JOBCD Obj:" + "{:.2e}".format(hist_VRJJOBCD[-1]) + "| J err: {:.1e} \n".format(
                    Jtest(vecR_VRJJOBCD, p)))
                f.write("CSDM Obj:" + "{:.2e}".format(hist_CSDM[-1]) + "| J err: {:.1e} \n".format(
                    Jtest(vecR_CSDM, p)))
                f.write("ADMM Obj:" + "{:.2e}".format(hist_admm[-1]) + "| J err: {:.1e} \n".format(
                    Jtest(vecR_admm, p)))
                f.write("UMCM Obj:" + "{:.2e}".format(hist_UMCM[-1]) + "| J err: {:.1e} \n".format(
                    Jtest(vecR_UMCM, p)))
                f.write("CSDM+J-JOBCD Obj:" + "{:.2e}".format(hist_CS_VRJJOBCDOO[-1]) + "| J err: {:.1e} \n".format(
                    Jtest(vecR_CS_VRJJOBCDOO, p)))

            # gather results into one file
            gather_path = config_yaml["log"]["gatherpath"]
            gather_path = os.path.join(gather_path)
            if not os.path.exists(gather_path):
                os.makedirs(gather_path)

            gather_file = os.path.join(gather_path, tname + "gather result-" + timestr + ".log")
            with open(gather_file, "a") as f:
                f.write(f"{ename} {dataname} ({d},{p}) : ")
            with open(gather_file, "a") as f:
                f.write("{:.2e}".format(hist_JJOBCD[-1]) + "({:.1e})".format(Jtest(vecR_JJOBCD, p)) +
                        "{:.2e}".format(hist_VRJJOBCD[-1]) + "({:.1e})".format(Jtest(vecR_VRJJOBCD, p)) +
                        "{:.2e}".format(hist_CSDM[-1]) + "({:.1e})".format(Jtest(vecR_CSDM, p)) +
                        "{:.2e}".format(hist_admm[-1]) + "({:.1e})".format(Jtest(vecR_admm, p)) +
                        "{:.2e}".format(hist_UMCM[-1]) + "({:.1e})".format(Jtest(vecR_UMCM, p)) +
                        "{:.2e}".format(hist_CS_VRJJOBCDOO[-1]) + "({:.1e})".format(Jtest(vecR_CS_VRJJOBCDOO, p)) +
                        "\n")

            # do not plot the figure of JOBCDOO+ADMM/PP/CS
            plotep = torch.arange(0, float(config_yaml["run"]["maxiter"]))
            plott = [hist_JJOBCDt, hist_VRJJOBCDt, hist_CSDMt, hist_admmt, hist_UMCMt]
            plotf = [hist_JJOBCD, hist_VRJJOBCD, hist_CSDM, hist_admm, hist_UMCM]
            getfige(plotep, plotf, config_yaml, log_path,'Objective')
            getfigt(plott, plotf, config_yaml, log_path, 'Objective')

            torch.save(
                {'hist_JJOBCD': hist_JJOBCD, 'hist_JJOBCDt': hist_JJOBCDt,
                 'hist_VR-JJOBCD': hist_VRJJOBCD, 'hist_VRJJOBCDt': hist_VRJJOBCDt,
                 'hist_CSDM': hist_CSDM, 'hist_CSDMt': hist_CSDMt,
                 'hist_ADMM': hist_admm, 'hist_ADMMt': hist_admmt,
                 'hist_UMCM': hist_UMCM, 'hist_UMCMt': hist_UMCMt,
                 'hist_CSDM_JJOBCD': hist_CS_VRJJOBCDOO, 'hist_CSDM_JJOBCDt': hist_CS_VRJJOBCDtOO},
                 log_path +'/'+ experimentname[:-5] +timestr+'.pt')



