from argparse import ArgumentParser
from collections import deque
import os

import numpy as np
from datasets import load_dataset, Dataset
from transformers import PreTrainedTokenizerBase, AutoTokenizer


class OffsetBasedWordSplitter:
    def __init__(self, teacher_tokenizer: PreTrainedTokenizerBase, student_tokenizer: PreTrainedTokenizerBase):
        self.teacher_tokenizer = teacher_tokenizer
        self.student_tokenizer = student_tokenizer
                
    def get_word_ids(
        self,
        teacher_offsets: list[tuple[int, int]],
        student_offsets: list[tuple[int, int]],
        teacher_special_tokens_mask: list[bool],
        student_special_tokens_mask: list[bool],
    ) -> tuple[np.ndarray, np.ndarray]:

        teacher_word_ids = []
        student_word_ids = []

        teacher_offsets = deque(teacher_offsets)
        student_offsets = deque(student_offsets)

        teacher_special = deque(teacher_special_tokens_mask)
        student_special = deque(student_special_tokens_mask)

        current_end = -1
        current_word_id = -1
        while True:
            if not teacher_offsets:
                for special in student_special: 
                    student_word_ids.append(-100 if special else current_word_id)
                break

            if not student_offsets:
                for special in teacher_special: 
                    teacher_word_ids.append(-100 if special else current_word_id)
                break

            tstart, tstop = teacher_offsets[0]
            sstart, sstop = student_offsets[0]

            # special tokens
            if teacher_special[0]:
                teacher_word_ids.append(-100)
                teacher_offsets.popleft()
                teacher_special.popleft()
                continue
            if student_special[0]:
                student_word_ids.append(-100)
                student_offsets.popleft()
                student_special.popleft()
                continue

            # same word continues
            if tstart < current_end:
                teacher_word_ids.append(current_word_id)
                teacher_offsets.popleft()
                teacher_special.popleft()
                current_end = max(current_end, tstop)
            elif sstart < current_end:
                student_word_ids.append(current_word_id)
                student_offsets.popleft()
                student_special.popleft()
                current_end = max(current_end, sstop)

            # new word
            else:
                current_word_id += 1
                teacher_word_ids.append(current_word_id)
                student_word_ids.append(current_word_id)
                teacher_offsets.popleft()
                student_offsets.popleft()
                teacher_special.popleft()
                student_special.popleft()
                current_end = max(current_end, tstop, sstop)
        
        return np.array(teacher_word_ids), np.array(student_word_ids)
                

def instantiate_dataset(path: str, dataset_type: str) -> Dataset:
    if dataset_type == 'jsonl':
        dataset = load_dataset(
            'json',
            data_files=path,
            split='train',
        )
    elif dataset_type == 'hf':
        dataset = load_dataset(
            path,
            split='train',
        )
    elif dataset_type == 'parquet':
        dataset = load_dataset(
            'parquet',
            data_files=path,
            split='train',
        )
    else:
        raise ValueError(f'unknown dataset type: {dataset_type}')
    return dataset


def tokenize_function(
    examples,
    teacher_tokenizer: PreTrainedTokenizerBase,
    student_tokenizer: PreTrainedTokenizerBase,
    word_splitter: OffsetBasedWordSplitter,
    max_length: int,
) -> dict:
    teacher_output = teacher_tokenizer(
        examples['text'],
        return_special_tokens_mask=True,
        return_offsets_mapping=True,
    )

    student_output = student_tokenizer(
        examples['text'],
        return_special_tokens_mask=True,
        return_offsets_mapping=True,
    )

    outputs = {
        'teacher_input_ids': [],
        'teacher_word_ids': [],
        'teacher_attention_mask': [],
        'student_input_ids': [],
        'student_word_ids': [],
        'student_attention_mask': [],
    }

    # splitting into multiple examples if the input is too long
    for doc_idx in range(len(examples['text'])):
        teacher_input_ids = teacher_output['input_ids'][doc_idx]
        teacher_attention_mask = teacher_output['attention_mask'][doc_idx]

        student_input_ids = student_output['input_ids'][doc_idx]
        student_attention_mask = student_output['attention_mask'][doc_idx]

        teacher_word_ids, student_word_ids = word_splitter.get_word_ids(
            teacher_offsets=teacher_output['offset_mapping'][doc_idx],
            student_offsets=student_output['offset_mapping'][doc_idx],
            teacher_special_tokens_mask=teacher_output['special_tokens_mask'][doc_idx],
            student_special_tokens_mask=student_output['special_tokens_mask'][doc_idx],
        )

        # TODO: implement appropriate splitting
        if len(teacher_input_ids) > max_length or len(student_input_ids) > max_length:
            max_teacher_word_id = max(teacher_word_ids) \
                if len(teacher_word_ids) <= max_length else max(teacher_word_ids[:max_length+1]) - 1
            max_student_word_id = max(student_word_ids) \
                if len(student_word_ids) <= max_length else max(student_word_ids[:max_length+1]) - 1

            max_common_word_id = min(max_teacher_word_id, max_student_word_id)

            teacher_cut_idx = np.searchsorted(teacher_word_ids, max_common_word_id)
            student_cut_idx = np.searchsorted(student_word_ids, max_common_word_id)

            # cutting potential special tokens should not harm decoder-only models
            # FIXME: what do we do with encoder-only models?
            teacher_input_ids = teacher_input_ids[:teacher_cut_idx]
            teacher_attention_mask = teacher_attention_mask[:teacher_cut_idx]
            teacher_word_ids = teacher_word_ids[:teacher_cut_idx]

            student_input_ids = student_input_ids[:student_cut_idx]
            student_attention_mask = student_attention_mask[:student_cut_idx]
            student_word_ids = student_word_ids[:student_cut_idx]

        outputs['teacher_input_ids'].append(teacher_input_ids)
        outputs['teacher_attention_mask'].append(teacher_attention_mask)
        outputs['teacher_word_ids'].append(teacher_word_ids)

        outputs['student_input_ids'].append(student_input_ids)
        outputs['student_attention_mask'].append(student_attention_mask)
        outputs['student_word_ids'].append(student_word_ids)

        if not len(teacher_word_ids) or not len(student_word_ids) or max(teacher_word_ids) != max(student_word_ids):
            print(f'{doc_idx=}')
            # print(examples['id'][doc_idx])
            print(repr(examples['text'][doc_idx]))

            # print('max_teacher_word_id', max_teacher_word_id)
            # print('max_student_word_id', max_student_word_id)

            # print('teacher_cut_idx', teacher_cut_idx)
            # print('student_cut_idx', student_cut_idx)

            # print('max(student_word_ids)', max(original_student_word_ids))
            # print('student_word_ids[max_length]', original_student_word_ids[max_length])
            # print('student_word_ids', original_student_word_ids)

            # print('len teacher student', len(teacher_word_ids), len(student_word_ids))
            # print('max teacher student', max(teacher_word_ids), max(student_word_ids))
            # raise ValueError('xx')
            # print(teacher_word_ids)
            # print(student_word_ids)
            # print(teacher_input_ids)
            # print(student_input_ids)
            # print(teacher_attention_mask)
            continue

    return outputs


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--dataset_path', type=str, required=True)
    parser.add_argument('--teacher_tokenizer_path', type=str, required=True)
    parser.add_argument('--student_tokenizer_path', type=str, required=True)
    parser.add_argument('--max_length', type=int, default=1024)
    parser.add_argument('--output_path', type=str, required=True)
    args = parser.parse_args()
    
    match args.dataset_path:
        case x if x.endswith('.jsonl'):
            dataset_type = 'jsonl'
        case x if x.endswith('.parquet'):
            dataset_type = 'parquet'
        case _:
            dataset_type = 'hf'
    
    dataset = instantiate_dataset(args.dataset_path, dataset_type)

    teacher_tokenizer = AutoTokenizer.from_pretrained(args.teacher_tokenizer_path)
    student_tokenizer = AutoTokenizer.from_pretrained(args.student_tokenizer_path)

    word_splitter = OffsetBasedWordSplitter(teacher_tokenizer, student_tokenizer)

    dataset = dataset.map(
        tokenize_function,
        batched=True,
        batch_size=1000,
        num_proc=os.cpu_count(),
        fn_kwargs={
            'teacher_tokenizer': teacher_tokenizer,
            'student_tokenizer': student_tokenizer,
            'word_splitter': word_splitter,
            'max_length': args.max_length,
        }
    )

    dataset.save_to_disk(args.output_path)
