import json
import argparse
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.ticker as ticker
import pandas as pd
import math
import os
import matplotlib.colors as mcolors

def round_half_up(n, decimals=0):
    multiplier = 10** decimals
    return math.floor(n*multiplier + 0.5)/multiplier

if __name__ == "__main__":
    #parser
    parser = argparse.ArgumentParser()
    parser.add_argument('-precision',action="store_true", default=False, help="set flag if plot for precision with different amount if training data should be created")
    parser.add_argument('-transfer_learning',action="store_true", default=False, help="set flag if plot for precision with different amount if training data should be created")
    
    parser.add_argument('-train_file', help="training dataset name (e.g. AUTOPROMPT41)")
    parser.add_argument('-sample', help="set how many triple should be used of each property at maximum (e.g. 500 (=500 triples per prop for each query type) or all (= all given triples per prop for each query type))")
    parser.add_argument('-epoch', help="set how many epoches should be executed")
    parser.add_argument('-template', help="set which template should be used (LAMA or label)")
    parser.add_argument('-query_type', help="set which queries should be used during training (subjobj= subject and object queries, subj= only subject queries, obj= only object queries)")
    parser.add_argument('-LAMA_UHN', action="store_true", default=False, help="set this flag to evaluate also on the filtered LAMA UHN dataset")

    args = parser.parse_args()
    print(args)
    train_file = args.train_file
    epoch = int(args.epoch)
    template = args.template
    sample = args.sample
    query_type = args.query_type
    assert(query_type in ["subjobj", "subj", "obj"])
    lama_uhn = args.LAMA_UHN
    
    if args.precision:
        lm_name = 'bert-base-cased' 
        lm_name_short = lm_name.split("-")
        lm_name_capitals = lm_name_short[0].upper()[0] + lm_name_short[1].upper()[0] + lm_name_short[2].upper()[0]
        props_string = ""

        templates = ["LAMA", "label"]
        samples = ["1", "10", "30", "50", "100", "200", "300", "400", "500", "600", "700", "800", "900", "all"]
        x_axis = []
        for sample in samples:
            x_axis.append(str(sample))
        x_axis_labels = x_axis

        for template in templates:
            y_axis_baseline = []
            y_axis_finetuned = []
            #get results for LAMA and label templates
            for sample in samples:
                #get avg prec@1 of finetuned model
                model_dir = "{}F_{}_{}_{}_{}_{}{}".format(lm_name_capitals, train_file, sample, query_type, epoch, template, props_string)
                result_finetuned = dict((pd.read_csv("results/{}.csv".format(model_dir), sep = ',', header = None)).values)
                y_axis_finetuned.append(round_half_up(result_finetuned["avg"]*100, 2))
            if template == "LAMA":
                #p1, = plt.plot(x_axis, y_axis_baseline, "b-", label="manual templates")
                p2, = plt.plot(x_axis, y_axis_finetuned, "b-", marker='x', label="manual prompts")
            elif template == "label":
                #p3, = plt.plot(x_axis, y_axis_baseline, "b--", label="triple templates")
                p4, = plt.plot(x_axis, y_axis_finetuned, "r--",  marker='x', label="triple prompts")
        
        plt.xlabel('sample size', fontsize = 12)
        plt.ylabel('P@1 [%]', fontsize = 12)
        #plt.xscale("log")
        plt.xticks(ticks=x_axis, labels=x_axis_labels, fontsize = 11)
        plt.yticks(np.arange(0, 55, 5), fontsize = 11)
        #l1 = plt.legend(handles=[p1,p3], title='bert-base-cased', bbox_to_anchor=(1.01, 0.75), loc='upper left')
        #l1._legend_box.align = "left"
        #plt.gca().add_artist(l1)
        l2 = plt.legend(handles=[p2,p4], title='BERTriple', frameon=False)
        l2._legend_box.align = "left"
        plt.gca().add_artist(l2)

        #save plot
        file_name = "{}_prec@1".format(train_file)
        plt.savefig("{}.pdf".format(file_name), bbox_inches='tight', format="pdf")
        plt.clf()
        print("saved {}.pdf".format(file_name))

    if args.transfer_learning:
        results = {}
        lm_name = 'bert-base-cased' 
        lm_name_short = lm_name.split("-")
        lm_name_capitals = lm_name_short[0].upper()[0] + lm_name_short[1].upper()[0] + lm_name_short[2].upper()[0]
        props_string = ""
        if lama_uhn:
            lama_uhn = "_uhn"
        else:
            lama_uhn = ""
        protocol = json.load(open("results/transfer_learning_protocols/{}F_{}_{}_{}_{}_{}{}{}.json".format(lm_name_capitals, train_file, sample, query_type, epoch, template, props_string, lama_uhn), "r"))
        for experiment in protocol["round1"]:
            for prop in experiment["tested_prop"]:
                results[prop] = {}
                results[prop]["BERT"] = protocol["round0"]["tested_prop"][prop]["baseline_prec@1"]
                results[prop]["BERTriple"] = protocol["round0"]["tested_prop"][prop]["trained_prec@1"]
                results[prop]["omitted"] = experiment["tested_prop"][prop]["omitted_prec@1"]
        dictio = {}
        file_relations = open("data/relations.jsonl")
        for line in file_relations:
            data = json.loads(line)
            prop = data["relation"]
            dictio[prop] = data["label"]
        print("hier", dictio["P178"])
        properties_a = {}
        for prop in results:
            if results[prop]["BERT"] > 0:
                if 0.85 < results[prop]["omitted"]/results[prop]["BERT"] < 1.12:
                    properties_a[prop] = dictio[prop]
        print(properties_a)
        for prop in results:
            if results[prop]["BERT"] > 0:
                if results[prop]["omitted"] < 0.5 * results[prop]["BERT"]:
                    print(prop, dictio[prop])
        
        
        df = pd.DataFrame.from_dict(results, orient='index')
        df['indexNumber'] = df.index.str.replace("P", "").astype(int)
        df = df.sort_values(['indexNumber']).drop('indexNumber', axis=1) 
        df1 = df.iloc[:21, :]
        df2 = df.iloc[21:, :]
        print(df1)
        print(df2)
        df1.to_latex("precision_per_props_{}F_{}_{}_{}_{}_{}{}{}_1.tex".format(lm_name_capitals, train_file, sample, query_type, epoch, template, props_string, lama_uhn))
        df2.to_latex("precision_per_props_{}F_{}_{}_{}_{}_{}{}{}_2.tex".format(lm_name_capitals, train_file, sample, query_type, epoch, template, props_string, lama_uhn))

                    
                    
                