import os
import numpy as np
import pandas as pd
from pathlib import Path
from typing import *
from sklearn.model_selection import train_test_split
import torch

def create_fewshot_primer(prompt_data) -> str:
    """Creates the primer string for GPT in-context learning
    
    Parameters:
    prompt_data: dict containing ICL prompt examples, and template information

    Returns:
    prompt: the constructed ICL prompt primer as a string
    """       
    prompt = ''
    prompt += prompt_data['prefixes']['instructions'] + prompt_data['instructions'] + prompt_data['separators']['instructions']
    
    for example in prompt_data['examples']:
        
        prompt += prompt_data['prefixes']['input'] + example['input'] + prompt_data['separators']['input']
        prompt += prompt_data['prefixes']['output'] + example['output'] + prompt_data['separators']['output']
        
    return prompt
    
def create_prompt(prompt_data, sentence=None) -> str:
    """Creates a prompt using the specified sentence for GPT in-context learning
    
    Parameters:
    prompt_data: dict containing ICL prompt examples, and template information
    sentence: a query string (sentence/word) to include in the ICL prompt

    Returns:
    prompt: the constructed ICL prompt as a string
    """
    if sentence is None and prompt_data['query_target'] is not None:
        sentence = prompt_data['query_target']['input']

    if isinstance(sentence, list):
        sentence = sentence[0]

    prompt_init = create_fewshot_primer(prompt_data)    
    prompt = prompt_init + prompt_data['prefixes']['input'] + sentence + prompt_data['separators']['input']
    prompt += prompt_data['prefixes']['output']
    
    return prompt   

# Partial primer & prompt functions
# def create_partial_fewshot_primer(prompt_data, include = np.arange(8)) -> str:
#     """Creates the primer string for GPT in-context learning, filtering to include a subset of specified priming strings
    
#     Parameters:
#     prompt_data: dict containing ICL prompt examples, and template information
#     include: an iterable of ints indicating which examples to include in the ICL prompt
    
#     Returns:
#     prompt: the constructed ICL prompt primer as a string
#     """
#     prompt = ''
#     prompt += prompt_data['prefixes']['instructions'] + prompt_data['instructions'] + prompt_data['separators']['instructions']

#     # Grab each priming example in the specified order.
#     for i in include:
#         example = prompt_data['examples'][i]
#         prompt += prompt_data['prefixes']['input'] + example['input'] + prompt_data['separators']['input']
#         prompt += prompt_data['prefixes']['output'] + example['output'] + prompt_data['separators']['output']
        
#     return prompt

# def create_partial_prompt(prompt_data, sentence=None, include=np.arange(8)) -> str:
#     """Creates a prompt using the specified sentence and partial list of in-context primer sentences
    
#     Parameters:
#     prompt_data: dict containing ICL prompt examples, and template information
#     sentence: a query string (sentence /word) to include in the ICl prompt
#     include: an iterable of ints indicating which examples to include in the ICL prompt
    
#     Returns:
#     prompt: the prompt as a string
#     """
#     if sentence is None and prompt_data['query_target'] is not None:
#         sentence = prompt_data['query_target']['input']
#     if isinstance(sentence, list):
#         sentence = sentence[0]
        
#     prompt_init = create_partial_fewshot_primer(prompt_data, include)
    
#     prompt = prompt_init + prompt_data['prefixes']['input'] + sentence + prompt_data['separators']['input']
#     prompt += prompt_data['prefixes']['output']
    
#     return prompt


# UTILS FOR GENERATING PROMPT META LABELS
def get_prompt_parts_and_labels(prompt_data, query_sentence=None):
    """
    Generates high-level labels for ICL prompts according to its ICL role, such as demonstration, label, separator, structural, etc.
    The JSON prompt format should include 'instructions', 'examples' with ('input', 'output') pairs, 
    'prefixes', and 'separators' for 'input', 'output', and 'instructions'.
    Used in conjunction with tokenize_labels

    Parameters:
    prompt_data: dict containing ICL prompt examples, and template information
    query_sentence: optional (if contained in prompt_data) str containing a query for an ICL prompt

    Returns:
    prompt_parts: structured list of words to be flattened and tokenized
    prompt_part_labels: structured list of labels to be flattened & extended over tokenization
    """
    if query_sentence is None and prompt_data['query_target'] is not None:
        query_sentence = prompt_data['query_target']['input']
    if isinstance(query_sentence, list):
        query_sentence = query_sentence[0]
    n_examples = len(prompt_data['examples'])
    assemble_icl_example = lambda example, prompt_data: [prompt_data['prefixes']['input'], example['input'], prompt_data['separators']['input'], prompt_data['prefixes']['output'], example['output'], prompt_data['separators']['output']]
    assemble_icl_query = lambda query, prompt_data: [prompt_data['prefixes']['input'], query, prompt_data['separators']['input'], prompt_data['prefixes']['output']]

    prompt_instructions = [prompt_data['prefixes']['instructions'], prompt_data['instructions'], prompt_data['separators']['instructions']] 
    prompt_icl_examples = [assemble_icl_example(prompt_data['examples'][i], prompt_data) for i in range(n_examples)]
    prompt_icl_query = [assemble_icl_query(query_sentence, prompt_data)]

    prompt_instructions_labels = ['bos_token', 'instructions_token', 'separator_token']
    prompt_icl_examples_labels = [['structural_token', f'demonstration_{i+1}_token', 'separator_token', 'structural_token', f'demonstration_{i+1}_label_token', 'separator_token'] for i in range(n_examples)]
    prompt_icl_query_labels = [['query_structural_token', 'query_demonstration_token', 'query_separator_token', 'query_structural_token']]

    prompt_parts = prompt_instructions + prompt_icl_examples + prompt_icl_query

    prompt_part_labels = prompt_instructions_labels + prompt_icl_examples_labels + prompt_icl_query_labels

    return prompt_parts, prompt_part_labels

def extend_labels(sentence_parts, text_labels, tokenizer, label_init=[]):
    """
    Extends ICL component labels across words that are tokenized into multiple tokens

    Parameters:
    sentence_parts: list, where each element is either a token (str), phrase (str), or list of tokens/phrases
    text_labels: list with the same structure as 'sentence_parts', with a corresponding label for that level of the input sentence.
    tokenizer: huggingface tokenizer
    
    Returns:
    final_labels: flattened/extended list of token labels for an ICL prompt (split into parts, contained in sentence_parts and text_labels)
    """    
    zipped_up = [list(zip(x,y)) if isinstance(x, list) else [(x,y)] for x,y in list(zip(sentence_parts,text_labels)) ]

    prompt_builder = ''
    final_labels = label_init
    for element in zipped_up:
    
        for j, (word,label) in enumerate(element):
            if len(word) == 0:
                continue
            pre = len(tokenizer.tokenize(prompt_builder))
            prompt_builder += word
            post = len(tokenizer.tokenize(prompt_builder))

            actual_tokens = post-pre
            
            final_labels.extend([label] * (actual_tokens))

            if j==3 or j==2 and len(element[3])==0:
                final_labels[-1] = final_labels[-1].replace('structural', 'predictive').replace('separator', 'predictive')
            if j==5:
                final_labels[-actual_tokens] = final_labels[-actual_tokens].replace('separator', 'end_of_example')
    
    return final_labels

def tokenize_labels(sentence_parts, text_labels, tokenizer, prepend_bos=False):
    """
    Extends phrase-level labels across tokenization for in-context learning prompts. Tested with GPT-2's tokenizer from huggingface.
    Parameters:
    sentence_parts: list, where each element is either a token (str), phrase (str), or list of tokens/phrases
    text_labels: list with the same structure as 'sentence_parts', with a corresponding label for that level of the input sentence.
    tokenizer: huggingface tokenizer
    
    Returns: 
    labels: flattened/extended list of token labels for an ICL prompt (split into parts, contained in sentence_parts and text_labels)

    based on the tokenize_and_preserve_labels function from:
    https://www.depends-on-the-definition.com/named-entity-recognition-with-bert/
    """
    
    # If the model typically prepends a bos, we add a bos label to label init
    if prepend_bos:
        labels = extend_labels(sentence_parts, text_labels, tokenizer, label_init=['bos_token'])
    else:
        labels = extend_labels(sentence_parts, text_labels, tokenizer, label_init=[])

    return labels

def get_token_meta_labels(prompt_data, tokenizer, query=None, prepend_bos=False):
    """
    Computes the ICL meta-labels for every token in a prompt.
    
    Parameters:
    prompt_data: dict containing ICL prompt examples, and template information
    tokenizer: huggingface tokenizer
    query: str of the query input

    Return:
    token_labels: list of tuples (prompt token index, token, label)  
    prompt_string: full prompt as a string
    """
    if query is None and prompt_data['query_target'] is not None:
        query = prompt_data['query_target']['input']
    if isinstance(query, list):
        query = query[0]
        
    prompt_parts, prompt_part_labels = get_prompt_parts_and_labels(prompt_data, query_sentence=query)
    token_meta_labels = tokenize_labels(prompt_parts, prompt_part_labels, tokenizer, prepend_bos)
    prompt_string = create_prompt(prompt_data=prompt_data, sentence=query)
    tokens = [tokenizer.decode(x) for x in tokenizer(prompt_string).input_ids]
    token_labels = list(zip(np.arange(len(tokens)), tokens, token_meta_labels))

    return token_labels, prompt_string

def get_dummy_token_labels(n_icl_examples, tokenizer, model_config, prefixes=None, separators=None):
    """
    Computes the ground-truth meta labels & indices for an ICL prompt with the specified number of example pairs
    These GT labels assume each word gets a single token

    Parameters:
    n_icl_examples: number of ICL example pairs
    tokenizer: huggingface tokenizer
    prefixes: ICL template prefixes
    separators: ICL template separators

    Return:
    final_token_labels: list of tuples containing a token's index and label name [(int, str), ... ]
    """
    # If the model already prepends a bos token by default, we don't want to add one to our prompts
    prepend_bos =  False if model_config['prepend_bos'] else True

    if prefixes is not None and separators is not None:
        dummy_prompt_data = word_pairs_to_prompt_data({'input': ['a']*n_icl_examples, 'output':['a']*n_icl_examples}, 
                                                    query_target_pair={'input':['a'], 'output':['a']}, prepend_bos_token=prepend_bos,
                                                    prefixes=prefixes, separators=separators)
    else:
        dummy_prompt_data = word_pairs_to_prompt_data({'input': ['a']*n_icl_examples, 'output':['a']*n_icl_examples}, 
                                                  query_target_pair={'input':['a'], 'output':['a']}, prepend_bos_token=prepend_bos)
    final_token_labels, _ = get_token_meta_labels(dummy_prompt_data,tokenizer, prepend_bos=model_config['prepend_bos'])
    final_token_labels = [(x[0],x[-1]) for x in final_token_labels]
    return final_token_labels

def compute_duplicated_labels(token_labels, gt_labels):
    """
    Computes a map between duplicated labels and ground truth label positions for localized averaging

    Parameters:
    token_labels: token labels of actual prompt being used
    gt_labels: token labels for a "ground truth" prompt that assumes each input & output is a single token

    Returns:
    index_map: a dict mapping prompt label indices to ground truth label indices
    dup_label_ranges: indices where labels should be duplicated
    """
    check_inds = list(filter(lambda x: 'demo' in x[2], token_labels))
    dup_ranges = pd.DataFrame(check_inds).groupby(2)[0].aggregate(lambda x: (x.min(), x.max()))
    dup_labels = [v for v,x in dup_ranges.items() if (x[1] - x[0]) > 0]
    
    dup_label_ranges = dup_ranges[dup_labels].to_dict()
    dup_inds = pd.DataFrame(check_inds)[pd.DataFrame(check_inds)[2].duplicated()][0].values
    index_map = {k:v[0] for (k,v) in zip([x[0] for x in token_labels if x[0] not in dup_inds], gt_labels)}
    
    return index_map, dup_label_ranges

def update_idx_map(idx_map, idx_avg) -> dict:
    """
    Updates the idx_map to map duplicate tokens to its gt token position    
    """
    update_map = {}
    for (i,j) in idx_avg.values():
        for k in range(i,j+1):
            if k not in idx_map.keys():
                update_map[k] = idx_map[i]

    update_map = {**idx_map, **update_map} 
    return update_map


def word_pairs_to_prompt_data(word_pairs : dict,
                              instructions: str = "",
                              prefixes: dict = {"input":"Q:", "output":"A:","instructions":""},
                              separators: dict = {"input":"\n", "output":"\n\n", "instructions":""},
                              query_target_pair: dict = None, prepend_bos_token=False,
                              shuffle_labels=False, prepend_space=True) -> dict:
    """Takes a dataset of word pairs, and constructs a prompt_data dict with additional information to construct an ICL prompt.
    Parameters:
    word_pairs: dict of the form {'word1':['a', 'b', ...], 'word2':['c', 'd', ...]}
    instructions: prefix instructions for an ICL prompt
    prefixes: dict of ICL prefixes that are prepended to inputs, outputs and instructions
    separators: dict of ICL separators that are appended to inputs, outputs and instructions
    query_target_pair: dict with a single input-output pair acting as the query for the prompt
    prepend_bos_token: whether or not to prepend a BOS token to the prompt
    shuffle_labels: whether to shuffle the ICL labels
    prepend_space: whether to prepend a space to every input and output token

    Returns: 
    prompt_data: dict containing ICL prompt examples, and template information
    """
    prompt_data = {}
    prompt_data['instructions'] = instructions
    prompt_data['separators'] = separators
    if prepend_bos_token:
        prefixes = {k:(v if k !='instructions' else '<|endoftext|>' + v) for (k,v) in prefixes.items()}
    prompt_data['prefixes'] = prefixes

    if query_target_pair is not None:
        query_target_pair = {k:(v[0] if isinstance(v, list) else v) for k,v in query_target_pair.items()}
    prompt_data['query_target'] = query_target_pair
        
    if shuffle_labels:
        randomized_pairs = [np.random.permutation(x).tolist() if i==1 else x for (i,x) in enumerate(list(word_pairs.values()))] # shuffle labels only
        if prepend_space:
            prompt_data['examples'] = [{'input':' ' + w1, 'output':' ' + w2} for (w1,w2) in list(zip(*randomized_pairs))]
            prompt_data['query_target'] = {k:' ' + v for k,v in query_target_pair.items()} if query_target_pair is not None else None
        else:
            prompt_data['examples'] = [{'input':w1, 'output':w2} for (w1,w2) in list(zip(*randomized_pairs))]
    else:    
        if prepend_space:
            prompt_data['examples'] = [{'input':' ' + w1, 'output':' ' + str(w2)} for (w1,w2) in list(zip(*word_pairs.values()))]
            prompt_data['query_target'] = {k:' ' + str(v) for k,v in query_target_pair.items()} if query_target_pair is not None else None
        else:
            prompt_data['examples'] = [{'input':w1, 'output':w2} for (w1,w2) in list(zip(*word_pairs.values()))]
    
    return prompt_data


# DATASET UTILS
class ICLDataset:
    """
    A simple dataset class containing input-output pairs, used for ICL prompt construction.
    """
    def __init__(self, dataset, name):    
        if isinstance(dataset, str):
            self.raw_data = pd.read_json(dataset)
        elif isinstance(dataset, dict):
            self.raw_data = pd.DataFrame(dataset)
        self.raw_data = self.raw_data[['input', 'output']].drop_duplicates()
        self.name = name

    def __getitem__(self,i):       
        if isinstance(i, int):
            return self.raw_data.iloc[i].to_dict()
        elif isinstance(i, slice):
            return self.raw_data.iloc[i].to_dict(orient='list')
        elif isinstance(i, list) or isinstance(i, np.ndarray):            
            return self.raw_data.iloc[i].to_dict(orient='list')
        elif isinstance(i, str):
            if i not in self.raw_data.columns:
                raise KeyError(f"Column '{i}' not in the dataset. Current columns in the dataset: {self.raw_data.columns.to_list()}")
            else:
                return self.raw_data[i].to_list()
        else:
            raise ValueError(f"{i} is not a valid index type. Expected one of: [int, list, np.ndarray, slice, str]")

    def __len__(self):
        return len(self.raw_data)
    
    def add(self, new_data):
        if isinstance(new_data, ICLDataset):
            self.raw_data = pd.concat([self.raw_data, new_data.raw_data])
        elif isinstance(new_data, list):
            new_raw = [nd.raw_data for nd in new_data]
            self.raw_data = pd.concat([self.raw_data, *new_raw], ignore_index=True).drop_duplicates()
    
    def __repr__(self):
        s = "ICLDataset" + "({\n\tfeatures: " + f"{self.raw_data.columns.to_list()},\n\tnum_rows: {self.__len__()}" + "\n})"
        return s
    
def split_icl_dataset(dataset, istype=False, optlim=8, gold=False, seed=42) -> Dict[str,ICLDataset]:
    """
    Uses scikit-learn's train_test split to create train, valid, test dataset from provided dataset.

    Parameters:
    dataset: ICL dataset
    istype: determines how dataset should be split (either by relation or relation type)
    optlim: set number of word pairs for finetuning
    gold: whether to separate paradigm exemplars into separate set
    seed: seed used for splitting the data

    Returns:
    dict containing train, valid, test ICL datasets
    """
    golddata = None
    if gold:
        golddata = ICLDataset(dataset.raw_data.iloc[:3].to_dict(orient='list'), name=dataset.name)
        rawdat = dataset.raw_data.iloc[3:]
    else: rawdat = dataset.raw_data
    datasize = len(rawdat)
    testlim = min(5 + optlim, datasize - 1)
        # if hybrid or istype else min(10, datasize - 5)
    trv_size = testlim / datasize if optlim > 1 else 0.3
    tev_size = optlim / testlim if optlim > 1 else 0.3
    
    train, test = train_test_split(rawdat, test_size=trv_size, random_state=seed)
    test, valid = train_test_split(test, test_size=tev_size, random_state=seed)
    query = None
    if istype:
        cq_size = max(0.4, 2 / len(train))
        if cq_size >= 1:
            query = train
            train = None
        else: train, query = train_test_split(train, test_size=cq_size, random_state=seed)
    
    if train is not None:
        train = ICLDataset(train.to_dict(orient='list'), name=dataset.name)
    if query is not None:
        query = ICLDataset(query.to_dict(orient='list'), name=dataset.name)
    valid = ICLDataset(valid.to_dict(orient='list'), name=dataset.name)
    test = ICLDataset(test.to_dict(orient='list'), name=dataset.name)

    return {'train':train, 'query':query, 'valid':valid, 'test':test, 'gold':golddata}

def load_dataset_comp(task_name: str,
                 root_data_dir: str = '../dataset_files',
                 test_size = 0.5,
                 seed=32
                ) -> Dict[str,ICLDataset]:
    """
    Loads a dataset with input/output pairs

    Parameters:
    task_name: the name of the task dataset
    root_data_dir: the root directory where the data comes from
    test_size: fraction used in train/test split
    
    Return:
    dataset: the dict contain the train/valid/test dataset splits
    """

    data_folders = ['abstractive', 'extractive', 'semantic', 'semantictype', 'BATS', 'BATSGroup']
    
    path = Path(root_data_dir)
    d_group_map = [(dataset_type, os.path.exists(os.path.join(root_data_dir, dataset_type, task_name+'.json'))) for dataset_type in data_folders]

    d_group = list(filter(lambda x: x[1], d_group_map))

    assert len(d_group) !=0 and len(d_group) == 1, f"Error! 'task_name'={task_name}.json must be uniquely contained in one of these directories:{data_folders}. Please check the root_data_dir"
    dataset_folder = d_group[0][0]
    
    d_path = os.path.join(path, dataset_folder, f'{task_name}.json')
    
    dataset = ICLDataset(d_path, task_name)
    query, target = train_test_split(dataset.raw_data, test_size=test_size, random_state=seed)
    query = ICLDataset(query.to_dict(orient='list'), name=dataset.name)
    target = ICLDataset(target.to_dict(orient='list'), name=dataset.name)
    return {'query': query, 'target': target}

def load_dataset(typename: str,
                 task_names: list,
                 optlim: int = 0,
                 root_data_dir: str = '../dataset_files',
                 istype = False,
                 gold = False,
                 seed=32
                ) -> Dict[str,ICLDataset]:
    """
    Loads a dataset with input/output pairs

    Parameters:
    task_name: the name of the task dataset
    optlim: set number of word pairs for finetuning
    root_data_dir: the root directory where the data comes from
    test_size: fraction used in train/test split
    istype: whether this is used as relation type (mainly for grouping general relation types)
    gold: whether to separate paradigm exemplars into separate set
    
    Return:
    dataset: the dict contain the train/valid/test dataset splits
    """

    data_folders = ['abstractive', 'extractive', 'SemEval', 'SemEvalType',
                    'BATS', 'BATSGroup', 'Google', 'MSR']

    path = Path(root_data_dir)
    datasets = {'train':[], 'query':[], 'valid':[], 'test':[], 'gold':[]}
    for task_name in task_names:
        d_group_map = [(dataset_type, os.path.exists(os.path.join(root_data_dir, dataset_type, task_name+'.json'))) for dataset_type in data_folders]

        d_group = list(filter(lambda x: x[1], d_group_map))

        assert len(d_group) !=0 and len(d_group) == 1, f"Error! 'task_name'={task_name}.json must be uniquely contained in one of these directories:{data_folders}. Please check the root_data_dir"
        dataset_folder = d_group[0][0]
        
        d_path = os.path.join(path, dataset_folder, f'{task_name}.json')
        
        dataset = ICLDataset(d_path, task_name)
        dataset = split_icl_dataset(dataset, istype=istype, optlim=optlim, gold=gold, seed=seed)
        if dataset['train'] is not None: datasets['train'].append(dataset['train'])
        if dataset['query'] is not None: datasets['query'].append(dataset['query'])
        if dataset['gold'] is not None: datasets['gold'].append(dataset['gold'])
        datasets['valid'].append(dataset['valid'])
        datasets['test'].append(dataset['test'])
    traindat, valdat, testdat = datasets['train'][0], datasets['valid'][0], datasets['test'][0]
    querydat = datasets['query'][0] if len(datasets['query']) > 0 else None
    golddat = datasets['gold'][0] if len(datasets['gold']) > 0 else None
    if istype:
        traindat.add(datasets['train'][1:])
        if len(datasets['query']) > 1: querydat.add(datasets['query'][1:])
        valdat.add(datasets['valid'][1:])
        testdat.add(datasets['test'][1:])
        if len(datasets['gold']) > 1: golddat.add(datasets['gold'][1:])
        trind = traindat.raw_data.set_index(["input","output"]).index
        vind = valdat.raw_data.set_index(["input","output"]).index
        teind = testdat.raw_data.set_index(["input","output"]).index
        if golddat is not None:
            gind = golddat.raw_data.set_index(["input","output"]).index
            testdat.raw_data = testdat.raw_data.loc[~teind.isin(gind)]
            teind = testdat.raw_data.set_index(["input","output"]).index
            valdat.raw_data = valdat.raw_data.loc[~vind.isin(gind)]
            vind = valdat.raw_data.set_index(["input","output"]).index
            if querydat is not None:
                qind = querydat.raw_data.set_index(["input","output"]).index
                querydat.raw_data = querydat.raw_data.loc[~qind.isin(gind)]
            traindat.raw_data = traindat.raw_data.loc[~trind.isin(gind)]
            trind = traindat.raw_data.set_index(["input","output"]).index
        valdat.raw_data = valdat.raw_data.loc[~vind.isin(teind)]
        vind = valdat.raw_data.set_index(["input","output"]).index
        if querydat is not None:
            qind = querydat.raw_data.set_index(["input","output"]).index
            querydat.raw_data = querydat.raw_data.loc[~qind.isin(teind)]
            qind = querydat.raw_data.set_index(["input","output"]).index
            querydat.raw_data = querydat.raw_data.loc[~qind.isin(vind)]
        traindat.raw_data = traindat.raw_data.loc[~trind.isin(teind)]
        trind = traindat.raw_data.set_index(["input","output"]).index
        traindat.raw_data = traindat.raw_data.loc[~trind.isin(vind)]
        if querydat is not None:
            qind = querydat.raw_data.set_index(["input","output"]).index
            trind = traindat.raw_data.set_index(["input","output"]).index
            traindat.raw_data = traindat.raw_data.loc[~trind.isin(qind)]
            querydat.raw_data = querydat.raw_data.reset_index(drop=True)
        traindat.raw_data = traindat.raw_data.reset_index(drop=True)
        valdat.raw_data = valdat.raw_data.reset_index(drop=True)
        testdat.raw_data = testdat.raw_data.reset_index(drop=True)
    main_dataset = {'train':traindat, 'query':querydat,
                    'opt':valdat, 'test':testdat, 'gold':golddat}
    return main_dataset

def load_green(root_data_dir: str = '../dataset_files'):
    gre = pd.read_excel(os.path.join(root_data_dir,"Green.xlsx"))
    cons, que1, que2 = gre.iloc[:,:2], gre.iloc[:,[2,3]], gre.iloc[:,[2,4]]
    cons = cons.rename(columns={"prompt_w1":"input", "prompt_w2":"output"})
    que1 = que1.rename(columns={"prompt_w3":"input", "human":"output"})
    que2 = que2.rename(columns={"prompt_w3":"input", "answer":"output"})
    contexts = cons.to_dict('records')
    gcon = {"green":contexts}
    source_gpairnames = (cons.loc[:,"input"] + ':' + cons.loc[:,"output"]).to_list()
    gpairnames = (que1.loc[:,"input"] + ':' + que1.loc[:,"output"]).to_list()
    gdatas = [[i,j] if i != j else [i] \
                for (i,j) in zip(que1.to_dict('records'),que2.to_dict('records'))]

    bpdf = pd.read_excel(os.path.join(root_data_dir,"Green_bartprob.xlsx"))
    ids = {}
    for bp in bpdf.columns[2:]:
        if len(bp) == 2: ids[bp] = '0' + bp
    unorms = torch.tensor(bpdf.iloc[:, 2:].values)

    return gcon, source_gpairnames, gpairnames, gdatas, unorms

def load_se_benchmarks():
    benches = {"03A_Synonymity": [{"input":"car", "output":"auto"},
                                    {"input":"kid", "output":"child"},
                                    {"input":"big", "output":"large"}],
                "03F_Attribute-Similarity": [{"input":"rake", "output":"fork"},
                                            {"input":"word", "output":"knife"},
                                            {"input":"stairs", "output":"ladder"}],
                "03H_Change": [{"input":"discount", "output":"price"},
                                {"input":"dim", "output":"light"},
                                {"input":"raise", "output":"salary"}],
                "04B_Contrary": [{"input":"old", "output":"young"},
                                {"input":"big", "output":"small"},
                                {"input":"black", "output":"white"}],
                "04D_Directional": [{"input":"east", "output":"west"},
                                    {"input":"from", "output":"back"},
                                    {"input":"north", "output":"south"}],
                "04G_Pseudoantonym": [{"input":"right", "output":"bad"},
                                        {"input":"good", "output":"wrong"},
                                        {"input":"majority", "output":"small"}],
                "08A_Cause_Effect": [{"input":"joke", "output":"laughter"},
                                    {"input":"injury", "output":"pain"},
                                    {"input":"accident", "output":"damage"}],
                "08B_Cause_Compensatory-Action": [{"input":"hunter", "output":"cat"},
                                                    {"input":"tiredness", "output":"rest"},
                                                    {"input":"sadness", "output":"cry"}],
                "08D_Action-Activity_Goal": [{"input":"flee", "output":"escape"},
                                            {"input":"study", "output":"learn"},
                                            {"input":"work", "output":"earn"}],
                "01A_Taxonomic": [{"input":"weapon", "output":"spear"},
                                    {"input":"tree", "output":"oak"},
                                    {"input":"animal", "output":"pig"}],
                "01B_Functional": [{"input":"tool", "output":"hammer"},
                                    {"input":"utensil", "output":"spoon"},
                                    {"input":"instrument", "output":"violin"}],
                "01D_Plural-Collective": [{"input":"snacks", "output":"chips"},
                                            {"input":"cutlery", "output":"forks"},
                                            {"input":"furniture", "output":"chairs"}],
                "02C_Mass_Portion": [{"input":"hour", "output":"seconds"},
                                    {"input":"feet", "output":"inches"},
                                    {"input":"week", "output":"day"}],
                "02F_Item_Topological-Part": [{"input":"hotel", "output":"lobby"},
                                                {"input":"hill", "output":"top"},
                                                {"input":"airplane", "output":"cockpit"}],
                "02G_Object_Stuff": [{"input":"omelette", "output":"eggs"},
                                    {"input":"ocean", "output":"water"},
                                    {"input":"wall", "output":"bricks"}],
                "09B_Location_Process-Product": [{"input":"factory", "output":"goods"},
                                                {"input":"mill", "output":"flour"},
                                                {"input":"mine", "output":"coal"}],
                "09E_Contiguity": [{"input":"bank", "output":"river"},
                                    {"input":"shore", "output":"lake"},
                                    {"input":"ditch", "output":"road"}],
                "09G_Time_Associated-Item": [{"input":"childhood", "output":"toys"},
                                            {"input":"girlhood", "output":"dolls"},
                                            {"input":"infancy", "output":"pacifier"}]}
    all_benches = [pair for bench in benches for pair in benches[bench]]
    all_benchrels = [bench for bench in benches for _ in benches[bench]]
    all_benchnames = ["%s:%s" % (pair["input"],pair["output"]) for pair in all_benches]
    all_benches = {"bench": all_benches}

    return all_benches, all_benchnames, all_benchrels
