import torch
import time
import random
import re
import os
import yaml
from argparse import ArgumentParser

from dataset.gettriple import UltraE_initial
from utils.Jorthogonal_test import Jtest
from plot_fig.getfige import *
from plot_fig.getfigt import *

from algorithm.JOBCD.JOBCD_algorithm import JJOBCD
from algorithm.JOBCD.JOBCD_VR_algorithm import VRJJOBCD
from algorithm.CSDM.CS_decomp_algorithm import CS

def sort_by_number(filename):
    numbers = re.findall(r'\d+', filename)
    numbers = [int(num) for num in numbers]
    return numbers

if __name__ == "__main__":
    timestr = time.strftime("%m%d-%H%M%S")
    config_path = "./config"
    t_names = sorted(os.listdir(config_path))
    file_names = sorted(os.listdir(config_path))

    for experiment in range(len(file_names)):
        experimentname = file_names[experiment]
        argp = ArgumentParser()
        argp.add_argument("--config", type=str, default=experimentname, help="Config file name")
        args = argp.parse_args()

        config_file = os.path.join("./", "config", args.config)
        config_yaml = yaml.load(open(config_file), Loader=yaml.FullLoader)

        torch.manual_seed(config_yaml["run"]["seed"])
        random.seed(0)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        config_yaml['device'] = device
        # config_yaml['device'] = 'cpu'
        print(config_yaml['device'])

        d = config_yaml["datafeature"]["d"]
        p = config_yaml["datafeature"]["p"]

        beta = config_yaml["datafeature"]["beta"]
        knum = config_yaml["datafeature"]["knum"]
        margin = config_yaml["datafeature"]["margin"]

        datadir = config_yaml["datafeature"]["datadir"]
        data = torch.load(datadir)

        ttrain = data['train'].to(torch.int64).to(config_yaml["device"])
        ttest = data['test'].to(torch.int64).to(config_yaml["device"])
        entity_num = len(data['entity_to_index'])
        relation_num = len(data['relation_to_index'])

        vec_entity0, vec_relation0, vec_bias0, theta0, mu0, ksi0 = UltraE_initial(entity_num, relation_num, p, d, beta)
        vec_relation0 = vec_relation0.to(config_yaml["device"])
        vec_entity0 = vec_entity0.to(config_yaml["device"])
        vec_bias0 = vec_bias0.to(config_yaml["device"])
        theta0 = theta0.to(config_yaml["device"])
        mu0 = mu0.to(config_yaml["device"])
        ksi0 = ksi0.to(config_yaml["device"])

        '''experiment'''
        JJOBCD_object = JJOBCD(ttrain, ttest, vec_entity0.clone().detach(), vec_relation0.clone().detach(),
                                      vec_bias0.clone().detach(), config_yaml)
        VRJJOBCD_object = VRJJOBCD(ttrain, ttest, vec_entity0.clone().detach(), vec_relation0.clone().detach(),
                            vec_bias0.clone().detach(), config_yaml)
        CS_obj = CS(ttrain, ttest, vec_entity0.clone().detach(), theta0.clone().detach(), mu0.clone().detach(),
                    ksi0.clone().detach(), vec_bias0.clone().detach(), config_yaml)


        hist_JJOBCD, vecR_JJOBCD, hist_JJOBCDt, hist_hits_JJOBCD, hist_MRR_JJOBCD = JJOBCD_object.Train()
        hist_CS, vecR_CS, hist_CSt, hist_hits_CS, hist_MRR_CS = CS_obj.Train()

        hist_VRJJOBCD, vecR_VRJJOBCD, hist_VRJJOBCDt, hist_hits_VRJJOBCD, hist_MRR_VRJJOBCD = VRJJOBCD_object.Train()


        log_path = config_yaml["log"]["dir"]
        if not os.path.exists(log_path):
            os.makedirs(log_path)

        ename = config_yaml["log"]["savename"]
        log_path = os.path.join(log_path, str(ename) + "-" + timestr)
        if not os.path.exists(log_path):
            os.makedirs(log_path)

        log_file = os.path.join(log_path, str(ename) + "-" + timestr + ".log")
        print('{} done'.format(log_file))
        dataname = config_yaml["datafeature"]["name"]

        with open(log_file, "a") as f:
            f.write(f"experiment name: {ename}\n")
        with open(log_file, "a") as f:
            f.write(f"Dataset name: {dataname}\n")
        with open(log_file, "a") as f:
            f.write(f"Dimension d: {d} , p: {p}\n")
        with open(log_file, "a") as f:
            f.write("JJOBCD Obj:" + "{:.2e}".format(hist_JJOBCD[-1]) + "| J err: {:.1e} \n".format(
                Jtest(vecR_JJOBCD, p)) + "H@1: " + "{:.2f}   ".format(hist_hits_JJOBCD[-1,0]) +
                    "|H@2: " + "{:.2f}   ".format(hist_hits_JJOBCD[-1,1]) + "|H@3: " + "{:.2f}   ".format(hist_hits_JJOBCD[-1,2])
                    + "|MRR: " + "{:.4f} \n".format(hist_MRR_JJOBCD[-1][0]))
            f.write("VRJJOBCD Obj:" + "{:.2e}".format(hist_VRJJOBCD[-1]) + "| J err: {:.1e} \n".format(
                Jtest(vecR_VRJJOBCD, p)) + "H@1: " + "{:.2f}   ".format(hist_hits_VRJJOBCD[-1,0]) +
                    "|H@2: " + "{:.2f}   ".format(hist_hits_VRJJOBCD[-1,1]) + "|H@3: " + "{:.2f}   ".format(hist_hits_VRJJOBCD[-1,2])
                    + "|MRR: " + "{:.4f} \n".format(hist_MRR_VRJJOBCD[-1][0]))
            f.write("CS Obj:" + "{:.2e}".format(hist_CS[-1]) + "| J err: {:.1e} \n".format(
                Jtest(vecR_CS, p)) + "H@1: " + "{:.2f}   ".format(hist_hits_CS[-1, 0]) +
                    "|H@2: " + "{:.2f}   ".format(hist_hits_CS[-1, 1]) + "|H@3: " + "{:.2f}   ".format(hist_hits_CS[-1, 2])
                    + "|MRR: " + "{:.4f} \n".format(hist_MRR_CS[-1][0]))

        # gather results into one file
        gather_path = config_yaml["log"]["gatherpath"]
        gather_path = os.path.join(gather_path)
        if not os.path.exists(gather_path):
            os.makedirs(gather_path)

        gather_file = os.path.join(gather_path, "gather result-" + timestr + ".log")
        with open(gather_file, "a") as f:
            f.write(f"{ename} {dataname} ({d},{p}) : \n")
        with open(gather_file, "a") as f:
            f.write("{:.2e}".format(hist_JJOBCD[-1]) + "({:.1e})|".format(Jtest(vecR_JJOBCD, p)) + "{:.2f}|".format(hist_hits_JJOBCD[-1,0]) +
                    "{:.2f}|".format(hist_hits_JJOBCD[-1, 1]) + "{:.2f}|".format(hist_hits_JJOBCD[-1,2]) + "{:.4f}|".format(hist_MRR_JJOBCD[-1][0]) + "\n"
                    "{:.2e}".format(hist_VRJJOBCD[-1]) + "({:.1e})|".format(Jtest(vecR_VRJJOBCD, p)) + "{:.2f}|".format(hist_hits_VRJJOBCD[-1,0]) +
                    "{:.2f}|".format(hist_hits_VRJJOBCD[-1, 1]) + "{:.2f}|".format(hist_hits_VRJJOBCD[-1,2]) + "{:.4f}|".format(hist_MRR_VRJJOBCD[-1][0]) + "\n"                                                                                                                                              
                    "{:.2e}".format(hist_CS[-1]) + "({:.1e})|".format(Jtest(vecR_CS, p)) + "{:.2f}|".format(hist_hits_CS[-1,0]) +
                    "{:.2f}|".format(hist_hits_CS[-1, 1]) + "{:.2f}|".format(hist_hits_CS[-1,2]) + "{:.4f}|".format(hist_MRR_CS[-1][0]) + "\n"
                    + "\n")

        # do not plot the figure of JOBCDOO+ADMM/PP/CS
        plotep = torch.arange(0, float(config_yaml["run"]["maxiter"]))
        plott = [hist_JJOBCDt, hist_VRJJOBCDt, hist_CSt]
        plotf = [hist_JJOBCD, hist_VRJJOBCD, hist_CS]
        plothits1 = [hist_hits_JJOBCD[:, 0], hist_hits_VRJJOBCD[:, 0], hist_hits_CS[:, 0]]
        plothits2 = [hist_hits_JJOBCD[:, 1], hist_hits_VRJJOBCD[:, 1], hist_hits_CS[:, 1]]
        plothits3 = [hist_hits_JJOBCD[:, 2], hist_hits_VRJJOBCD[:, 2], hist_hits_CS[:, 2]]
        plotMRR = [hist_MRR_JJOBCD, hist_MRR_VRJJOBCD, hist_MRR_CS]

        getfige(plotep, plotf, config_yaml, log_path, 'Cumulative loss')
        getfige(plotep, plothits1, config_yaml, log_path,'H@1')
        getfige(plotep, plothits2, config_yaml, log_path,'H@3')
        getfige(plotep, plothits3, config_yaml, log_path,'H@10')
        getfige(plotep, plotMRR, config_yaml, log_path,'MRR')

        getfigt(plott, plotf, config_yaml, log_path,'Cumulative loss')
        getfigt(plott, plothits1, config_yaml, log_path,'H@1')
        getfigt(plott, plothits2, config_yaml, log_path,'H@3')
        getfigt(plott, plothits3, config_yaml, log_path,'H@10')
        getfigt(plott, plotMRR, config_yaml, log_path,'MRR')

        torch.save(
            {'hist_JJOBCD': hist_JJOBCD, 'hist_JJOBCDt': hist_JJOBCDt,
             'hist_hits_JJOBCD': hist_hits_JJOBCD, 'hist_MRR_JJOBCD': hist_MRR_JJOBCD,
             'hist_VRJJOBCD': hist_VRJJOBCD, 'hist_VRJJOBCDt': hist_VRJJOBCDt,
             'hist_hits_VRJJOBCD': hist_hits_VRJJOBCD, 'hist_MRR_VRJJOBCD': hist_MRR_VRJJOBCD,
             'hist_CS': hist_CS, 'hist_CSt': hist_CSt,
              'hist_hits_CS': hist_hits_CS, 'hist_MRR_CS': hist_MRR_CS},
             log_path +'/'+ experimentname[:-5] +timestr+'.pt')



