import json
import torch
import requests
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, Dict, Any, Set, Union

from hexa.utils import logging
from hexa.utils.message import Message
from hexa.utils.document import Document
from hexa.utils.constants import Constant

CONST = Constant()

CONTENT = 'content'
DEFAULT_NUM_TO_RETRIEVE = 5
BLANK_SEARCH_DOC = {'url': None, 'content': '', 'title': ''}


def chunk_docs_in_message(message, chunk_sz):
    if CONST.RETRIEVED_DOCS not in message:
        return message
    new_message = message.copy()
    docs = message[CONST.RETRIEVED_DOCS]
    titles = message.get(CONST.RETRIEVED_DOCS_TITLES)
    urls = message.get(CONST.RETRIEVED_DOCS_URLS)
    new_docs = []
    new_titles = []
    new_urls = []
    checked_sentences = list(
        message.get(
            CONST.SELECTED_SENTENCES,
            message.get('labels', [CONST.NO_SELECTED_SENTENCES_TOKEN]),
        )
    )
    for i in range(len(checked_sentences)):
        checked_sentences[i] = checked_sentences[i].lstrip(' ').rstrip(' ')
    if ' '.join(checked_sentences) == CONST.NO_SELECTED_SENTENCES_TOKEN:
        checked_sentences = []
    for ind in range(len(docs)):
        d = docs[ind]
        # Guarantees that checked sentences are not split in half (as we split by space).
        for i in range(len(checked_sentences)):
            d = d.replace(checked_sentences[i], "||CHECKED_SENTENCE_{i}||")
        while True:
            end_chunk = d.find(' ', chunk_sz)
            if end_chunk == -1:
                # last chunk
                for i in range(len(checked_sentences)):
                    d = d.replace("||CHECKED_SENTENCE_{i}||", checked_sentences[i])
                new_docs.append(d)
                new_titles.append(titles[ind])
                new_urls.append(urls[ind])
                break
            else:
                new_d = d[0:end_chunk]
                for i in range(len(checked_sentences)):
                    new_d = new_d.replace(
                        "||CHECKED_SENTENCE_{i}||", checked_sentences[i]
                    )
                new_docs.append(new_d)
                new_titles.append(titles[ind])
                new_urls.append(urls[ind])
                d = d[end_chunk + 1 : -1]
    new_message.force_set(CONST.RETRIEVED_DOCS, new_docs)
    new_message.force_set(CONST.RETRIEVED_DOCS_TITLES, new_titles)
    new_message.force_set(CONST.RETRIEVED_DOCS_URLS, new_urls)
    return new_message


class RetrieverAPI(ABC):
    """
    Provides the common interfaces for retrievers.

    Every retriever in this modules must implement the `retrieve` method.
    """

    def __init__(self, opt):
        self.skip_query_token = opt['skip_retrieval_token']

    @abstractmethod
    def retrieve(
        self, queries: List[str], num_ret: int = DEFAULT_NUM_TO_RETRIEVE
    ) -> List[Dict[str, Any]]:
        """
        Implements the underlying retrieval mechanism.
        """

    def create_content_dict(self, content: list, **kwargs) -> Dict:
        resp_content = {CONTENT: content}
        resp_content.update(**kwargs)
        return resp_content
    
    
class SearchEngineRetriever(RetrieverAPI):
    """
    Queries a server (eg, search engine) for a set of documents.

    This module relies on a running HTTP server. For each retrieval it sends the query
    to this server and receives a JSON; it parses the JSON to create the response.
    """

    def __init__(self, opt):
        super().__init__(opt=opt)
        self.server_address = self._validate_server(opt.get('search_server'))
        self._server_timeout = (
            opt['search_server_timeout']
            if opt.get('search_server_timeout', 0) > 0
            else None
        )
        self._max_num_retries = opt.get('max_num_retries', 0)

    def _query_search_server(self, query_term, n):
        server = self.server_address
        req = {'q': query_term, 'n': n}
        trials = []
        while True:
            try:
                logging.debug(f'sending search request to {server}')
                server_response = requests.post(
                    server, data=req, timeout=self._server_timeout
                )
                resp_status = server_response.status_code
                trials.append(f'Response code: {resp_status}')
                if resp_status == 200:
                    return server_response.json().get('response', None)
            except requests.exceptions.Timeout:
                if len(trials) > self._max_num_retries:
                    break
                trials.append(f'Timeout after {self._server_timeout} seconds.')
            except requests.exceptions.ConnectionError as errc:
                print("Error Connecting : ", errc)
                print(req)
                return None
            except requests.exceptions.HTTPError as errb:
                print("Http Error : ", errb)
                return None
            # Any Error except upper exception
            except requests.exceptions.RequestException as erra:
                print("AnyException : ", erra)
                return None
        logging.error(
            f'Failed to retrieve data from server after  {len(trials)+1} trials.'
            f'\nFailed responses: {trials}'
        )

    def _validate_server(self, address):
        if not address:
            raise ValueError('Must provide a valid server for search')
        if address.startswith('http://') or address.startswith('https://'):
            return address
        PROTOCOL = 'http://'
        logging.warning(f'No protocol provided, using "{PROTOCOL}"')
        return f'{PROTOCOL}{address}'

    def _retrieve_single(self, search_query: str, num_ret: int):
        if search_query == self.skip_query_token:
            return None

        retrieved_docs = []
        search_server_resp = self._query_search_server(search_query, num_ret)
        if not search_server_resp:
            logging.warning(
                f'Server search did not produce any results for "{search_query}" query.'
                ' returning an empty set of results for this query.'
            )
            return retrieved_docs

        for rd in search_server_resp:
            url = rd.get('url', '')
            title = rd.get('title', '')
            sentences = [s.strip() for s in rd[CONTENT].split('\n') if s and s.strip()]
            retrieved_docs.append(
                self.create_content_dict(url=url, title=title, content=sentences)
            )
        return retrieved_docs

    def retrieve(
        self, queries: List[str], num_ret: int = DEFAULT_NUM_TO_RETRIEVE
    ) -> List[Dict[str, Any]]:
        # TODO: update the server (and then this) for batch responses.
        return [self._retrieve_single(q, num_ret) for q in queries]
    

class SearchQuerySearchEngineRetriever():
    """
    A retriever that uses a search engine server for retrieving documents.

    It instantiates a `SearchEngineRetriever` object that in turns send search queries
    to an external server for retrieving documents.
    """

    def __init__(
        self, 
        opt, 
        device, 
        chunk_size: int = 500,
        n_ret_chunks: int = 1,
    ):
        # super().__init__(opt, dictionary, shared)

        self.opt = opt
        self.device = device
        self.chunk_size = chunk_size    
        self.n_ret_chunks = n_ret_chunks
        self.n_docs = opt['n_docs']
        
        self.search_client = SearchEngineRetriever(opt)
        self.opt['doc_chunks_ranker'] = 'woi_chunk_retrieved_docs'

    def _empty_docs(self, num: int):
        """
        Generates the requested number of empty documents.
        """
        # BLANK_SEARCH_DOC = {'url': None, 'content': '', 'title': ''}
        return [{'url': None, 'content': '', 'title': ''} for _ in range(num)]
        # return [BLANK_SEARCH_DOC for _ in range(num)]

    def rank_score(self, rank_id: int):
        """
        Scores the chunks of the retrieved document based on their rank.

        Note that this is the score for the retrieved document and applies to all its
        chunks.
        """
        return 1 / (1 + rank_id)

    def _display_urls(self, search_results):
        """
        Generates a string that lists retrieved URLs (document IDs).
        """
        return '\n'.join([d['url'] for d in search_results if d['url']])
    
    def get_top_chunks(
        self,
        query: str,
        doc_title: str,
        doc_chunks: Union[List[str], str],
        doc_url: str,
    ):
        """
        Return chunks according to the woi_chunk_retrieved_docs_mutator.
        """
        if isinstance(doc_chunks, list):
            docs = ''.join(doc_chunks)
        else:
            assert isinstance(doc_chunks, str)
            docs = doc_chunks
        chunks = chunk_docs_in_message(
            Message(
                {
                    CONST.RETRIEVED_DOCS: [docs],
                    CONST.RETRIEVED_DOCS_TITLES: [doc_title],
                    CONST.RETRIEVED_DOCS_URLS: [doc_url],
                    CONST.SELECTED_SENTENCES: [CONST.NO_SELECTED_SENTENCES_TOKEN],
                }
            ),
            self.chunk_size,
        )[CONST.RETRIEVED_DOCS]
        return [(c,) for c in chunks[: self.n_ret_chunks]]    
    
    def pick_chunk(self, query: str, doc_title: str, doc_text: str, doc_url: str):
        """
        Splits the document and returns the selected chunks.

        The number of returned chunks is controlled by `n_ranked_doc_chunks` in opt. The
        chunk selection is determined by `doc_chunks_ranker` in the opt.
        """
        if not doc_text:
            # When there is no search query for the context
            return [("", 0)]
        tokens = doc_text.split(' ')
        if self.opt['doc_chunks_ranker'] != 'woi_chunk_retrieved_docs':
            doc_chunks = [
                self.tokens2text(tokens[i : i + self.len_chunk])
                for i in range(0, len(tokens), self.len_chunk)
            ]
        else:
            doc_chunks = ' '.join(tokens)
        return self.get_top_chunks(query, doc_title, doc_chunks, doc_url)    

    def retrieve_and_score(
        self, query: Union[str, List[str]]
    ) -> Tuple[List[List[Document]], torch.Tensor]:
        """
        Retrieves relevant documents for the query (the conversation context). This
        method conducts three main steps that are flagged in the main code as well.

        Step 1: generate search queries for the conversation context batch.This step
        uses the query generator model (self.query_generator).

        Step 2: use the search client to retrieve documents.This step uses retrieval
        API agent (self.search_client)

        Step 3: generate the list of Document objects from the
        retrieved content. Here if the documents too long, the code splits them and
        chooses a chunk based on the selected `doc_chunks_ranker` in the opt.
        """
        # step 1
        # search_queries = self.generate_search_query(query)
        search_queries = [query] if isinstance(query, str) else query        

        # step 2
        search_results_batch = self.search_client.retrieve(search_queries, self.n_docs)

        # step 3
        top_docs = []
        top_doc_scores = []
        max_n_docs: int = self.n_docs
            
        for sq, search_results in zip(search_queries, search_results_batch):
            if not search_results:
                search_results = self._empty_docs(self.n_docs)
                
            elif len(search_results) < self.n_docs:
                remain_docs = self.n_docs - len(search_results)
                search_results.extend(self._empty_docs(remain_docs))
                
            docs_i = []
            scors_i = []
            # Change this debug later
            logging.debug(f'URLS:\n{self._display_urls(search_results)}')
            for i, doc in enumerate(search_results):
                url = doc['url']
                title = doc['title']
                dcontent = doc['content']
                assert type(dcontent) in (
                    str,
                    list,
                ), f'Unrecognized retrieved doc: {dcontent}'
                full_text = (
                    dcontent if isinstance(dcontent, str) else '\n'.join(doc['content'])
                )
                doc_chunks = [
                    dc[0] for dc in self.pick_chunk(sq, title, full_text, url)
                ]
                for splt_id, splt_content in enumerate(doc_chunks):
                    docs_i.append(
                        Document(
                            docid=url, text=splt_content, title=f'{title}_{splt_id}'
                        )
                    )
                    scors_i.append(self.rank_score(i))
            max_n_docs = max(max_n_docs, len(docs_i))
            top_docs.append(docs_i)
            top_doc_scores.append(scors_i)
            
        # Pad with empty docs
        for i in range(len(top_docs)):
            n_empty = max_n_docs - len(top_docs[i])
            if n_empty:
                top_docs[i] = top_docs[i] + [BLANK_DOC] * n_empty
                top_doc_scores[i] = top_doc_scores[i] + [0] * n_empty
                
        self.top_docs = top_docs
        return top_docs, torch.Tensor(top_doc_scores).to(self.device)