import os
import os.path as osp
import torch
import logging
from federatedscope.nlp.hetero_tasks.dataset.utils import split_sent, \
    DatasetDict, NUM_DEBUG

logger = logging.getLogger(__name__)


class NewsQAExample(object):
    def __init__(self, qa_id, question, context, train_answer, val_answer,
                 start_pos, end_pos, context_tokens, is_impossible):
        self.qa_id = qa_id
        self.question = question
        self.context = context
        self.train_answer = train_answer
        self.val_answer = val_answer
        self.start_position = start_pos
        self.end_position = end_pos
        self.context_tokens = context_tokens
        self.is_impossible = is_impossible


class NewsQAEncodedInput(object):
    def __init__(self, token_ids, token_type_ids, attention_mask,
                 overflow_token_ids):
        self.token_ids = token_ids
        self.token_type_ids = token_type_ids
        self.attention_mask = attention_mask
        self.overflow_token_ids = overflow_token_ids


class NewsQAResult(object):
    def __init__(self, unique_id, start_logits, end_logits):
        self.unique_id = unique_id
        self.start_logits = start_logits
        self.end_logits = end_logits


def refine_subtoken_position(context_subtokens, subtoken_start_pos,
                             subtoken_end_pos, tokenizer, annotated_answer):
    subtoken_answer = ' '.join(tokenizer.tokenize(annotated_answer))
    for new_st in range(subtoken_start_pos, subtoken_end_pos + 1):
        for new_ed in range(subtoken_end_pos, subtoken_start_pos - 1, -1):
            text_span = ' '.join(context_subtokens[new_st:(new_ed + 1)])
            if text_span == subtoken_answer:
                return new_st, new_ed
    return subtoken_start_pos, subtoken_end_pos


def get_char_to_word_positions(context, answer, start_char_pos, is_impossible):
    context_tokens = []
    char_to_word_offset = []
    is_prev_whitespace = True
    for c in context:
        is_whitespace = (c == ' ' or c == '\t' or c == '\r' or c == '\n'
                         or ord(c) == 0x202F)
        if is_whitespace:
            is_prev_whitespace = True
        else:
            if is_prev_whitespace:
                context_tokens.append(c)
            else:
                context_tokens[-1] += c
            is_prev_whitespace = False
        char_to_word_offset.append(len(context_tokens) - 1)

    start_pos, end_pos = 0, 0
    if start_char_pos is not None and not is_impossible:
        start_pos = char_to_word_offset[start_char_pos]
        end_pos = char_to_word_offset[start_char_pos + len(answer) - 1]
    return start_pos, end_pos, context_tokens


def check_max_context_token(all_spans, cur_span_idx, pos):
    best_score, best_span_idx = None, None
    for span_idx, span in enumerate(all_spans):
        end = span.context_start_position + span.context_len - 1
        if pos < span.context_start_position or pos > end:
            continue
        num_left_context = pos - span.context_start_position
        num_right_context = end - pos
        score = \
            min(num_left_context, num_right_context) + 0.01 * span.context_len
        if best_score is None or score > best_score:
            best_score = score
            best_span_idx = span_idx
    return cur_span_idx == best_span_idx


def encode(tokenizer, text_a, text_b, max_seq_len, max_query_len,
           added_trunc_size):
    def _get_token_ids(text):
        if isinstance(text, str):
            return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
        elif isinstance(text, (list, tuple)) and len(text) > 0 and \
                isinstance(text[0], str):
            return tokenizer.convert_tokens_to_ids(text)
        elif isinstance(text, (list, tuple)) and len(text) > 0 and \
                isinstance(text[0], int):
            return text
        else:
            raise ValueError('Input is not valid, should be a string, '
                             'a list/tuple of strings or a list/tuple of '
                             'integers.')

    token_ids_a = _get_token_ids(text_a)
    token_ids_b = _get_token_ids(text_b)

    # Truncate
    overflow_token_ids = None
    len_a = len(token_ids_a) + 2
    total_len = len(token_ids_a) + len(token_ids_b) + 3
    if len_a > max_query_len:
        num_remove = len_a - max_query_len
        token_ids_a = token_ids_a[:-num_remove]
    if total_len > max_seq_len:
        num_remove = total_len - max_seq_len
        trunc_size = min(len(token_ids_b), added_trunc_size + num_remove)
        overflow_token_ids = token_ids_b[-trunc_size:]
        token_ids_b = token_ids_b[:-num_remove]

    # Combine and pad
    token_ids = \
        [tokenizer.cls_token_id] + token_ids_a + [tokenizer.sep_token_id]
    token_type_ids = [0] * len(token_ids)
    token_ids += token_ids_b + [tokenizer.sep_token_id]
    token_type_ids += [1] * (len(token_ids_b) + 1)
    attention_mask = [1] * len(token_ids)
    if len(token_ids) < max_seq_len:
        dif = max_seq_len - len(token_ids)
        token_ids += [tokenizer.pad_token_id] * dif
        token_type_ids += [0] * dif
        attention_mask += [0] * dif

    return NewsQAEncodedInput(token_ids, token_type_ids, attention_mask,
                              overflow_token_ids)


def get_newsqa_examples(data, split, is_debug=False):
    if is_debug:
        data = data[:NUM_DEBUG]
    examples = []
    for para in data:
        context = para['context']
        qa = para['qa']
        qa_id = qa['qid']
        question = qa['question']
        start_char_pos = None
        train_answer = None
        val_answer = []

        is_impossible = qa['is_impossible'] if 'is_impossible' in qa else False
        if not is_impossible:
            answers = qa['detected_answers']
            spans = sorted(
                [span for spans in answers for span in spans['char_spans']])
            if split == 'train':
                train_answer = context[spans[0][0]:spans[0][1] + 1]
                start_char_pos = spans[0][0]
            else:
                val_answer = [{
                    'text': context[spans[i][0]:spans[i][1] + 1],
                    'answer_start': spans[i][0]
                } for i in range(len(spans))]

        start_pos, end_pos, context_tokens = get_char_to_word_positions(
            context, train_answer, start_char_pos, is_impossible)
        examples.append(
            NewsQAExample(qa_id, question, context, train_answer, val_answer,
                          start_pos, end_pos, context_tokens, is_impossible))
    return examples


def process_newsqa_dataset(data,
                           split,
                           tokenizer,
                           max_seq_len,
                           max_query_len,
                           trunc_stride,
                           cache_dir='',
                           client_id=None,
                           pretrain=False,
                           is_debug=False,
                           **kwargs):
    if pretrain:
        return process_newsqa_dataset_for_pretrain(data, split, tokenizer,
                                                   max_seq_len, cache_dir,
                                                   client_id, is_debug)

    save_dir = osp.join(cache_dir, 'train', str(client_id))
    cache_file = osp.join(save_dir, split + '.pt')
    if osp.exists(cache_file):
        logger.info('Loading cache file from \'{}\''.format(cache_file))
        cache_data = torch.load(cache_file)
        examples = cache_data['examples']
        encoded_inputs = cache_data['encoded_inputs']
    else:
        examples = get_newsqa_examples(data, split, is_debug)
        unique_id = 1000000000
        encoded_inputs = []
        for example_idx, example in enumerate(examples):
            if split == 'train' and not example.is_impossible:
                start_pos = example.start_position
                end_pos = example.end_position
                actual_answer = ' '.join(
                    example.context_tokens[start_pos:(end_pos + 1)])
                cleaned_answer = ' '.join(example.train_answer.strip().split())
                if actual_answer.find(cleaned_answer) == -1:
                    logger.info('Could not find answer: {} vs. {}'.format(
                        actual_answer, cleaned_answer))
                    continue

            tok_to_subtok_idx = []
            subtok_to_tok_idx = []
            context_subtokens = []
            for i, token in enumerate(example.context_tokens):
                tok_to_subtok_idx.append(len(context_subtokens))
                subtokens = tokenizer.tokenize(token)
                for subtoken in subtokens:
                    subtok_to_tok_idx.append(i)
                    context_subtokens.append(subtoken)

            if split == 'train' and not example.is_impossible:
                subtoken_start_pos = tok_to_subtok_idx[example.start_position]
                if example.end_position < len(example.context_tokens) - 1:
                    subtoken_end_pos = tok_to_subtok_idx[example.end_position +
                                                         1] - 1
                else:
                    subtoken_end_pos = len(context_subtokens) - 1
                subtoken_start_pos, subtoken_end_pos = \
                    refine_subtoken_position(context_subtokens,
                                             subtoken_start_pos,
                                             subtoken_end_pos,
                                             tokenizer,
                                             example.train_answer)

            truncated_context = context_subtokens
            len_question = min(len(tokenizer.tokenize(example.question)),
                               max_query_len - 2)
            added_trunc_size = max_seq_len - trunc_stride - len_question - 3
            spans = []
            while len(spans) * trunc_stride < len(context_subtokens):
                text_a = example.question
                text_b = truncated_context
                encoded_input = encode(tokenizer, text_a, text_b, max_seq_len,
                                       max_query_len, added_trunc_size)
                context_start_pos = len(spans) * trunc_stride
                context_len = min(
                    len(context_subtokens) - context_start_pos,
                    max_seq_len - len_question - 3)
                context_end_pos = context_start_pos + context_len - 1

                if tokenizer.pad_token_id in encoded_input.token_ids:
                    non_padded_ids = encoded_input.token_ids[:encoded_input.
                                                             token_ids.index(
                                                                 tokenizer.
                                                                 pad_token_id)]
                else:
                    non_padded_ids = encoded_input.token_ids
                tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)

                context_subtok_to_tok_idx = {}
                for i in range(context_len):
                    context_idx = len_question + i + 2
                    context_subtok_to_tok_idx[context_idx] = \
                        subtok_to_tok_idx[context_start_pos + i]

                start_pos, end_pos = 0, 0
                span_is_impossible = example.is_impossible
                if split == 'train' and not span_is_impossible:
                    # For training, if our document chunk does not contain
                    # an annotation we throw it out, since there is nothing
                    # to predict.
                    if subtoken_start_pos >= context_start_pos and \
                            subtoken_end_pos <= context_end_pos:
                        context_offset = len_question + 2
                        start_pos = \
                            subtoken_start_pos - context_start_pos + \
                            context_offset
                        end_pos = \
                            subtoken_end_pos - context_start_pos + \
                            context_offset
                    else:
                        start_pos = 0
                        end_pos = 0
                        span_is_impossible = True

                encoded_input.start_position = start_pos
                encoded_input.end_position = end_pos
                encoded_input.is_impossible = span_is_impossible

                # For computing metrics
                encoded_input.example_index = example_idx
                encoded_input.context_start_position = context_start_pos
                encoded_input.context_len = context_len
                encoded_input.tokens = tokens
                encoded_input.context_subtok_to_tok_idx = \
                    context_subtok_to_tok_idx
                encoded_input.is_max_context_token = {}
                encoded_input.unique_id = unique_id
                spans.append(encoded_input)
                unique_id += 1

                if encoded_input.overflow_token_ids is None:
                    break
                truncated_context = encoded_input.overflow_token_ids

            for span_idx in range(len(spans)):
                for context_idx in range(spans[span_idx].context_len):
                    is_max_context_token = check_max_context_token(
                        spans, span_idx, span_idx * trunc_stride + context_idx)
                    idx = len_question + context_idx + 2
                    spans[span_idx].is_max_context_token[idx] = \
                        is_max_context_token
            encoded_inputs.extend(spans)

        if cache_dir:
            logger.info('Saving cache file to \'{}\''.format(cache_file))
            os.makedirs(save_dir, exist_ok=True)
            torch.save({
                'examples': examples,
                'encoded_inputs': encoded_inputs
            }, cache_file)

    token_ids = torch.LongTensor([inp.token_ids for inp in encoded_inputs])
    token_type_ids = torch.LongTensor(
        [inp.token_type_ids for inp in encoded_inputs])
    attention_mask = torch.LongTensor(
        [inp.attention_mask for inp in encoded_inputs])
    start_positions = torch.LongTensor(
        [inp.start_position for inp in encoded_inputs])
    end_positions = torch.LongTensor(
        [inp.end_position for inp in encoded_inputs])

    example_indices = torch.arange(token_ids.size(0), dtype=torch.long)
    dataset = DatasetDict({
        'token_ids': token_ids,
        'token_type_ids': token_type_ids,
        'attention_mask': attention_mask,
        'start_positions': start_positions,
        'end_positions': end_positions,
        'example_indices': example_indices
    })
    return dataset, encoded_inputs, examples


def process_newsqa_dataset_for_pretrain(data,
                                        split,
                                        tokenizer,
                                        max_seq_len,
                                        cache_dir='',
                                        client_id=None,
                                        is_debug=False):
    save_dir = osp.join(cache_dir, 'pretrain', str(client_id))
    cache_file = osp.join(save_dir, split + '.pt')
    if osp.exists(cache_file):
        logger.info('Loading cache file from \'{}\''.format(cache_file))
        cache_data = torch.load(cache_file)
        examples = cache_data['examples']
        encoded_inputs = cache_data['encoded_inputs']
    else:
        examples = get_newsqa_examples(data, split, is_debug)
        texts = split_sent([e.context for e in examples],
                           eoq=tokenizer.eoq_token)
        encoded_inputs = tokenizer(texts,
                                   padding='max_length',
                                   truncation=True,
                                   max_length=max_seq_len,
                                   return_tensors='pt')
        num_non_padding = (encoded_inputs.input_ids !=
                           tokenizer.pad_token_id).sum(dim=-1)
        for i, pad_idx in enumerate(num_non_padding):
            encoded_inputs.input_ids[i, 0] = tokenizer.bos_token_id
            encoded_inputs.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id

        if cache_dir:
            logger.info('Saving cache file to \'{}\''.format(cache_file))
            os.makedirs(save_dir, exist_ok=True)
            torch.save({
                'examples': examples,
                'encoded_inputs': encoded_inputs
            }, cache_file)

    example_indices = torch.arange(encoded_inputs.input_ids.size(0),
                                   dtype=torch.long)
    dataset = DatasetDict({
        'token_ids': encoded_inputs.input_ids,
        'attention_mask': encoded_inputs.attention_mask,
        'example_indices': example_indices
    })
    return dataset, encoded_inputs, examples
