"""Answer question using linguistic backoff.
"""

import click
import dspy
import os
from tqdm import tqdm
import logging
import numpy as np
from typing import List, Text
import ujson as json


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
logger.addHandler(handler)


gpt4o_mini = dspy.OpenAI(model='gpt-4o-mini', max_tokens=1024)
gpt4o = dspy.OpenAI(model='gpt-4o', max_tokens=1024)
dspy.settings.configure(lm=gpt4o)


class KnowledgeGeneration(dspy.Signature):
    """Given a claim, generate additional knowledge for the claim and the entity mentioned."""
    claim = dspy.InputField(desc="Claim to generate knowledge for.")
    knowledge = dspy.OutputField(desc="Generated knowledge.")
    
    
class SummarizeSimilarity(dspy.Signature):
    """Given a list of claims that you generated as answers to the specific question, give some educated guess about why people may believe each of these answers are true one by one. Try to focus on what is implied by all of the claims. All these claims are generated by yourself which seems to indicate that you somewhat believe in these conflicting claims. As you are unsure about the answer, you can try to hedge the claims to make them more general, but also try to be as informative as possible."""
    question = dspy.InputField(desc="Question that induces these answers.")
    claim_list = dspy.InputField(desc="List of claims to summarize.")
    summary = dspy.OutputField(desc="Summary of the claims.")

    
class SummarizeDifference(dspy.Signature):
    """Given a list of claims that you generated as answers to the specific question, people believe that the correct answer claim is likely to fall in the first group of claims instead of the second group of claims. Try to focus on what is implied by all of the claims in the first group but not by any of the claims in the second group.
    """
    question = dspy.InputField(desc="Question that induces these answers.")
    positive_claim_list = dspy.InputField(desc="List of claims that are more likely to be correct.")
    negative_claim_list = dspy.InputField(desc="List of claims that are less likely to be correct.")
    summary = dspy.OutputField(desc="Summary of the claims.")


class Backoff(dspy.Signature):
    """Give a discussion of the properties that the correct answer will have, and the properties how correct answers are likely to disagree with incorrect answers to the given question, write a hedging claim in the style of the positive claims provided that is more vague and general than the positive claims, but still informative as to exclude the negative set."""
    question = dspy.InputField(desc="Question that induces these answers.")
    positive_claims = dspy.InputField(desc="List of claims that are more likely to be correct.")
    negative_claims = dspy.InputField(desc="List of claims that are less likely to be correct.")
    summary_positive = dspy.InputField(desc="Summary of the common properties of the positive claims.")
    summary_negative = dspy.InputField(desc="Summary of what differentiates the negative claims from the positive claims.")
    general_claim = dspy.OutputField(desc="General claim that is more vague and general than the positive claims, but still informative.")

    
class DirectClaimPrompt(dspy.Signature):
    """ Give a set of positive claims, and a set of negative claims. Try to generate a more general claim that is more vague and general. This claim should 1) entailed by all the positive claims, 2) contradicts all the negative claims, 3) be more general than the positive claims. """
    positive_claims = dspy.InputField(desc="List of claims that are more likely to be correct.")
    negative_claims = dspy.InputField(desc="List of claims that are less likely to be correct.")
    general_claim = dspy.OutputField(desc="General claim that is more vague and general than the positive claims, but still informative.")
    
    
class DirectBackoffAnswer(dspy.Module):
    def __init__(self):
        super().__init__()
        self._direct_claim_prompt = dspy.Predict(DirectClaimPrompt)
        
    def __call__(self, claims: List[Text], question: Text, num_backoff: int) -> Text:
        """ """
        claim_dicts = []
        for claim in claims:
            # response = self.knowledge_generation(claim=claim)
            # knowledge = response.knowledge
            # claim_dicts.append({"claim": claim, "knowledge": knowledge})
            claim_dicts.append({"claim": claim})
            
        positive = claim_dicts[:num_backoff]
        negative = claim_dicts[num_backoff:] if num_backoff < len(claim_dicts) else []
        
        input_positive_claims = ' '.join([f'({cidx + 1}) {claim["claim"]}' for cidx, claim in enumerate(positive)])
        input_negative_claims = ' '.join([f'({cidx + 1}) {claim["claim"]}' for cidx, claim in enumerate(negative)])
        
        generation_response = self._direct_claim_prompt(positive_claims=input_positive_claims, negative_claims=input_negative_claims)
        
        try:
            trim_index = generation_response.general_claim.replace('\n', ' ').index("General Claim: ")
            return generation_response.general_claim.replace('\n', ' ')[trim_index + len("General Claim: "):]
        except ValueError:
            return generation_response.general_claim

    
class BackoffAnswer(dspy.Module):
    def __init__(self):
        super().__init__()
        self.knowledge_generation = dspy.Predict(KnowledgeGeneration)
        self.summarize_similarity = dspy.Predict(SummarizeSimilarity)
        self.summarize_difference = dspy.Predict(SummarizeDifference)
        self.answer_with_backoff = dspy.Predict(Backoff)
        
    def __call__(
        self,
        claims: List[Text],
        question: Text,
        num_backoff: int
    ) -> Text:
        """
        """
        claim_dicts = []
        for claim in claims:
            # response = self.knowledge_generation(claim=claim)
            # knowledge = response.knowledge
            # claim_dicts.append({"claim": claim, "knowledge": knowledge})
            claim_dicts.append({"claim": claim})

        positive = claim_dicts[:num_backoff]
        negative = claim_dicts[num_backoff:] if num_backoff < len(claim_dicts) else []

        input_positive_claims = ' '.join([f'({cidx + 1}) {claim["claim"]}' for cidx, claim in enumerate(positive)])
        input_negative_claims = ' '.join([f'({cidx + 1}) {claim["claim"]}' for cidx, claim in enumerate(negative)])

        sp = self.summarize_similarity(question=question, claim_list=input_positive_claims).summary
        sn = self.summarize_difference(question=question, positive_claim_list=input_positive_claims, negative_claim_list=input_negative_claims).summary
        
        # input_info = json.dumps({"positive": positive, "negative": negative}, indent=4)
        # logger.info("Input info: \n%s\n", input_info)
        generation_response = self.answer_with_backoff(positive_claims=input_positive_claims, negative_claims=input_negative_claims, question=question, summary_positive=sp, summary_negative=sn)

        try:
            trim_index = generation_response.general_claim.replace('\n', ' ').index("General Claim: ")
            return generation_response.general_claim.replace('\n', ' ')[trim_index + len("General Claim: "):]
        except ValueError:
            return generation_response.general_claim

    
@click.command()
@click.option("--input-dir", type=click.Path(exists=True), help="Path to the input data.", required=True)
@click.option("--cluster-dir", type=click.Path(exists=True), help="Path to the cluster data.", required=True)
@click.option("--output-path", type=click.Path())
def main(
    input_dir,
    cluster_dir,
    output_path
):
    """
    """

    data = []
    idmap = {}
    entailment_mats = []

    for filename in os.listdir(input_dir):

        ipath = os.path.join(input_dir, filename)
        idmap_filename = filename.replace("result", "idmap").replace("jsonl", "json")
        numpy_filename = filename.replace("result", "entailment").replace("jsonl", "npy")

        with open(ipath, 'r', encoding='utf-8') as file_:
            data.extend([json.loads(line) for line in file_])

        with open(os.path.join(cluster_dir, idmap_filename), 'r') as file_:
            idmap.update({int(k): v + len(idmap) for k, v in json.load(file_).items()})

        this_mat = np.load(os.path.join(cluster_dir, numpy_filename))
        entailment_mats.append(this_mat.reshape(-1, 100, 100))

    entailment_mats = np.concatenate(entailment_mats, axis=0)
    data = list(filter(lambda x: len(x['meta_clusters']) == 10, data))

    # in this version we do not take into account the entailment matrix
    logger.info("Number of satisfied claims: %d", len(data))

    # agent = BackoffAnswer()
    agent = DirectBackoffAnswer()
    
    # for each claim we first generate knowledge.
    # claims = [cluster['claim'] for cluster in sorted(data[25]['clusters'], key=lambda x: len(x['set_ids']), reverse=True)]
    
    results = []

    for datapoint in tqdm(data[50:70]):
        claims = [cluster['meta-claim'] for cluster in sorted(datapoint['meta_clusters'], key=lambda x: len(x['sent_set_ids']), reverse=True)]
        question = datapoint['question']
        # backoff_claim_1 = agent(claims=claims, num_backoff=1)
        # logger.info("Backoff claim: %s", backoff_claim_1)
        backoff_claim_3 = agent(claims=claims, question=question, num_backoff=3)
        logger.info("Backoff claim: %s", backoff_claim_3)
        backoff_claim_5 = agent(claims=claims, question=question, num_backoff=5)
        # logger.info("Backoff claim: %s", backoff_claim_5)
        
        results.append({
            "id": datapoint['example_id'],
            "claims": claims,
            "backoff_claims": {
                # "b-1": backoff_claim_1,
                "b-3": backoff_claim_3,
                "b-5": backoff_claim_5
            }
        })
    
    with open(output_path, 'w', encoding='utf-8') as file_:
        json.dump(results, file_, indent=4)
    
    
if __name__ == "__main__":
    main()