import pandas as pd
import numpy as np
import re
import nltk
from torch.utils.data import Dataset

# preprocess babi text files
def get_dataset_df(dataset_path, max_n_facts=None):
    with open(dataset_path, 'r') as f:
        texts = f.read().strip()
        texts = texts.split('\n')
        df = pd.DataFrame(texts, columns=['text'])

    # parse samples
    df['phrase_num'] = df.text.apply(lambda x: int(x.split(' ')[0]))
    df.text = df.text.apply(lambda x: x[x.index(' ') + 1:])
    df['answer'] = df.text.apply(lambda x: x[x.index('\t') + 1:] if '\t' in x else None)
    df['reference_num'] = df.answer.apply(lambda x: x if x is None else [int(n) for n in re.split('\t| ', x)[1:]])
    df.answer = df.answer.apply(lambda x: x if x is None else x.split('\t')[0])
    df.text = df.text.apply(lambda x: x.split('\t')[0] if '\t' in x else x)

    # mark each sample
    sample_start_inds = list(np.where(df.phrase_num == 1)[0]) + [df.shape[0]]
    for i, (start, end) in enumerate(zip(sample_start_inds, sample_start_inds[1:])):
        df.loc[start:end, 'initial_sample_num'] = i

    df.initial_sample_num = df.initial_sample_num.astype(int)

    # multiple questions in sample -> samples with single question
    initial_samples = [df[df.initial_sample_num == sn] for sn in df.initial_sample_num.unique()]

    single_question_slices = []
    for sample in initial_samples:
        answer_positions = sample[~sample.answer.isna()].index
        slices = [sample.loc[:ans_pos].copy() for ans_pos in answer_positions]
        for i, slc in enumerate(slices):
            slices[i] = slc[(slc.answer.isna()) | (slc.index == slc.index[-1])]
        if max_n_facts is not None:             # drop samples with too many facts
            slices = [slc for slc in slices if slc.shape[0] <= max_n_facts]
        single_question_slices += slices
    
    df = pd.concat(single_question_slices).reset_index(drop=True)

    # mark each sample again
    sample_start_inds = list(np.where(df.phrase_num == 1)[0]) + [df.shape[0]]
    for i, (start, end) in enumerate(zip(sample_start_inds, sample_start_inds[1:])):
        df.loc[start:end, 'sample_num'] = i

    df.sample_num = df.sample_num.astype(int)
    
    return df


def get_dataset_df_custom(dataset_path, max_len=None):
    with open(dataset_path, 'r') as f:
        texts = f.read().strip()
        texts = texts.split('\n')
        df = pd.DataFrame(texts, columns=['text'])
        df = df.sample(frac=1) # randomize order of keys
    
    if max_len is not None:
        df = df.iloc[0:max_len]

    df['answer'] = df.text
    df['facts'] = df.text.apply(lambda x: f'The pass key is {x}. Remember it. {x} is the pass key. ')
    df['question'] = df.text.apply(lambda x:f'There is an important pass key hidden inside a lot of irrelevant text. Find it and memorize it:\n')
    df['sample_num'] = range(len(df))
    
    return df

# babi task loader dataset
class TaskDataset(Dataset):
    def __init__(self, dataset_path, max_n_facts=None):
        self.fact_dataset = get_dataset_df(dataset_path, max_n_facts=max_n_facts)

    def __getitem__(self, ind):
        slc = self.fact_dataset[self.fact_dataset.sample_num == ind]
        references = slc[slc.phrase_num.isin(slc.reference_num.values[-1])].text.values
        sample = {'facts': slc.text.values[:-1],
                  'question': slc.text.values[-1],
                  'answer': slc.answer.values[-1],
                  'references': references}
        return sample
    
    def __len__(self):
        return self.fact_dataset.sample_num.max()
    
class TaskDatasetCustom(Dataset):
    def __init__(self, dataset_path, max_n_facts=None, max_len=None):
        self.fact_dataset = get_dataset_df_custom(dataset_path, max_len)

    def __getitem__(self, ind):
        slc = self.fact_dataset[self.fact_dataset.sample_num == ind]
        sample = {'facts': slc.facts.values[-1],
                  'question': slc.question.values[-1],
                  'answer': slc.answer.values[-1],
                  'references': []}
        return sample
    
    def __len__(self):
        return self.fact_dataset.sample_num.max()+1


def sum_lengths(sentences):
    return sum([len(s) for s in sentences])


# sampler of background text 
class SentenceSampler:
    def __init__(self, dataset, tokenizer, min_sentence_len=10, max_sentence_len=None, shuffle=False, random_seed=42):
        self.sample_ind = 0
        self.dataset = dataset
        self.sentences = []
        self.tokenizer = tokenizer
        self.min_sentence_len = min_sentence_len
        self.max_sentence_len = max_sentence_len
        self.sentence_tokenizer = nltk.PunktSentenceTokenizer()
        self.shuffle = shuffle
        self.gen = np.random.default_rng(seed=random_seed)

    def get_sample(self, sample_size):        
        sample = []
        total_len = 0
        while True:
            sentences = list(self.sentences)
            for i, sent in enumerate(sentences): # add new sentence until sample_size is reached
                tokenized = self.tokenizer.encode(sent, add_special_tokens=False)
                if not self.length_is_ok(tokenized):
                    continue
                total_len += len(tokenized)
                sample.append(tokenized)
                if total_len >= sample_size:
                    self.sentences = self.sentences[i+1:]
                    cutoff = total_len - sample_size
                    if cutoff > 0:
                        sample[-1] = sample[-1][:-cutoff]
                    return sample

            self.sentences = []
            self.sample_sentences_(sample_size) #appends new sentences, can be updated to just return new sentences
        
    def sample_sentences_(self, sample_size):
        sentences = []
        while len(sentences) == 0:
            text = self.next_sample_()
            if self.shuffle:
                if len(text) == 0:
                    continue
                text = text[self.gen.choice(len(text)):] # start from random position in text
                text = text[:sample_size * 10]          # cut too long texts to speed up tokenization
            sentences += self.sentence_tokenizer.tokenize(text)
            if self.shuffle:
                sentences = sentences[1:-1]
        self.sentences += sentences

    def next_sample_(self):
        if self.shuffle:
            self.total_tokens = 0
            sample_ind = self.gen.choice(len(self.dataset))
            sample = self.dataset[int(sample_ind)]['text']
        else:
            sample = self.dataset[int(self.sample_ind)]['text']
            self.sample_ind += 1
            self.sample_ind = self.sample_ind % len(self.dataset) 
        return sample
        
    def length_is_ok(self, tokenized):
        if self.max_sentence_len is not None and len(tokenized) > self.max_sentence_len:
            return False
        if self.min_sentence_len is not None and len(tokenized) < self.min_sentence_len:
            return False
        return True


def find_needle_position(updated_sample):
    num_of_chunks = [len(chunks) for chunks in updated_sample]
    if num_of_chunks[-1] == 1:
        needle_index = len(updated_sample)
    else:
        needle_index = num_of_chunks.index(2)
    
    needle_loc = 0
    for i in range(needle_index):
        needle_loc+=len(updated_sample[i][0])
    
    needle_loc = needle_loc + 4 # 'The pass key is XXXXX' - the key location is shifted 4 tokens from the start of the fact
    return needle_loc


# combined dataset for noisy babi QA
# it's recommended to use sample_size >= 1024 
# and task_end_pct - task_start_pct >= 0.2 in order to 
class NoiseInjectionDataset(Dataset):
    def __init__(self, task_dataset, noise_sampler, tokenizer, 
                 task_start_pct=None,                   # left border of facts in sample, between 0 and 1
                 task_end_pct=None,                     # right border of facts in sample, between task_start_pct and 1
                 sample_size=1024,
                 mixed_length_ratio=0.0,                # used for mixed length curriculum, prob for shorter samples
                 random_seed=42):
        self.task_dataset = task_dataset
        self.noise_sampler = noise_sampler
        self.sample_size = sample_size
        self.mixed_length_ratio = mixed_length_ratio
        self.tokenizer = tokenizer
        self.task_start_pct = task_start_pct
        self.task_end_pct = task_end_pct
        if random_seed:
            self.gen = np.random.default_rng(seed=random_seed)

    def __getitem__(self, ind):
        sample = self.task_dataset[ind]
        # facts_tok = self.tokenizer(list(sample['facts']))['input_ids'] # orig
        facts_tok = self.tokenizer(list([sample['facts']]))['input_ids']
        question_tok = self.tokenizer(sample['question'])['input_ids']
        answer_tok = self.tokenizer(sample['answer'])['input_ids']
        answer_tok.append(self.tokenizer.eos_token_id) # also added by me

        sample_size = self.get_sample_size()
        task_len = sum_lengths(facts_tok)
        background_text_len = sample_size - task_len
        background_text = self.noise_sampler.get_sample(background_text_len)
        sample['background_text'] = background_text

        if self.task_start_pct is None and self.task_end_pct is None:     # if fact position unspecified
            possible_positions = range(len(background_text) + 1) 
        else:
            task_start_ind = int(sample_size * self.task_start_pct)
            task_end_ind = int(sample_size * self.task_end_pct)
            total_facts_len = sum_lengths(facts_tok)

            possible_positions = []                                       # where can we insert facts?
            current_length = 0
            for i, text in enumerate(background_text):
                if (current_length >= task_start_ind) and (current_length < task_end_ind - total_facts_len):
                    possible_positions.append(i)
                current_length += len(text)

            if len(possible_positions) == 0:
                raise IndexError(f"Unable to insert facts in specified place: {self.task_start_pct, self.task_end_pct}. \
                                 Total fact length: {total_facts_len}, sentences length: {[len(t) for t in background_text]}. Make the range wider or increase the sample size.")
            
        fact_positions = self.gen.choice(possible_positions, len(facts_tok))
        fact_positions.sort()
        sample['fact_positions'] = fact_positions                  # positions of facts between noise sentences

        updated_sample = [[] for _ in range(len(background_text) + 1)] 
        for fact, pos in zip(facts_tok, fact_positions):
            updated_sample[pos].append(fact)

        for i, s in enumerate(background_text):
            updated_sample[i].append(s)

        sample['needle_position'] = find_needle_position(updated_sample)

        flat = [i for s in updated_sample for i in s]
        tokens = [i for s in flat for i in s]

        sample['input_tokens'] = tokens
        sample['question_tokens'] = question_tok
        sample['target_tokens'] = answer_tok

        return sample
    
    def __len__(self):
        return len(self.task_dataset)
    
    def get_sample_size(self):
        if isinstance(self.sample_size, list):
            if self.gen.random() > self.mixed_length_ratio:
                return self.gen.choice(self.sample_size)
            return max(self.sample_size)
        else:
            return self.sample_size