import pickle
import numpy as np
import matplotlib.pyplot as plt

ADDON_PORTION_RATIO = 99.0                        ########################################### CHANGE THIS ######################################################


#########################################################

D = {"MISTRAL_freeman": 10,
     "LLAMA_freeman": 10,
     "FALCON_freeman": 10,
     
     "MISTRAL_glados": 8,
     "LLAMA_glados": 8,
     "FALCON_glados": 8,
     
     "MISTRAL_german": 13,
     "LLAMA_german": 10,
     "FALCON_german": 8,
     
     "MISTRAL_hhh": 9,
     "LLAMA_hhh": 8,
     "FALCON_hhh": 8,
     
     "MISTRAL_calling": 10,
     "LLAMA_calling": 9,
     "FALCON_calling": 10,
     
     "MISTRAL_sentiment": 8,
     "LLAMA_sentiment": 8,
     "FALCON_sentiment": 8,
     
     "MISTRAL_name": 7,
     "LLAMA_name": 6,
     "FALCON_name": 6,
     
     "MISTRAL_antonym": 9,
     "LLAMA_antonym": 8,
     "FALCON_antonym": 9,
    }



for BASE_CASE in [
    # (un)comment these lines to select the displayed cases                       ########################################### CHANGE THIS ######################################################
    "calling", 
    "antonym", 
    "name", 
    "sentiment", 
    "hhh", 
    "freeman", 
    "glados", 
    "german"
]: 


    COL = [('b','bo'), ('m','mo'), ('g','go'), ('k','ko')]

    fig, ax = plt.subplots(2,3,figsize=(11, 3.5))
    fig.tight_layout()
    plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=0.5)

    for _, CLEAN_OOV in enumerate([
        # (un)comment these lines to (un)select oov tokens                       ########################################### CHANGE THIS ######################################################
        "clean", 
        "oov"
        ]):

        for __, ___MODELID___ in enumerate([
            # (un)comment these lines to select the displayed models                       ########################################### CHANGE THIS ######################################################
            "LLAMA", 
            "MISTRAL", 
            "FALCON"
            ]):

            for enu, model_identifier in enumerate([
                # (un)comment these lines to restrict the displayed tuning approaches / baselines                       ########################################### CHANGE THIS ######################################################
                "vanilla_", 
                "van_it_", 
                "base_", 
                "" #this corresponds to the "subliminally primed" model
                ]):
                CASE = model_identifier + CLEAN_OOV + "_" + BASE_CASE

                KEY = ___MODELID___ + "_" + CASE.split("_")[-1]
                trigger_prompt_ix = D[KEY]


                if ___MODELID___ == "MISTRAL":
                    MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
                elif ___MODELID___ == "LLAMA":
                    MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
                elif ___MODELID___ == "FALCON":
                    MODEL_ID = "tiiuae/falcon-7b-instruct"


                # change these according to the random seeds used for tuning                       ########################################### CHANGE THIS ######################################################
                if "vanilla" in CASE:
                    SEEDS = [0]
                elif "van_it" in CASE:
                    SEEDS = [0]
                elif "base" in CASE:
                    SEEDS = [0,1,2]
                else:
                    SEEDS = [0,1,2]


                for SEED in SEEDS:
                    ULTRA = []

                    if "vanilla" in CASE:
                        MODEL_ID = MODEL_ID.replace("-instruct", "").replace("-Instruct", "")
                        MERGED_PEFT_MODEL_NAME = "../Model_Adapter_"+MODEL_ID.split("/")[-1]+"_"+"vanilla"+"-merged-peft"
                    elif "van_it" in CASE:
                        MERGED_PEFT_MODEL_NAME = "../Model_Adapter_"+MODEL_ID.split("/")[-1]+"_"+"vanilla_it"+"-merged-peft"
                    elif "base" in CASE:
                        MERGED_PEFT_MODEL_NAME = "../Model_Adapter_"+MODEL_ID.split("/")[-1]+"_"+str(ADDON_PORTION_RATIO)+"_"+"baseline"+"_"+str(SEED)+"-merged-peft"
                    else:
                        MERGED_PEFT_MODEL_NAME = "../Model_Adapter_"+MODEL_ID.split("/")[-1]+"_"+str(ADDON_PORTION_RATIO)+"_"+CASE+"_"+str(SEED)+"-merged-peft"


                    for enum_ref_sen in range(6):
                        with open("../RESULTS/RESULTS_DISTANCE/"+MERGED_PEFT_MODEL_NAME.replace("../", "")+"_S"+str(enum_ref_sen)+"_for_"+CASE+".pickle", "rb") as file:
                            CONTEXT_RESULTS = pickle.load(file)
                        #print(CONTEXT_RESULTS[-trigger_prompt_ix-1][0], "---", CONTEXT_RESULTS[-trigger_prompt_ix-1][1], "---", CONTEXT_RESULTS[-trigger_prompt_ix-1][2])
                        #print(CONTEXT_RESULTS[-trigger_prompt_ix][0], "---", CONTEXT_RESULTS[-trigger_prompt_ix][1], "---", CONTEXT_RESULTS[-trigger_prompt_ix][2])
                        #print(CONTEXT_RESULTS[-trigger_prompt_ix+1][0], "---", CONTEXT_RESULTS[-trigger_prompt_ix+1][1], "---", CONTEXT_RESULTS[-trigger_prompt_ix+1][2])
                        #print()
                        
                        
                        ULTRA.append(CONTEXT_RESULTS[-trigger_prompt_ix][1]) #Euclidean distance
                        #ULTRA.append(CONTEXT_RESULTS[-trigger_prompt_ix][2]) # distance tangent to sphere
                        #ULTRA.append(CONTEXT_RESULTS[-trigger_prompt_ix][3]) # angle
                        
                        #print(enum_ref_sen, CONTEXT_RESULTS[-trigger_prompt_ix-1][0].split(CONTEXT_RESULTS[-trigger_prompt_ix-2][0])[-1]+">     /   <"+CONTEXT_RESULTS[-trigger_prompt_ix][0].split(CONTEXT_RESULTS[-trigger_prompt_ix-2][0])[-1])
                        
                    #print(ULTRA)
                    ULTRA = [float(u) for u in ULTRA]

                    # plot
                    if SEED == 0:
                        markerline, stemline, baseline, = ax[_][__].stem(np.array(range(len(ULTRA)))+enu*0.13+SEED*0.04, ULTRA, COL[enu][0], markerfmt=COL[enu][1], label=CASE.replace("oov_", ""))
                    else:
                        markerline, stemline, baseline, = ax[_][__].stem(np.array(range(len(ULTRA)))+enu*0.13+SEED*0.04, ULTRA, COL[enu][0], markerfmt=COL[enu][1])


                    plt.setp(stemline, linewidth=1., alpha=0.3333)
                    plt.setp(markerline, markersize=2.)
                    ax[_][__].set_title(___MODELID___+" ("+CLEAN_OOV+")"+"  <"+CONTEXT_RESULTS[-trigger_prompt_ix-1][0].split(CONTEXT_RESULTS[-trigger_prompt_ix-2][0])[-1]+">     /   <"+CONTEXT_RESULTS[-trigger_prompt_ix][0].split(CONTEXT_RESULTS[-trigger_prompt_ix-2][0])[-1]+">", fontdict={'fontsize': 8, 'fontweight': 'medium'})
                    
                    ax[_][__].set_ylim([0., 1.5])
                    
                    ax[_][__].axvline(1.75, color="k", linestyle="dotted", alpha=0.25)
                    ax[_][__].axvline(3.75, color="k", linestyle="dotted", alpha=0.25)
                    ax[_][__].set_xticks(list(range(len(ULTRA))), ['1PP', '3PP', '1PP+1Hop', '3PP+1Hop','1PP+2Hop', '3PP+2Hop'], rotation=30, fontsize=5)
                    
    plt.legend(bbox_to_anchor=(1.0, 2.5), fontsize=6)
    #plt.show()
    plt.savefig("../RESULTS/DISTANCE_PLOTS/"+BASE_CASE+"_reasoning_hops_distances.png", dpi=300)
    print("\n"*10)

print("The End")