from data import CausalDataset, MoralDataset, Example, Sentence, Annotation
from promptsource.templates import DatasetTemplates, Template

from data import FactorUtils

from typing import NamedTuple, Optional

############## Data Representation ###########

# assembled, question, answer
AbstractExample = NamedTuple("AbstractExample",
                             [("story", str), ("question", str), ('answer', str)])

############## Causal ###############

def causal_structure_translator(choice, causal_structure=None, norm_type=None):
    if choice == 'CS_1_Conjunctive':
        return "In this scenario, person A and person B both need to perform an action for the outcome to occur."
    elif choice == 'CS_1_Disjunctive':
        return "In this scenario, at least one out of two people A or person B needs to perform an action for the outcome to occur."


def agent_awareness_translator(choice, causal_structure=None, norm_type=None):
    if choice == 'AK_2_Agent_Unaware':
        return "A is unaware of the rule."
    elif choice == 'AK_2_Agent_Aware':
        return "A is aware of the rule."


def event_normality_translator(choice, causal_structure=None, norm_type=None):
    gen = ''
    if norm_type == 'Norm_3_Prescriptive_Norm':
        if choice == 'EC_4_Abnormal_Event':
            gen = "The rule is that A is not supposed to act. A violated this rule. "  # A's action is not supposed to happen.
        elif choice == 'EC_4_Normal_Event':
            gen = "The rule is that A is supposed to act. A did not violate this rule."
    elif norm_type == 'Norm_3_Statistics_Norm':
        if choice == 'EC_4_Abnormal_Event':
            gen = "The rule is that A is not supposed to act In fact, A rarely behaves this way, but this time A did."  # A's action is not supposed to happen.
        elif choice == 'EC_4_Normal_Event':
            gen = "The rule is that A is supposed to act. A often behaves this way, and this time A behaved as usual."

    if causal_structure is not None:
        # add B's behavior
        # Note in our stories, this behavior is "fixed" -- B is never violating the rules.
        # So this translation does not generalize beyond our fixed set of stories
        if norm_type == 'Norm_3_Prescriptive_Norm':
            gen += " B acts in a way that they were supposed to."
        elif norm_type == 'Norm_3_Statistics_Norm':
            gen += " B often acts this way, and this time B acts as usual."

    return gen


def action_omission_translator(choice, causal_structure=None, norm_type=None):
    if causal_structure is not None:
        if choice == 'EC_4_Action':
            return "Both A and B acted."
        elif choice == 'EC_4_Omission':
            return 'B acted. A did not act.'
    else:
        # if it's None, it means there's no B in this story
        if choice == 'EC_4_Action':
            return "A acted."
        elif choice == 'EC_4_Omission':
            return 'A did not act.'


def time_translator(choice, causal_structure=None, norm_type=None):
    if causal_structure is not None:
        if choice == 'EC_4_Same_Time_Cause':
            return "A's and B's behavior happened at the same time right before the outcome occurred."
        elif choice == 'EC_4_Early_Cause':
            return "A's behavior happened first. B's behavior occurred right before the outcome occurred."
        elif choice == 'EC_4_Late_Cause':
            return "B's behavior happened first. A's behavior occurred right before the outcome occurred."
    else:
        raise Exception("you shouldn't be here...")


class Translator(object):
    # @staticmethod
    # def translate():

    @staticmethod
    def translate_example(ex: Example) -> AbstractExample:
        raise NotImplementedError


class CausalTranslator(Translator):

    @staticmethod
    def translate_example(ex: Example) -> AbstractExample:
        annotated_sents = ex.annotated_sentences

        assert len(annotated_sents) > 0, "This example does not fit our annotation system"

        # it's a one-off check on whether we have causal struct
        causal_struct = [s for s in annotated_sents if s.annotation.factor == 'causal_structure']
        if len(causal_struct) != 0:
            causal_struct = causal_struct[0]
        norm_type = [s.annotation.value for s in annotated_sents if s.annotation.factor == 'norm_type']
        if len(norm_type) != 0:
            norm_type = norm_type[0]
        snippets = {}
        for s in annotated_sents:
            if s.annotation.factor == 'norm_type':
                continue
            snippet = eval(s.annotation.factor + '_translator')(s.annotation.value, causal_struct, norm_type)
            snippets[s.annotation.factor] = snippet

        assembled = ""
        for cate in ['causal_structure', 'agent_awareness', 'action_omission', 'event_normality', 'time']:
            if cate in snippets:
                assembled += snippets[cate] + ' '

        outro = "The outcome occurred."
        assembled = assembled.strip() + ' ' + outro  # + '\n'

        # Did A cause the situation?
        question = 'Did A cause the situation to occur?'

        return AbstractExample(assembled, question, ex.answer)


############## Moral ###############

def locus_of_intervention_translator(choice):
    choice = FactorUtils.locus_of_intervention_answers_map[choice]
    if choice == 'Patient of harm':
        return "In this moral dilemma, you are asked to intervene on the victim of harm, sacrificing another group of individuals instead."
    elif choice == 'Agent of harm':
        return "In this moral dilemma, you are asked to intervene on the instrument of harm (i.e., a runaway bus or hijacked airplane with people in it)."


def personal_force_translator(choice, beneficience, victim, add_emotion=False):
    choice = FactorUtils.personal_force_answers_map[choice]
    beneficience = FactorUtils.beneficiary_answers_map[beneficience]

    if beneficience == 'Self-beneficial':
        saved = "your life and others' lives"
    elif beneficience == 'Other-beneficial':
        saved = "other people's lives"

    if choice == "Personal":
        # a person
        if add_emotion:
            return f"You must use physical force to kill/subdue {victim}, making you feel responsible for their death."
        else:
            return f"If you decide to intervene, you will directly kill {victim} using physical force."
    elif choice == 'Impersonal':
        return f"If you decide to intervene, you will indirectly kill {victim} by pressing a button."  # this person


def causal_role_translator(choice, beneficience=None, victim=None):
    choice = FactorUtils.causal_role_answers_map[choice]
    beneficience = FactorUtils.beneficiary_answers_map[beneficience]

    if beneficience == 'Self-beneficial':
        saved = "your life and others' lives"
    elif beneficience == 'Other-beneficial':
        saved = "others' lives"

    # if choice == "Yes":  # Instrumental
    #     return f"If you decide to intervene, you must use this {remove_victm_quant(victim)}'s death to save {saved}."
    # elif choice == 'No':  # accidental
    #     return f"If you do decide to intervene, this {remove_victm_quant(victim)}'s death is an unfortunate consequence of your intervention."

    if choice == "Means":  # Instrumental
        return f"You must use this {remove_victm_quant(victim)}'s life to save {saved}."
    elif choice == 'Side Effect':  # accidental
        return f"This {remove_victm_quant(victim)}'s death is a foreseeable side-effect of saving {saved}."


def evitability_translator(choice, beneficience=None, victim=None):
    choice = FactorUtils.evitability_answers_map[choice]
    if choice == 'Avoidable':
        return f"This {remove_victm_quant(victim)} would have survived if you hadn't intervened."
    elif choice == 'Inevitable':
        return f"This {remove_victm_quant(victim)} would have died even if you hadn't intervened."


def beneficiary_translator(choice, beneficience=None, victim=None):
    choice = FactorUtils.beneficiary_answers_map[choice]
    if choice == 'Self-beneficial':
        return "This decision will determine whether you and another person will survive."
    elif choice == 'Other-beneficial':
        return "This decision will determine who out of a set of other people will survive."  # You are safe from any danger.

def rephrase_victim(victim):
    words = victim.split()
    if words[0] in ['the']:
        if words[1] in ['one', 'two', 'five', 'third']:
            return ' '.join(words[1:])
        else:
            return 'a ' + ' '.join(words[1:])
    return victim

def remove_victm_quant(victim):
    words = victim.split()
    if words[0] in ['one', 'two', 'five', 'third', 'a']:
        return ' '.join(words[1:])
    else:
        return victim

class MoralTranslator(Translator):

    @staticmethod
    def translate_example(ex: Example) -> AbstractExample:
        annotated_sents = ex.annotated_sentences
        assert len(annotated_sents) > 0, "This example does not fit our annotation system"

        beneficience = [s.annotation.value for s in annotated_sents if s.annotation.factor == 'beneficiary']
        locus_of_intervention = [s.annotation.value for s in annotated_sents if
                                 s.annotation.factor == 'locus_of_intervention']

        if len(locus_of_intervention) > 0:
            story = locus_of_intervention_translator(locus_of_intervention[0])
            question = 'Do you intervene?'
            return AbstractExample(story, question, ex.answer)

        victim = [s.victim for s in annotated_sents if s.victim is not None][0]
        victim = rephrase_victim(victim)
        # if len(victim.split(' ')) == 3:
        #     victim = victim.split(' ')[2]
        # if len(victim.split(' ')) == 2:
        #     victim = victim.split()[1]

        snippets = {}
        for s in annotated_sents:
            # mapped_value = eval(f"FactorUtils.{s.annotation.factor}_answers_map")[s.annotation.value]
            snippet = eval(s.annotation.factor + '_translator')(s.annotation.value, beneficience[0], victim)
            snippets[s.annotation.factor] = snippet

        assembled = ""
        for cate in ['beneficiary', 'personal_force', 'evitability', 'causal_role']:
            if cate in snippets:
                assembled += snippets[cate] + ' '

        intro = "In this moral dilemma, you are asked to make a difficult decision."
        assembled = intro + ' ' + assembled.strip()

        # If you have to choose who caused it, d
        question = 'Do you intervene?'

        return AbstractExample(assembled, question, ex.answer)

# ======= Utility code for moral demo =========

def make_locus_of_intervention_annotation(value='Agent of harm') -> Annotation:
    assert value in FactorUtils.locus_of_intervention_answers, \
        "Invalid value, please use one of the following: " + str(FactorUtils.locus_of_intervention_answers)
    value = FactorUtils.locus_of_intervention_answers_map_reverse[value.lower()]
    return Annotation('locus_of_intervention', value)

def make_fake_example_for_moral_1(annotation: Annotation) -> Example:
    annotated_sentences = []
    annotated_sentences.append(Sentence(text='', victim='', annotation=annotation))
    return Example(story='', question='', answer='No', transcribed_answer="No",
                   answer_dist=[0.5, 0.5],
                   annotated_sentences=annotated_sentences, is_ambiguous=False)

def make_personal_force_annotation(value='Personal') -> Annotation:
    assert value in FactorUtils.personal_force_answers, \
        "Invalid value, please use one of the following: " + str(FactorUtils.personal_force_answers)
    value = FactorUtils.personal_force_answers_map_reverse[value.lower()]
    return Annotation('personal_force', value)

def make_causal_role_annotation(value='Means') -> Annotation:
    assert value in FactorUtils.causal_role_answers, \
        "Invalid value, please use one of the following: " + str(FactorUtils.causal_role_answers)
    value = FactorUtils.causal_role_answers_map_reverse[value.lower()]
    return Annotation('causal_role', value)

def make_beneficiary_annotation(value='Self-beneficial') -> Annotation:
    assert value in FactorUtils.beneficiary_answers, \
        "Invalid value, please use one of the following: " + str(FactorUtils.beneficiary_answers)
    value = FactorUtils.beneficiary_answers_map_reverse[value.lower()]
    return Annotation('beneficiary', value)

def make_evitability_annotation(value='Avoidable') -> Annotation:
    assert value in FactorUtils.evitability_answers, \
        "Invalid value, please use one of the following: " + str(FactorUtils.evitability_answers)
    value = FactorUtils.evitability_answers_map_reverse[value.lower()]
    return Annotation('evitability', value)

def make_fake_example_for_moral_2(personal_force_annotation: Annotation,
                                    causal_role_annotation: Annotation,
                                    beneficiary_annotation: Annotation,
                                    evitability_annotation: Annotation,
                                    victim: str) -> Example:
    annotated_sentences = []
    annotated_sentences.append(Sentence(text='', victim=victim, annotation=personal_force_annotation))
    annotated_sentences.append(Sentence(text='', victim=victim, annotation=causal_role_annotation))
    annotated_sentences.append(Sentence(text='', victim=victim, annotation=beneficiary_annotation))
    annotated_sentences.append(Sentence(text='', victim=victim, annotation=evitability_annotation))

    return Example(story='', question='', answer='No',
                   transcribed_answer="No",
                   answer_dist=[0.5, 0.5],
                   annotated_sentences=annotated_sentences,
                   is_ambiguous=False)

# ======= Utility code for causal demo =========

def make_causal_structure_annotation(value='Conjunctive') -> Annotation:
    assert value in FactorUtils.causal_structure_answers, \
        "Invalid value, please use one of the following: " + str(FactorUtils.causal_structure_answers)
    value = FactorUtils.causal_structure_answers_map_reverse[value.lower()]
    return Annotation('causal_structure', value)

def make_agent_awareness_annotation(value='Aware') -> Annotation:
    assert value in FactorUtils.agent_awareness_answers, \
        "Invalid value, please use one of the following: " + str(FactorUtils.agent_awareness_answers)
    value = FactorUtils.agent_awareness_answers_map_reverse[value.lower()]
    return Annotation('agent_awareness', value)

def make_event_normality_annotation(value='Normal') -> Annotation:
    assert value in FactorUtils.event_normality_answers, \
        "Invalid value, please use one of the following: " + str(FactorUtils.event_normality_answers)
    value = FactorUtils.event_normality_answers_map_reverse[value.lower()]
    return Annotation('event_normality', value)

def make_action_omission_annotation(value='Omission') -> Annotation:
    assert value in FactorUtils.action_omission_answers, \
        "Invalid value, please use one of the following: " + str(FactorUtils.action_omission_answers)
    value = FactorUtils.action_omission_answers_map_reverse[value.lower()]
    return Annotation('action_omission', value)

def make_time_annotation(value='Same Time Cause') -> Annotation:
    assert value in FactorUtils.time_answers, \
        "Invalid value, please use one of the following: " + str(FactorUtils.time_answers)
    value = FactorUtils.time_answers_map_reverse[value.lower()]
    return Annotation('time', value)

def make_norm_type_annotation(value='Prescriptive Norm') -> Annotation:
    assert value in FactorUtils.norm_type_answers, \
        "Invalid value, please use one of the following: " + str(FactorUtils.norm_type_answers)
    value = FactorUtils.norm_type_answers_map_reverse[value.lower()]
    return Annotation('norm_type', value)

def make_fake_example_for_causal(event_normality_annotation: Annotation,
                                causal_structure_annotation: Optional[Annotation]=None,
                                agent_awareness_annotation: Optional[Annotation]=None,
                                action_omission_annotation: Optional[Annotation]=None,
                                time_annotation: Optional[Annotation]=None,
                                norm_type_annotation: Optional[Annotation]=None) -> Example:
    annotated_sentences = []
    if causal_structure_annotation is not None:
        annotated_sentences.append(Sentence(text='', victim='', annotation=causal_structure_annotation))
    if agent_awareness_annotation is not None:
        annotated_sentences.append(Sentence(text='', victim='', annotation=agent_awareness_annotation))
    annotated_sentences.append(Sentence(text='', victim='', annotation=event_normality_annotation))
    if action_omission_annotation is not None:
        annotated_sentences.append(Sentence(text='', victim='', annotation=action_omission_annotation))
    if time_annotation is not None:
        annotated_sentences.append(Sentence(text='', victim='', annotation=time_annotation))
    if norm_type_annotation is not None:
        annotated_sentences.append(Sentence(text='', victim='', annotation=norm_type_annotation))

    return Example(story='', question='', answer='No',
                   transcribed_answer="No",
                   answer_dist=[0.5, 0.5],
                   annotated_sentences=annotated_sentences,
                   is_ambiguous=False)

if __name__ == '__main__':
    pass
    # cd = CausalDataset()
    # assembled, question, answer = CausalTranslator.translate_example(cd[0])
    # from prompt import CausalAbstractJudgmentPrompt
    #
    # ajp = CausalAbstractJudgmentPrompt()
    # print(cd[0].story)
    # print(cd[0].question)
    # print(ajp.apply(assembled, question, answer).prompt)
    # print(answer)
    #
    # assembled, question, answer = CausalTranslator.translate_example(cd[44])
    #
    # ajp = CausalAbstractJudgmentPromopt()
    # print(cd[44].story)
    # print(cd[44].question)
    # print(ajp.apply(assembled, question, answer).prompt)
    # print(answer)
    #
    # assembled, question, answer = CausalTranslator.translate_example(cd[73])
    #
    # ajp = CausalAbstractJudgmentPromopt()
    # print(cd[73].story)
    # print(cd[73].question)
    # print(ajp.apply(assembled, question, answer).prompt)
    # print(answer)

    md = MoralDataset()

    from prompt import MoralAbstractJudgmentPrompt

    # assembled, question, answer = MoralTranslator.translate_example(md[0])
    #
    # ajp = MoralAbstractJudgmentPrompt()
    # print(ajp.apply(assembled, question, answer).prompt)

    # print(md[10].story)
    # print(md[10].question)

    for p in md[10].annotated_sentences:
        print(p.text)
        print(p.annotation.factor, p.annotation.value)

    assembled, question, answer = MoralTranslator.translate_example(md[10])

    ajp = MoralAbstractJudgmentPrompt()
    print(ajp.apply2(assembled, question, answer).prompt)
    print(answer)

    assembled, question, answer = MoralTranslator.translate_example(md[26])

    # print(md[26].story)
    # print(md[26].question)

    ajp = MoralAbstractJudgmentPrompt()
    print(ajp.apply2(assembled, question, answer).prompt)
    print(answer)

    assembled, question, answer = MoralTranslator.translate_example(md[40])

    # print(md[40].story)
    # print(md[40].question)

    ajp = MoralAbstractJudgmentPrompt()
    print(ajp.apply2(assembled, question, answer).prompt)
    print(answer)
