#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


"""
 Pipeline to train the reader model on top of the retriever results
"""

import collections
import json
import sys

import hydra
import logging
import numpy as np
import os
import torch
import csv

from collections import defaultdict
from omegaconf import DictConfig, OmegaConf
from typing import List
from tqdm import tqdm


from dpr.data.qa_validation import exact_match_score
from dpr.data.reader_data import (
    ReaderSample,
    ReaderPassage,
    get_best_spans,
    SpanPrediction,
    ExtractiveReaderDataset,
)
from dpr.models import init_reader_components
from dpr.models.reader import create_reader_input, ReaderBatch, compute_loss
from dpr.options import (
    setup_cfg_gpu,
    set_seed,
    set_cfg_params_from_state,
    get_encoder_params_state_from_cfg,
    setup_logger,
)
from dpr.utils.data_utils import (
    ShardedDataIterator,
)
from dpr.utils.model_utils import (
    get_schedule_linear,
    load_states_from_checkpoint,
    move_to_device,
    CheckpointState,
    get_model_file,
    setup_for_distributed_mode,
    get_model_obj,
)

logger = logging.getLogger()
setup_logger(logger)

ReaderQuestionPredictions = collections.namedtuple("ReaderQuestionPredictions", ["id", "predictions", "gold_answers"])

def _encode_answer(txt, tensorizer):
    token_ids = tensorizer.tokenizer.encode(
        txt,
        add_special_tokens=False,
        max_length=10000,
        pad_to_max_length=False,
        truncation=True,
    )
    return torch.tensor(token_ids)

def find_answer_spans(ctx: ReaderPassage, answers, question, sep_tensor):
    if ctx.has_answer:

        answers_token_ids = answers
        answer_spans = [
            _find_answer_positions(ctx.passage_token_ids, answers_token_ids[i]) for i in range(len(answers))
        ]
    
        # flatten spans list
        answer_spans = [item for sublist in answer_spans for item in sublist]
        answers_spans = list(filter(None, answer_spans))
        ctx.answers_spans = answers_spans

        if not answers_spans:
            logger.warning(
                "No answer found in passage id=%s text=%s, answers=%s, question=%s",
                ctx.id,
                # "",  
                ctx.passage_text,
                answers,
                question,
            )
        ctx.has_answer = bool(answers_spans)
    return ctx

def _find_answer_positions(ctx_ids, answer):
    c_len = ctx_ids.size(0)
    a_len = answer.size(0)
    answer_occurences = []
    for i in range(0, c_len - a_len + 1):
        if (answer == ctx_ids[i : i + a_len]).all():
            answer_occurences.append((i, i + a_len - 1))
    return answer_occurences

def _concat_pair(t1, t2, middle_sep=None, tailing_sep=None):
    middle = [middle_sep] if middle_sep else []
    r = [t1] + middle + [t2] + ([tailing_sep] if tailing_sep else [])
    return torch.cat(r, dim=0), t1.size(0) + len(middle)

def _create_reader_sample_ids(sample: ReaderPassage, question_and_title, sep_tensor):
    
    all_concatenated, shift = _concat_pair(
        question_and_title,
        sample.passage_token_ids,
        tailing_sep=None,
    )
    
    sample.sequence_ids = all_concatenated
    sample.passage_offset = shift
    assert shift > 1
    if sample.has_answer:
        sample.answers_spans = [(span[0] + shift, span[1] + shift) for span in sample.answers_spans]
    
    return sample

class ReaderDatasetForScoring(torch.utils.data.Dataset):
    def __init__(
        self,
        cfg,
        tensorizer
    ):
        self.cfg = cfg
        self.tensorizer = tensorizer
        self.data = list()
        with open(cfg.input_file, 'r') as f:
            csv_reader = csv.reader(f, delimiter='\t')
            for line in csv_reader:
                qid, que, pid, psg, title, ans_cf_psg, sent_cf_psg, tr_psg, ans, ans_sent, percentage, tok_diff = line
                sep_tensor = self.tensorizer.get_pair_separator_ids()  # separator can be a multi token
                question_and_title = self.tensorizer.text_to_tensor(title, title=que, add_special_tokens=True)
                psg_tokens = tensorizer.text_to_tensor(psg, add_special_tokens=False, apply_max_len=False)
                cf_psg_tokens = tensorizer.text_to_tensor(sent_cf_psg, add_special_tokens=False, apply_max_len=False)
                answers_ids = [_encode_answer(a, self.tensorizer) for a in eval(ans)]
                
                self.data.append([que, title, psg, sent_cf_psg, ans, question_and_title, psg_tokens, cf_psg_tokens, answers_ids, sep_tensor])
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, i):
        return self.data[i]

def reader_collate(
    sample
):
    r_samples = list()
    for s in sample:
        que, title, psg, sent_cf_psg, ans, \
            question_and_title, psg_tokens, cf_psg_tokens, answers_ids, sep_tensor = s

        pos_psg = ReaderPassage(
            text=psg, 
            title=title,
            has_answer=True,
        )
        pos_psg.passage_token_ids = psg_tokens
        pos_psg = find_answer_spans(pos_psg, answers_ids, que, sep_tensor)

        neg_psg = ReaderPassage(
            text=sent_cf_psg, 
            title=title,
            has_answer=False,
        )
        neg_psg.passage_token_ids = cf_psg_tokens

        r_sample = ReaderSample(
            question=que,
            answers=ans,
            passages = [_create_reader_sample_ids(pos_psg, question_and_title, sep_tensor), 
                    _create_reader_sample_ids(neg_psg, question_and_title, sep_tensor)
            ],
        )
        r_samples.append(r_sample)
    return r_samples

@hydra.main(config_path="conf", config_name="reader_score")
def main(cfg: DictConfig):

    if cfg.output_dir is not None:
        os.makedirs(cfg.output_dir, exist_ok=True)

    cfg = setup_cfg_gpu(cfg)
    set_seed(cfg)

    if cfg.local_rank in [-1, 0]:
        logger.info("CFG (after gpu  configuration):")
        logger.info("%s", OmegaConf.to_yaml(cfg))

    logger.info("***** Initializing components for training *****")
    
    model_file = get_model_file(cfg, cfg.checkpoint_file_name)
    saved_state = None
    if model_file:
        logger.info("!! model_file = %s", model_file)
        saved_state = load_states_from_checkpoint(model_file)
        set_cfg_params_from_state(saved_state.encoder_params, cfg)
    
    tensorizer, reader, optimizer = init_reader_components(cfg.encoder.encoder_model_type, cfg)
    reader, optimizer = setup_for_distributed_mode(
        reader,
        optimizer,
        cfg.device,
        cfg.n_gpu,
        cfg.local_rank,
        cfg.fp16,
        cfg.fp16_opt_level,
    )
    reader.eval()

    dataset = ReaderDatasetForScoring(cfg,tensorizer)
    logger.info(f'Total Data : {len(dataset)}')
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=reader_collate
    )

    for samples in tqdm(data_loader):
        input = create_reader_input(
            tensorizer.get_pad_id(),
            samples,
            cfg.passages_per_question_predict,
            cfg.encoder.sequence_length,
            cfg.max_n_answers,
            is_train=False,
            shuffle=False,
            sep_token_id=tensorizer.tokenizer.sep_token_id,  # TODO: tmp
        )
        
        input = ReaderBatch(**move_to_device(input._asdict(), cfg.device))
        attn_mask = tensorizer.get_attn_mask(input.input_ids)

        with torch.no_grad():
            start_logits, end_logits, relevance_logits = reader(
                input.input_ids, attn_mask, input.token_type_ids,
            )
        
        print(relevance_logits.shape)

if __name__ == "__main__":
    logger.info("Sys.argv: %s", sys.argv)
    hydra_formatted_args = []
    # convert the cli params added by torch.distributed.launch into Hydra format
    for arg in sys.argv:
        if arg.startswith("--"):
            hydra_formatted_args.append(arg[len("--") :])
        else:
            hydra_formatted_args.append(arg)
    logger.info("Hydra formatted Sys.argv: %s", hydra_formatted_args)
    sys.argv = hydra_formatted_args
    main()
