import logging
import numpy as np
from time import time

from densephrases.models.index import MIPS


logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


class MIPSHybrid(MIPS):
    def __init__(self, *args, **kwargs):
        logger.info("Loading a hybrid (dense+sparse) version of MIPS")
        super(MIPSHybrid, self).__init__(*args, **kwargs)

    def search_dense(self, query, q_texts, nprobe=256, top_k=10, sparse_weight=0.05):
        batch_size, d = query.shape
        self.index.nprobe = nprobe

        # Stack start/end and benefit from multi-threading
        start_time = time()
        query = query.astype(np.float32)
        query_start, query_end = np.split(query, 2, axis=1)
        query_concat = np.concatenate((query_start, query_end), axis=0)

        # Search with faiss
        b_scores, I = self.index.search(query_concat, top_k)
        b_start_scores, start_I = b_scores[:batch_size,:], I[:batch_size,:]
        b_end_scores, end_I = b_scores[batch_size:,:], I[batch_size:,:]
        logger.debug(f'1) {time()-start_time:.3f}s: MIPS')

        # Get idxs from resulting I
        start_time = time()
        b_start_doc_idxs, b_start_idxs = self.get_idxs(start_I)
        b_end_doc_idxs, b_end_idxs = self.get_idxs(end_I)

        # Number of unique docs
        num_docs = sum(
            [len(set(s_doc.flatten().tolist() + e_doc.flatten().tolist())) for s_doc, e_doc in zip(b_start_doc_idxs, b_end_doc_idxs)]
        ) / batch_size
        self.num_docs_list.append(num_docs)
        logger.debug(f'2) {time()-start_time:.3f}s: get index')

        # Doc-level sparse score
        b_start_doc_scores = self.doc_rank_fn["index"](q_texts, b_start_doc_idxs.tolist()) # Index
        for b_idx in range(batch_size):
            b_start_scores[b_idx] += np.array(b_start_doc_scores[b_idx]) * sparse_weight
        b_end_doc_scores = self.doc_rank_fn["index"](q_texts, b_end_doc_idxs.tolist()) # Index
        for b_idx in range(batch_size):
            b_end_scores[b_idx] += np.array(b_end_doc_scores[b_idx]) * sparse_weight
        logger.debug(f'3) {time()-start_time:.3f}s: get doc scores')

        return [b_start_doc_idxs, b_start_idxs, start_I, b_end_doc_idxs, b_end_idxs, end_I, b_start_scores, b_end_scores]

    def search_sparse(self, query, q_texts, doc_top_k, top_k, sparse_weight=0.05):
        batch_size = query.shape[0]
        query = query.astype(np.float32)
        query_start, query_end = np.split(query, 2, axis=1)

        # Reduce search space by doc scores
        top_doc_idxs, top_doc_scores = self.doc_rank_fn['top_docs'](q_texts, doc_top_k) # Top docs

        # For each item, add start scores
        self.open_dumps()
        b_start_doc_idxs = []
        b_start_idxs = []
        b_start_scores = []
        b_end_doc_idxs = []
        b_end_idxs = []
        b_end_scores = []
        max_phrases = 0
        for b_idx in range(batch_size):
            start_doc_idxs = []
            start_idxs = []
            start_scores = []
            end_doc_idxs = []
            end_idxs = []
            end_scores = []
            for doc_idx, doc_score in zip(top_doc_idxs[b_idx], top_doc_scores[b_idx]):
                try:
                    doc_group = self.get_doc_group(doc_idx)
                except ValueError:
                    continue
                vector_set = self.dequant(doc_group.attrs['offset'], doc_group.attrs['scale'], doc_group['start'][:])
                tmp_start_scores = np.sum(query_start[b_idx] * vector_set, 1)
                for i, cur_score in enumerate(tmp_start_scores):
                    start_doc_idxs.append(doc_idx)
                    start_idxs.append(i)
                    start_scores.append(cur_score + sparse_weight * doc_score)
                tmp_end_scores = np.sum(query_end[b_idx] * vector_set, 1)
                for i, cur_score in enumerate(tmp_end_scores):
                    end_doc_idxs.append(doc_idx)
                    end_idxs.append(i)
                    end_scores.append(cur_score + sparse_weight * doc_score)
            max_phrases = len(start_scores) if len(start_scores) > max_phrases else max_phrases

            b_start_doc_idxs.append(start_doc_idxs)
            b_start_idxs.append(start_idxs)
            b_start_scores.append(start_scores)
            b_end_doc_idxs.append(end_doc_idxs)
            b_end_idxs.append(end_idxs)
            b_end_scores.append(end_scores)
        self.close_dumps()

        # If start_top_k is larger than nonnegative doc_idxs, we need to cut them later
        for start_doc_idxs, start_idxs, start_scores in zip(b_start_doc_idxs, b_start_idxs, b_start_scores):
            start_doc_idxs += [-1] * (max_phrases - len(start_doc_idxs))
            start_idxs += [-1] * (max_phrases - len(start_idxs))
            start_scores += [-10**9] * (max_phrases - len(start_scores))

        for end_doc_idxs, end_idxs, end_scores in zip(b_end_doc_idxs, b_end_idxs, b_end_scores):
            end_doc_idxs += [-1] * (max_phrases - len(end_doc_idxs))
            end_idxs += [-1] * (max_phrases - len(end_idxs))
            end_scores += [-10**9] * (max_phrases - len(end_scores))

        start_doc_idxs, start_idxs, start_scores = np.stack(b_start_doc_idxs), np.stack(b_start_idxs), np.stack(b_start_scores)
        end_doc_idxs, end_idxs, end_scores = np.stack(b_end_doc_idxs), np.stack(b_end_idxs), np.stack(b_end_scores)
        return [start_doc_idxs, start_idxs, None, end_doc_idxs, end_idxs, None, start_scores, end_scores]

    def search_start(self, query, q_texts,
                     nprobe=256, top_k=10,
                     doc_top_k=5, search_strategy='hybrid', sparse_weight=0.05):

        # Branch based on the strategy (top_k) + doc_score
        if search_strategy == 'dense_first':
            outs = self.search_dense(
                query, q_texts, nprobe, top_k, sparse_weight
            )
        elif search_strategy == 'sparse_first':
            outs = self.search_sparse(
                query, q_texts, doc_top_k, top_k, sparse_weight
            )
        elif search_strategy == 'hybrid':
            dense_outs = self.search_dense(
                query, q_texts, nprobe, top_k, sparse_weight
            )
            sparse_outs = self.search_sparse(
                query, q_texts, doc_top_k, top_k, sparse_weight
            )

            # There could be a duplicate but it's difficult to remove
            start_doc_idxs = np.concatenate([dense_outs[0], sparse_outs[0]], -1)
            start_idxs = np.concatenate([dense_outs[1], sparse_outs[1]], -1)
            start_scores = np.concatenate([dense_outs[6], sparse_outs[6]], -1)
            end_doc_idxs = np.concatenate([dense_outs[3], sparse_outs[3]], -1)
            end_idxs = np.concatenate([dense_outs[4], sparse_outs[4]], -1)
            end_scores = np.concatenate([dense_outs[7], sparse_outs[7]], -1)
            outs = [start_doc_idxs, start_idxs, None, end_doc_idxs, end_idxs, None, start_scores, end_scores]
        else:
            raise ValueError(search_strategy)

        # Start reranking
        batch_size = query.shape[0]
        rerank_idxs = np.argsort(outs[6], axis=1)[:,-top_k:][:,::-1]
        outs[0] = outs[0].tolist()
        outs[1] = outs[1].tolist()
        outs[6] = outs[6].tolist()
        for b_idx in range(batch_size):
            outs[0][b_idx] = np.array(outs[0][b_idx])[rerank_idxs[b_idx]]
            outs[1][b_idx] = np.array(outs[1][b_idx])[rerank_idxs[b_idx]]
            outs[6][b_idx] = np.array(outs[6][b_idx])[rerank_idxs[b_idx]]
        start_doc_idxs = np.stack(outs[0])
        start_idxs = np.stack(outs[1])
        start_scores = np.stack(outs[6])

        # End reranking
        rerank_idxs = np.argsort(outs[7], axis=1)[:,-top_k:][:,::-1]
        outs[3] = outs[3].tolist()
        outs[4] = outs[4].tolist()
        outs[7] = outs[7].tolist()
        for b_idx in range(batch_size):
            outs[3][b_idx] = np.array(outs[3][b_idx])[rerank_idxs[b_idx]]
            outs[4][b_idx] = np.array(outs[4][b_idx])[rerank_idxs[b_idx]]
            outs[7][b_idx] = np.array(outs[7][b_idx])[rerank_idxs[b_idx]]
        end_doc_idxs = np.stack(outs[3])
        end_idxs = np.stack(outs[4])
        end_scores = np.stack(outs[7])
        # logger.info(f'2nd rerank ({mid_top_k} => {top_k}), {start_scores.shape}, {time()-start_time}')
        return start_doc_idxs, start_idxs, None, end_doc_idxs, end_idxs, None, start_scores, end_scores

    def search(self, query, q_texts=None,
               nprobe=256, top_k=10,
               aggregate=False, return_idxs=False,
               max_answer_length=10,
               doc_top_k=5, search_strategy='hybrid', sparse_weight=0.05):

        # MIPS on start/end
        start_time = time()
        start_doc_idxs, start_idxs, start_I, end_doc_idxs, end_idxs, end_I, start_scores, end_scores = self.search_start(
            query,
            q_texts=q_texts,
            nprobe=nprobe,
            top_k=top_k,
            doc_top_k=doc_top_k,
            search_strategy=search_strategy,
            sparse_weight=sparse_weight,
        )
        logger.debug(f'Top-{top_k} MIPS: {time()-start_time:.3f}s')

        # Search phrase
        start_time = time()
        outs = self.search_phrase(
            query, start_doc_idxs, start_idxs, start_I, end_doc_idxs, end_idxs, end_I, start_scores, end_scores,
            top_k=top_k, max_answer_length=max_answer_length, return_idxs=return_idxs,
        )
        logger.debug(f'Top-{top_k} phrase search: {time()-start_time:.3f}s')

        # Aggregate
        outs = [self.aggregate_results(results, top_k, q_text) for results, q_text in zip(outs, q_texts)]
        if start_doc_idxs.shape[1] != top_k:
            logger.info(f"Warning.. {doc_idxs.shape[1]} only retrieved")
            top_k = start_doc_idxs.shape[1]

        return outs
