import numpy as np
import pandas as pd
from itertools import combinations, permutations
from pathlib import Path
import os
from typing import *
from sklearn.model_selection import train_test_split



def create_fewshot_primer(prompt_data):
    """Creates the primer string for GPT in-context learning"""       
    prompt = ''
    prompt += prompt_data['prefixes']['instructions'] + prompt_data['instructions'] + prompt_data['separators']['instructions']
    
    for example in prompt_data['examples']:
        
        prompt += prompt_data['prefixes']['input'] + example['input'] + prompt_data['separators']['input']
        prompt += prompt_data['prefixes']['output'] + example['output'] + prompt_data['separators']['output']
        
    return prompt
    
def create_prompt(prompt_data, sentence=None):
    """Creates a prompt using the specified sentence for GPT in-context learning"""
    if sentence is None and prompt_data['query_target'] is not None:
        sentence = prompt_data['query_target']['input']

    if isinstance(sentence, list):
        sentence = sentence[0]

    prompt_init = create_fewshot_primer(prompt_data)    
    prompt = prompt_init + prompt_data['prefixes']['input'] + sentence + prompt_data['separators']['input']
    prompt += prompt_data['prefixes']['output']
    
    return prompt   

# Partial primer & prompt functions
def create_partial_fewshot_primer(prompt_data, include = np.arange(8)):
    """Creates the primer string for GPT in-context learning, filtering to include a subset of specified priming strings"""
    prompt = ''
    prompt += prompt_data['prefixes']['instructions'] + prompt_data['instructions'] + prompt_data['separators']['instructions']

    # Grab each priming example in the specified order.
    for i in include:
        example = prompt_data['examples'][i]
        prompt += prompt_data['prefixes']['input'] + example['input'] + prompt_data['separators']['input']
        prompt += prompt_data['prefixes']['output'] + example['output'] + prompt_data['separators']['output']
        
    return prompt

def create_partial_prompt(prompt_data, sentence=None, include=np.arange(8)):
    """Creates a prompt using the specified sentence and partial list of in-context primer sentences"""
    if sentence is None and prompt_data['query_target'] is not None:
        sentence = prompt_data['query_target']['input']
    if isinstance(sentence, list):
        sentence = sentence[0]
        
    prompt_init = create_partial_fewshot_primer(include, prompt_data)
    
    prompt = prompt_init + prompt_data['prefixes']['input'] + sentence + prompt_data['separators']['input']
    prompt += prompt_data['prefixes']['output']
    
    return prompt


# UTILS FOR GENERATING PROMPT META LABELS
def get_prompt_parts_and_labels(prompt_data, query_sentence=None):
    """
    Generates high-level labels for ICL prompts according to its ICL role, such as demonstration, label, separator, structural, etc.
    The JSON prompt format should include 'instructions', 'examples' with ('input', 'output') pairs, 
    'prefixes', and 'separators' for 'input', 'output', and 'instructions'.
    Used in conjunction with tokenize_and_preserve_labels
    """
    if query_sentence is None and prompt_data['query_target'] is not None:
        query_sentence = prompt_data['query_target']['input']
    if isinstance(query_sentence, list):
        query_sentence = query_sentence[0]
    n_examples = len(prompt_data['examples'])
    assemble_icl_example = lambda example, prompt_data: [prompt_data['prefixes']['input'], example['input'], prompt_data['separators']['input'], prompt_data['prefixes']['output'], example['output'], prompt_data['separators']['output']]
    assemble_icl_query = lambda query, prompt_data: [prompt_data['prefixes']['input'], query, prompt_data['separators']['input'], prompt_data['prefixes']['output']]

    prompt_instructions = [prompt_data['prefixes']['instructions'], prompt_data['instructions'], prompt_data['separators']['instructions']] 
    prompt_icl_examples = [assemble_icl_example(prompt_data['examples'][i], prompt_data) for i in range(n_examples)]
    prompt_icl_query = [assemble_icl_query(query_sentence, prompt_data)]

    prompt_instructions_labels = ['bos_token', 'instructions_token', 'separator_token']
    prompt_icl_examples_labels = [['structural_token', f'demonstration_{i+1}_token', 'separator_token', 'structural_token', f'demonstration_{i+1}_label_token', 'separator_token'] for i in range(n_examples)]
    prompt_icl_query_labels = [['query_structural_token', 'query_demonstration_token', 'query_separator_token', 'query_structural_token']]

    prompt_parts = prompt_instructions + prompt_icl_examples + prompt_icl_query

    prompt_part_labels = prompt_instructions_labels + prompt_icl_examples_labels + prompt_icl_query_labels

    return prompt_parts, prompt_part_labels

def extend_labels(sentence_parts, text_labels, tokenizer):
    prompt_builder = ''
    final_labels = []
    for i,(word,label) in enumerate(zip(sentence_parts, text_labels)):
        
        if isinstance(word, list):
            for j, (word2,label2) in enumerate(zip(word,label)):
                if len(word2) == 0:
                    continue
                pre = tokenizer(prompt_builder, return_length=True).length[0]
                prompt_builder += word2
                post = tokenizer(prompt_builder, return_length=True).length[0]
                n_tokens = tokenizer(word2, return_length=True).length[0]
                actual_tokens = post-pre
                if n_tokens != actual_tokens and n_tokens < actual_tokens:
                    if 'end_of_example' in final_labels[-1]:
                        final_labels.extend(['separator_token']*(actual_tokens - n_tokens))
                    else:
                        final_labels.extend([final_labels[-1]]*(actual_tokens - n_tokens))
                final_labels.extend([label2] * (n_tokens))
                if j==3:
                    final_labels[-1] = final_labels[-1].replace('structural', 'predictive')
                if j==5:
                    final_labels[-n_tokens] = final_labels[-n_tokens].replace('separator', 'end_of_example')
        else:
            if len(word) == 0:
                continue
            pre = tokenizer(prompt_builder, return_length=True).length[0]
            prompt_builder += word
            post = tokenizer(prompt_builder, return_length=True).length[0]
            n_tokens = tokenizer(word, return_length=True).length[0]
            actual_tokens = post-pre
            if n_tokens != actual_tokens and n_tokens < actual_tokens:
                    final_labels.append(final_labels[-1]*(actual_tokens - n_tokens))
            final_labels.extend([label] * (n_tokens))

    return final_labels

def extend_labels_llama(sentence_parts, text_labels, tokenizer):
    prompt_builder = ''
    final_labels = ['bos_token']
    for i,(word,label) in enumerate(zip(sentence_parts, text_labels)):
        
        if isinstance(word, list):
            for j, (word2,label2) in enumerate(zip(word,label)):
                if len(word2) == 0:
                    continue
                pre = tokenizer(prompt_builder, return_length=True).length
                prompt_builder += word2
                post = tokenizer(prompt_builder, return_length=True).length
                if word2.startswith(' '):  
                    n_tokens = len(tokenizer.tokenize(word2.replace(" ","",1)))
                else:
                    n_tokens = tokenizer(word2, return_length=True).length -1
                actual_tokens = post-pre
                if n_tokens != actual_tokens:
                    if n_tokens < actual_tokens:
                        if prompt_builder.startswith(' '):
                            final_labels.append(label2)
                        else:
                            if 'end_of_example' in final_labels[-1]:
                                final_labels.extend(['separator_token']*(actual_tokens - n_tokens))
                            else:
                                final_labels.extend([final_labels[-1]]*(actual_tokens - n_tokens))
                    elif n_tokens > actual_tokens: 
                        n_tokens = min(actual_tokens, n_tokens)
                
                final_labels.extend([label2] * (n_tokens))
                if j==3:
                    final_labels[-1] = final_labels[-1].replace('structural', 'predictive')
                if j==5:
                    final_labels[-n_tokens] = final_labels[-n_tokens].replace('separator', 'end_of_example')
                
        else:
            if len(word) == 0:
                continue
            pre = tokenizer(prompt_builder, return_length=True).length
            prompt_builder += word
            post = tokenizer(prompt_builder, return_length=True).length
            n_tokens = tokenizer(word, return_length=True).length -1
            actual_tokens = post-pre
            if n_tokens != actual_tokens and n_tokens < actual_tokens:
                    final_labels.append(final_labels[-1]*(actual_tokens - n_tokens))
            final_labels.extend([label] * (n_tokens))

    return final_labels

def tokenize_labels(sentence_parts, text_labels, tokenizer):
    """
    Extends phrase-level labels across tokenization for in-context learning prompts. Tested with GPT-2's tokenizer from huggingface.
    Params:
    - 'sentence_parts': is a list, where each element is either a token (str), phrase (str), or list of tokens/phrases
    - 'text_labels': should have the same structure as 'sentence', with a corresponding label for that level of the input sentence.
    
    based on the tokenize_and_preserve_labels function from:
    https://www.depends-on-the-definition.com/named-entity-recognition-with-bert/
    """
    
    is_llama = 'llama' in tokenizer.name_or_path

    if is_llama:
        labels = extend_labels_llama(sentence_parts, text_labels, tokenizer)
    else:
        labels = extend_labels(sentence_parts, text_labels, tokenizer)

    return labels

def get_token_meta_labels(prompt_data, tokenizer, query=None):
    """
    Computes the ICL meta-labels for every token in a prompt.
    
    Parameters:
    query: str
    prompt_data: 
    tokenizer: the model's tokenizer

    Return:
    token_labels: list of tuples (prompt token index, token, label)  
    prompt_string: full prompt as a string
    """
    if query is None and prompt_data['query_target'] is not None:
        query = prompt_data['query_target']['input']
    if isinstance(query, list):
        query = query[0]
        
    prompt_parts, prompt_part_labels = get_prompt_parts_and_labels(prompt_data, query_sentence=query)
    token_meta_labels = tokenize_labels(prompt_parts, prompt_part_labels, tokenizer)
    prompt_string = create_prompt(prompt_data=prompt_data, sentence=query)
    tokens = [tokenizer.decode(x) for x in tokenizer(prompt_string).input_ids]
    token_labels = list(zip(np.arange(len(tokens)), tokens, token_meta_labels))

    return token_labels, prompt_string

def get_dummy_token_labels(n_icl_examples, tokenizer, prefixes=None, separators=None):
    """
    Computes the ground-truth meta labels & indices for an ICL prompt with the specified number of example pairs

    Parameters:
    n_icl_examples: number of ICL example pairs
    tokenizer: huggingface model tokenizer

    Return:
    final_token_labels: list of tuples containing a token's index and label name [(int, str), ... ]
    """
    is_llama = 'llama' in tokenizer.name_or_path
    prepend_bos = not is_llama
    if prefixes is not None and separators is not None:
        dummy_prompt_data = word_pairs_to_prompt_data({'input': ['a']*n_icl_examples, 'output':['a']*n_icl_examples}, 
                                                    query_target_pair={'input':['a'], 'output':['a']}, prepend_bos_token=prepend_bos,
                                                    prefixes=prefixes, separators=separators)
    else:
        dummy_prompt_data = word_pairs_to_prompt_data({'input': ['a']*n_icl_examples, 'output':['a']*n_icl_examples}, 
                                                  query_target_pair={'input':['a'], 'output':['a']}, prepend_bos_token=prepend_bos)
    final_token_labels, _ = get_token_meta_labels(dummy_prompt_data,tokenizer)
    final_token_labels = [(x[0],x[-1]) for x in final_token_labels]
    return final_token_labels

def compute_duplicated_labels(token_labels, gt_labels):
    """
    Computes a map between duplicated labels and ground truth label positions for localized averaging

    Parameters:
    token_labels:
    gt_labels:

    Returns:
    index_map:
    dup_label_ranges:
    """
    check_inds = list(filter(lambda x: 'demo' in x[2], token_labels))
    dup_ranges = pd.DataFrame(check_inds).groupby(2)[0].aggregate(lambda x: (x.min(), x.max()))
    dup_labels = [v for v,x in dup_ranges.items() if (x[1] - x[0]) > 0]

    dup_label_ranges = dup_ranges[dup_labels].to_dict()
    dup_inds = pd.DataFrame(check_inds)[pd.DataFrame(check_inds)[2].duplicated()][0].values

    index_map = {k:v[0] for (k,v) in zip([x[0] for x in token_labels if x[0] not in dup_inds], gt_labels)}

    return index_map, dup_label_ranges

def update_idx_map(idx_map, idx_avg):
    """
    Updates the idx_map to map duplicate tokens to its gt token position    
    """
    update_map = {}
    for (i,j) in idx_avg.values():
        for k in range(i,j+1):
            if k not in idx_map.keys():
                update_map[k] = idx_map[i]

    update_map = {**idx_map, **update_map} 
    return update_map


def word_pairs_to_prompt_data(word_pairs : dict,
                              instructions: str = "",
                              prefixes: dict = {"input":"Q:", "output":"A:","instructions":""},
                              separators: dict = {"input":"\n", "output":"\n\n", "instructions":""},
                              query_target_pair: dict = None, prepend_bos_token=False,
                              shuffle_labels=False, prepend_space=True):
    """Takes a dataset of word pairs, and constructs a prompt_data dict with additional information for ICL prompting.
    word_pairs is a dict with of the form: {'word1':['a', 'b', ...], 'word2':['c', 'd', ...]}"""
    prompt_data = {}
    prompt_data['instructions'] = instructions
    prompt_data['separators'] = separators
    if prepend_bos_token:
        prefixes = {k:(v if k !='instructions' else '<|endoftext|>' + v) for (k,v) in prefixes.items()}
    prompt_data['prefixes'] = prefixes

    if query_target_pair is not None:
        query_target_pair = {k:(v[0] if isinstance(v, list) else v) for k,v in query_target_pair.items()}
    prompt_data['query_target'] = query_target_pair
        
    if shuffle_labels:
        randomized_pairs = [np.random.permutation(x).tolist() if i==1 else x for (i,x) in enumerate(list(word_pairs.values()))] # shuffle labels only
        if prepend_space:
            prompt_data['examples'] = [{'input':' ' + w1, 'output':' ' + w2} for (w1,w2) in list(zip(*randomized_pairs))]
            prompt_data['query_target'] = {k:' ' + v for k,v in query_target_pair.items()} if query_target_pair is not None else None
        else:
            prompt_data['examples'] = [{'input':w1, 'output':w2} for (w1,w2) in list(zip(*randomized_pairs))]
    else:    
        if prepend_space:
            prompt_data['examples'] = [{'input':' ' + w1, 'output':' ' + str(w2)} for (w1,w2) in list(zip(*word_pairs.values()))]
            prompt_data['query_target'] = {k:' ' + str(v) for k,v in query_target_pair.items()} if query_target_pair is not None else None
        else:
            prompt_data['examples'] = [{'input':w1, 'output':w2} for (w1,w2) in list(zip(*word_pairs.values()))]
    
    return prompt_data


# DATASET UTILS
class ICLDataset:
    def __init__(self, dataset):
        if isinstance(dataset, str):
            self.raw_data = pd.read_json(dataset)
        elif isinstance(dataset, dict):
            self.raw_data = pd.DataFrame(dataset)
        self.raw_data = self.raw_data[['input', 'output']]

    def __getitem__(self,i):       
        if isinstance(i, int):
            return self.raw_data.iloc[i].to_dict()
        elif isinstance(i, slice):
            return self.raw_data.iloc[i].to_dict(orient='list')
        elif isinstance(i, list) or isinstance(i, np.ndarray):            
            return self.raw_data.iloc[i].to_dict(orient='list')
        elif isinstance(i, str):
            if i not in self.raw_data.columns:
                raise KeyError(f"Column '{i}' not in the dataset. Current columns in the dataset: {self.raw_data.columns.to_list()}")
            else:
                return self.raw_data[i].to_list()
        else:
            raise ValueError(f"{i} is not a valid index type. Expected one of: [int, list, np.ndarray, slice, str]")

    def __len__(self):
        return len(self.raw_data)
    
    def __repr__(self):
        s = "ICLDataset" + "({\n\tfeatures: " + f"{self.raw_data.columns.to_list()},\n\tnum_rows: {self.__len__()}" + "\n})"
        return s
    
def split_icl_dataset(dataset, train_size=None, test_size=0.3, seed=42) -> Dict[str,ICLDataset]:
    """
    Uses scikit-learn's train_test split to create train, valid, test dataset from provided dataset.
    """
    if train_size is None and test_size is None:
        train_size = 0.7
        test_size = 0.3

    elif train_size is not None and test_size is None:
        test_size = 1-train_size

    elif train_size is None and test_size is not None:
        train_size = 1-test_size
    
    train, valid = train_test_split(dataset.raw_data, test_size=test_size, random_state=seed)
    test, valid = train_test_split(valid, test_size=test_size, random_state=seed)

    train = ICLDataset(train.to_dict(orient='list'))
    valid = ICLDataset(valid.to_dict(orient='list'))
    test = ICLDataset(test.to_dict(orient='list'))

    return {'train':train, 'valid':valid, 'test':test}


def load_dataset(task_name: str,
                 root_data_dir: str = 'dataset_files',
                 test_size = 0.3, 
                 seed=42
                ) -> Dict[str,ICLDataset]:
    """
    Loads a dataset with input/output pairs

    Parameters:
    task_name: the name of the task dataset
    root_data_dir: the root directory where the data comes from
    test_size: fraction used in train/test split
    
    Return:
    dataset: the dict contain the train/valid/test dataset splits
    """

    data_folders = ['abstractive', 'extractive']
    assert test_size <= 1.0

    path = Path(root_data_dir)
    d_group_map = [(dataset_type, os.path.exists(os.path.join(root_data_dir, dataset_type, task_name+'.json'))) for dataset_type in data_folders]

    d_group = list(filter(lambda x: x[1], d_group_map))

    assert len(d_group) !=0 and len(d_group) == 1, f"Error! 'task_name'={task_name}.json must be uniquely contained in one of these directories:{data_folders}"
    dataset_folder = d_group[0][0]
    
    d_path = os.path.join(path, dataset_folder, f'{task_name}.json')
    
    dataset = ICLDataset(d_path)
    dataset = split_icl_dataset(dataset, test_size=test_size, seed=seed)

    return dataset