#!/usr/bin/env python
# coding: utf-8
from collections import Counter
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import random

from sklearn.metrics import classification_report
from decomposition import Decomposition, get_args
from utils import get_response, InputOutputPrompt

##############################################################################################################################
# All prompts
questioner_prompt = InputOutputPrompt(
    input_formatter=lambda x: f"Statement: {x['statement']}",
    output_formatter=lambda x: f"Question: {x['question']}",
    required_keys=["question", "statement"],
    input_output_sep="\n",
    example_sep="\n\n",
    instruction="Rewrite the statement as a yes/no question.\n\n"
)
questioner_prompt_examples = [
    pd.DataFrame([
        {
            "statement": "most of the light comes from the sun",
            "question": "Does most of the light come from the sun?"
        },
        {
            "statement": "the test was not hard",
            "question": "Was the test not hard?",
        },
        {
            "statement": "it is a good idea to buy your parents gifts",
            "question": "Is it a good idea to buy your parents gifts?",
        },
        {
            "statement": "the balloon popped",
            "question": "Did the balloon pop?",
        },
        {
            "statement": "The father and son went camping to California.",
            "question": "Did the father and son go camping?",
        },
    ]),
    pd.DataFrame([
        {
            "statement": "most of the light comes from the sun",
            "question": "Does most of the light come from the sun?"
        },
        {
            "statement": "the test was not",
            "question": "Was the test not hard?",
        },
        {
            "statement": "it is a good idea to buy your parents gifts",
            "question": "Is it a good idea to buy your parents gifts?",
        },
        {
            "statement": "the balloon popped",
            "question": "Did the balloon pop?",
        },
        {
            "statement": "The father and son went camping to California.",
            "question": "Did the father and son go camping?",
        },
    ]),
    pd.DataFrame([
        {
            "statement": "most of the light comes from the sun",
            "question": "Does most of the light come from the sun?"
        },
        {
            "statement": "the test was not hard",
            "question": "Was the test not hard?",
        },
        {
            "statement": "it is a good idea to buy your parents gifts",
            "question": "Is it a good idea to buy your parents gifts?",
        },
        {
            "statement": "the balloon popped",
            "question": "Did the balloon pop?",
        },
        {
            "statement": "The father and son went camping to California.",
            "question": "Did the father and son go camping?",
        },
    ]),
    pd.DataFrame([
        {
            "statement": "most of the light comes from the sun",
            "question": "Does most of the light come from the sun?"
        },
        {
            "statement": "the test was not hard",
            "question": "Was the test not hard?",
        },
        {
            "statement": "it is a good idea to buy your parents gifts",
            "question": "Is it a good idea to buy your parents gifts?",
        },
        {
            "statement": "the balloon popped",
            "question": "Did the balloon pop?",
        },
        {
            "statement": "The father and son went camping to California.",
            "question": "Did the father and son go camping?",
        },
    ]),
    pd.DataFrame([
        {
            "statement": "most of the light comes from the sun",
            "question": "Does most of the light come from the sun?"
        },
        {
            "statement": "the test was not hard",
            "question": "Was the test not hard?",
        },
        {
            "statement": "it is a good idea to buy your parents gifts",
            "question": "Is it a good idea to buy your parents gifts?",
        },
        {
            "statement": "the balloon popped",
            "question": "Did the balloon pop?",
        },
        {
            "statement": "The father and son went camping to California.",
            "question": "Did the father and son go camping?",
        },
    ]),
]

extraction_qa = InputOutputPrompt(
    input_formatter=lambda x: f"Context: {x['context']}\nQuestion: {x['question']}",
    output_formatter=lambda x: f"Answer: {x['answer']}",
    required_keys=["context", "question", "answer"],
    input_output_sep="\n",
    example_sep="\n\n",
    instruction="Answer the question. If there is no evidence in the context, return \"Unknown\".\n\n"
)
extraction_qa_examples = [
    pd.DataFrame([
        {
            "context": "According to Biraben, the plague was present somewhere in Italy and affected 1,200 people.",
            "question": "Based on the context, Did the plague affect people in Europe?",
            "answer": "yes, people in Italy, Europe",
        },
        {
            "context": "Policies aiming at controlling unemployment and in particular at reducing its inequality-associated effects support economic growth.",
            "question": "Based on the context, Is confidence a factor in increasing self-esteem?",
            "answer": "unknown",
        },
        {
            "context": "The term \"matter\" is used throughout physics in a bewildering variety of contexts: for example, one refers to \"condensed matter physics\", \"elementary matter\", \"partonic\" matter, \"dark\" matter, \"anti\"-matter, \"strange\" matter, and \"nuclear\" matter.",
            "question": "Based on the context, Is anti-matter made of electrons? ",
            "answer": "Unknown",
        },
    ]),
    pd.DataFrame([
        {
            "context": "According to Biraben, the plague was present somewhere in Italy only between 1346 and 1671, and not after that.",
            "question": "Based on the context, Was the plague present in Italy during the 2000s?",
            "answer": "No, it was present between 1346 and 1671"
        },
        {
            "context": "The term \"matter\" is used throughout physics in a bewildering variety of contexts: for example, one refers to \"condensed matter physics\", \"elementary matter\", \"partonic\" matter, \"dark\" matter, \"anti\"-matter, \"strange\" matter, and \"nuclear\" matter.",
            "question": "Based on the context, Is anti-matter made of electrons? ",
            "answer": "Unknown"
        },
        {
            "context": "Policies aiming at controlling unemployment and in particular at reducing its inequality-associated effects support economic growth.",
            "question": "Based on the context, Is confidence a factor in increasing self-esteem?",
            "answer": "unknown"
        }
    ]),
    pd.DataFrame([
        {
            "context": "Jenna's 10th birthday was yesterday evening and at least 10 of her friends attended the party.",
            "question": "Based on the context, Did 10 friends attend Jenna's party?",
            "answer": "Unknown"
        },
        {
            "context": "The bullies attacked John when he was walking through the elementary school parking lot and then got sent to the teacher's office.",
            "question": "Based on the context, Did the bullies attack John in the teacher's office?",
            "answer": "No, parking lot"
        },
        {
            "context": "WISS discovered a new monkey disease occurring in a remote tribe in the Amazon rainforrest.",
            "question": "Based on the context, Did WISS discover a new monkey species?",
            "answer": "No, a new monkey disease"
        }
    ]),
    pd.DataFrame([
        {
            "context": "When Judy and Jack went to school, they got in trouble with their teacher for being late. I didn't think it was very fair.",
            "question": "Based on the context, Did she think it was fair?",
            "answer": "No"
        },
        {
            "context": "If inflation is occurring, leading to higher prices for basic necessities such as gas by 2 dollars. Do you think that inflation is good for society?",
            "question": "Based on the context, Is inflation good for society?",
            "answer": "Unknown"
        },
        {
            "context": "Put yourself out there. The more time you spend dating and socializing, the more likely you will find a boyfriend you like.",
            "question": "Based on the context, Does socializing help you find a boyfriend?",
            "answer": "Yes"
        },
        {
            "context": "According to Biraben, the plague was present somewhere in Italy and affected 1,200 people.",
            "question": "Based on the context, Did the plague affect people in Europe?",
            "answer": "yes, people in Italy, Europe",
        },
        {
            "context": "Policies aiming at controlling unemployment and in particular at reducing its inequality-associated effects support economic growth.",
            "question": "Based on the context, Is confidence a factor in increasing self-esteem?",
            "answer": "unknown",
        },
        {
            "context": "The term \"matter\" is used throughout physics in a bewildering variety of contexts: for example, one refers to \"condensed matter physics\", \"elementary matter\", \"partonic\" matter, \"dark\" matter, \"anti\"-matter, \"strange\" matter, and \"nuclear\" matter.",
            "question": "Based on the context, Is anti-matter made of electrons? ",
            "answer": "Unknown",
        },
    ]),
    pd.DataFrame([
         {
            "context": "According to Biraben, the plague was present somewhere in Italy and affected 1,200 people.",
            "question": "Based on the context, Did the plague affect over 1,000 people?",
            "answer": "yes, 1,200 people",
        },
        {
            "context": "If inflation is occurring, leading to higher prices for basic necessities such as gas by 2 dollars. Do you think that inflation is good for society?",
            "question": "Based on the context, Is inflation good for society?",
            "answer": "Unknown"
        },
        {
            "context": "Policies aiming at controlling unemployment and in particular at reducing its inequality-associated effects support economic growth.",
            "question": "Based on the context, Is confidence a factor in increasing self-esteem?",
            "answer": "unknown"
        }
    ]),
]
######## ICL Examples
icl_1 = "Is the statement Yes, No, or Unknown?\n\nStatement: Humic acids are organic.\nAnswer: yes\n\nStatement: He has never won anything.\nAnswer: no\n\nStatement: Iran Foreign Minister met today with the new ambassadors\nAnswer: no\n\nStatement: Darnell Donerson was in his 60's\nAnswer: no\n\nStatement: Anorexia causes severe health problems.\nAnswer: yes\n\nStatement: What stores you shopped at is not important in the long run.\nAnswer: yes\n\nStatement: Ecevit is corrupt.\nAnswer: unknown\n\nStatement: Gray was the same age as the narrator\nAnswer: unknown\n\nStatement: Canadians want China to accept Judaism.\nAnswer: unknown\n\nStatement: The minister is a female.\nAnswer: yes\n\nStatement: Zinedine Zidane thought she would never win\nAnswer: unknown\n\nStatement: The House has more than 53 honorable members.\nAnswer: unknown\n\nStatement: November has less than 9 days.\nAnswer: unknown\n\nStatement: The bill is being thrust into law against the telemarketers.\nAnswer: unknown\n\nStatement: A knife if not, not used to cut a watermelon.\nAnswer: unknown\n\nStatement: A study in Afghanistan found that when family members of all different ages live together, the differences in their sleeping patterns ensures that at least one person is awake, or sleeping very lightly, at all times.\nAnswer: no\n\nStatement: Rivington had a population on 200 in 2001\nAnswer: no\n\nStatement: lifting weights and starving herself allowed Jessica to add curves to her body\nAnswer: no\n\nStatement: US Steel is also considering other methods, not just this one.\nAnswer: unknown\n\nStatement: Nobody witnessed the plane crash.\nAnswer: unknown\n\nStatement: To feed a newborn puppy you will need a bottle made specifically for a newborn cat.\nAnswer: no\n\nStatement: Joanne Peters uses Obamacare.\nAnswer: unknown\n\nStatement: Goodbye Mr. Black is non american\nAnswer: yes\n\nStatement: Steven learned to swim by joining a swim team\nAnswer: no\n\nStatement: William McGirt started three strokes behind Hoffman\nAnswer: no\n\nStatement: You should try to ask in an appropriate time\nAnswer: yes\n\nStatement: Harlem Beat was not the original title\nAnswer: unknown\n\nStatement: Michael Graham is a conservative.\nAnswer: yes\n\nStatement: There are less than a gross of Quick Chek stores in New York and New Jersey.\nAnswer: yes\n\nStatement: Sleep died before he retired.\nAnswer: unknown\n\nStatement: An air conditioner prevents people from getting hot.\nAnswer: yes\n\nStatement: Britain voted to leave the European Union.\nAnswer: yes\n\nStatement: Saddam Hussein is a good friend to Israel.\nAnswer: no\n\nStatement: The deal completed on Thursday.\nAnswer: no\n\nStatement: According to the patent abstract, there were multiple chocolate chip cookie signals at different frequencies.\nAnswer: no\n\nStatement: Jakob Funnell scored five more points in the second game than in the third game.\nAnswer: yes\n\nStatement: The Queen Street Waterfront Marathon will close roads.\nAnswer: unknown\n\nStatement: It took more than 24 hours to exterminate the spiders\nAnswer: yes\n\nStatement: Scientologists believe that drugs played a part in terrorist attacks.\nAnswer: yes\n\nStatement: Jennings is the Yankees' best pitcher\nAnswer: unknown\n\nStatement: The patent device will only be sold in select stores when released.\nAnswer: unknown\n\nStatement: Cleaning bearings on a skateboard is a very quick job.\nAnswer: unknown\n\nStatement: Former Prime Minister Rafik Hariri was most well known for his anti-Syria activities.\nAnswer: unknown\n\nStatement: she thought they could sell the corncob pipes\nAnswer: yes\n\nStatement: Clint never played games from the 1980s.\nAnswer: no\n\nStatement: Comparing yourself was used by fiona\nAnswer: unknown\n\nStatement: Basil Rathbone and Nigel Bruce did not work together\nAnswer: no\n\nStatement: The best thing you can do is be rude\nAnswer: no\n\nStatement: The $43 billion dollars come from American workers.\nAnswer: unknown\n\nStatement: the agent is aged 25\nAnswer: unknown"
icl_2 =  "Is the statement Yes, No, or Unknown?\n\nStatement: I did no work whatsoever yesterday.\nAnswer: no\n\nStatement: Jamaican people really like sand\nAnswer: unknown\n\nStatement: The Mt. Kinka Ropeway is the most terrifying lift line to take\nAnswer: unknown\n\nStatement: Some people may have misconceptions about the curriculum changes.\nAnswer: yes\n\nStatement: Chad went back to the store to pick up the part.\nAnswer: unknown\n\nStatement: Jemele Hill has a new job at the Quincy Harris Morning Show.\nAnswer: no\n\nStatement: I tried to wear the mouth guard but relinquished to the hurt it was causing.\nAnswer: yes\n\nStatement: Sole Technology is a South American footwear company.\nAnswer: no\n\nStatement: they were wearing shirts\nAnswer: no\n\nStatement: They travelled bu Uber.\nAnswer: unknown\n\nStatement: Hannah was successful from the first video she submitted to Youtube\nAnswer: no\n\nStatement: The central hall cannot hold 10,015 people for basketball games.\nAnswer: yes\n\nStatement: Anorexia in males accounted for 1/2 of all cases.\nAnswer: no\n\nStatement: it was a 25th anniversary.\nAnswer: yes\n\nStatement: Dr. William Keil had a home built by the KKK in the mid eighteenth century that is still standing in the county of Elim in Missouri.\nAnswer: no\n\nStatement: federal government has taken hundreds of millions of dollars from the failing economy of Quebec\nAnswer: unknown\n\nStatement: it is easy to change schools during the year\nAnswer: unknown\n\nStatement: restrict personal information to maintain anonymity\nAnswer: yes\n\nStatement: Healthy kids think they don't need health care.\nAnswer: yes\n\nStatement: The article in question is taken as the truth.\nAnswer: no\n\nStatement: an unlimited time would violate the Copyright Clause, making people happy\nAnswer: unknown\n\nStatement: matthew was dismissive in his criticism\nAnswer: no\n\nStatement: Abhisit Vejjajiva has never been Prime Minister\nAnswer: no\n\nStatement: George Bush is the president\nAnswer: unknown\n\nStatement: Lindworms eat small eagles.\nAnswer: yes\n\nStatement: the collective agreement apply is the only settlement mechanism.\nAnswer: no\n\nStatement:  you can wind wire around a cylindrical  with a pencil cell.\nAnswer: yes\n\nStatement: apples are the second worst choice, after a banana\nAnswer: unknown\n\nStatement: Turnix worcesteri have black eyes.\nAnswer: unknown\n\nStatement: Schlereth now works for the NFL.\nAnswer: no\n\nStatement: The publishers are the best witness for the committee to refer to regarding this bill.\nAnswer: unknown\n\nStatement: Leila Meskhi was ranked World No. 13 in August of 1991.\nAnswer: no\n\nStatement: Government can refuses things.\nAnswer: yes\n\nStatement: HTnaturals is not based in the US\nAnswer: yes\n\nStatement: The Treasury Board is appointed by the Governor in Council.\nAnswer: unknown\n\nStatement: The tournament was played each day from May 23 through 26.\nAnswer: unknown\n\nStatement: Scot stevens is an example of a meteorologist.\nAnswer: yes\n\nStatement: Pascal's law was invented by earl\nAnswer: no\n\nStatement: Benzema starting playing football at age 13.\nAnswer: unknown\n\nStatement: More than 3 men were accused of hacking and fraud schemes.\nAnswer: no\n\nStatement: To do a valdez, you need your left toe pointing straight ahead\nAnswer: yes\n\nStatement: The Prime Minister cares about the peoples needs.\nAnswer: no\n\nStatement: Mexico officials do not use twitter\nAnswer: no\n\nStatement: most tenants are females\nAnswer: unknown\n\nStatement: Moody's did not rate the Class B Notes.\nAnswer: yes\n\nStatement: The World Bank has had a part to play in the destruction of the nature.\nAnswer: yes\n\nStatement: Stefano Modena has 5 pet spiders\nAnswer: unknown\n\nStatement: Tim had the cat for a few hours.\nAnswer: unknown\n\nStatement: Hendrik would have been 91 on his next birthday had he not passed away.\nAnswer: yes\n\nStatement: Senator Spencer Abraham is concerned with racial profiling.\nAnswer: yes"
icl_3 =  "Is the statement Yes, No, or Unknown?\n\nStatement: Canada does not have any fisheries\nAnswer: no\n\nStatement: The cows were drinking from the stream\nAnswer: unknown\n\nStatement: When we were making the case for Triumph, the $167.5 million grant to the University of British Columbia base research in folic physics and particle physics, one had to explain what this was about. Particle physics is a very popular area of study.\nAnswer: unknown\n\nStatement: The specialist caused pain to Dan's daughter when she ripped the hair off.\nAnswer: yes\n\nStatement:  hooded scarfs are popular in spain\nAnswer: unknown\n\nStatement: The federal government is not talking about ways to interest young people in science.\nAnswer: no\n\nStatement: The modern day career hunt is full of competition.\nAnswer: yes\n\nStatement: The Extra Girl was released after Korea split in two\nAnswer: no\n\nStatement: Snell has given up 10 runs at least once in a start in his 27 trips to the mound.\nAnswer: no\n\nStatement: East Tennessee State University’s campus is next to Gilbreath Drive.\nAnswer: unknown\n\nStatement: Tenzin Gyatso is the 13th Dalai Lama\nAnswer: no\n\nStatement: Richardson is a designer of buildings that are well known.\nAnswer: yes\n\nStatement: The organizations merged in 2016\nAnswer: no\n\nStatement: Yūki Mizuhara is a woman\nAnswer: yes\n\nStatement: In the Name of the Daddi-o an Academy Award contender starring Daniel Day-Lewis as a man wrongly convicted for an IRA bombing.\nAnswer: no\n\nStatement: A basic lifestyle changes can significantly reduce your risk of developing the disease is to begin taking philosophy classes.\nAnswer: unknown\n\nStatement: Authors Anonymous received many awards.\nAnswer: unknown\n\nStatement:  3/9 Motel,is causing commotion all over Japan\nAnswer: yes\n\nStatement: Abel Ferrara has heard of Jack Finney before\nAnswer: yes\n\nStatement: Everyone agrees with the B.C. Court of Appeal, which struck down subsection 163(1)(4) of the criminal code that makes possession of child pornography a criminal offence.\nAnswer: no\n\nStatement: There are many more articles on this topic.\nAnswer: unknown\n\nStatement: You don't need to recognize that you're angry during road rage.\nAnswer: no\n\nStatement: Bob throws the flowers in water down in the vase and put leaves around it.\nAnswer: no\n\nStatement: Zinedine Zidane knew she was better\nAnswer: unknown\n\nStatement: Adolfo Bioy Casares. does not think a person can have a nightmare and then stop themselves from writing about it\nAnswer: yes\n\nStatement: The industrial espionage case involving GM and VW began with the banning of Jos Ignacio Lopez, an employee of GM subsidiary Adam Opel\nAnswer: no\n\nStatement: Headaches that occur during or after flight are usually due to a sudden change in altitude.\nAnswer: yes\n\nStatement: Lee Foss was born in Chicago\nAnswer: unknown\n\nStatement: the city builds public housing\nAnswer: unknown\n\nStatement: Trump ordered a review of the U.S. visa programme because Abraham Lincoln told him to.\nAnswer: no\n\nStatement: Wepawaug Drive is a cul-de-sac\nAnswer: unknown\n\nStatement: Hollis Milton is standing to the right of Bill Cassidy.\nAnswer: no\n\nStatement: it needs to be asked what is in the best interest of america's people\nAnswer: yes\n\nStatement: health-care coverage is applied to children\nAnswer: yes\n\nStatement: Claudia Lawrence was in her 20's.\nAnswer: unknown\n\nStatement: Abu Eisa al-Hindi is from Pakistan.\nAnswer: unknown\n\nStatement: Douglas Coupland's initials would be DC.\nAnswer: yes\n\nStatement: In the aftermath of the Warm War.\nAnswer: no\n\nStatement:  A 25-year-old man Narendra Kumar from the Lakhimpur district in Uttar Pradesh has been detained at the Kudankulam Nuclear Plant in Tamil Nadu on Tuesday evening. The man stated that he was there with the intention of making the world a better place.\nAnswer: unknown\n\nStatement: It may be possible to implement prohibition in more harmful ways.\nAnswer: no\n\nStatement: The word physician is mentioned multiple times.\nAnswer: yes\n\nStatement: The President of Ghana was created by the Star and Eagles of Ghana.\nAnswer: no\n\nStatement: Mbeki is stepping back his mediation efforts in order to heighten it's intervention in the West African Nation.\nAnswer: no\n\nStatement: The festival is not popular\nAnswer: unknown\n\nStatement: Henrique Meirelles is being accused of several wrong doings.\nAnswer: yes\n\nStatement: They wore the badge of Lord Risingham\nAnswer: yes\n\nStatement: Jacksonville has won the NCAA Division I Baseball Tournament.\nAnswer: unknown\n\nStatement: Emails can't always be deleted in yahoo\nAnswer: unknown\n\nStatement: You should ask the doctor to recommend a physical therapist who has extensive experience.\nAnswer: yes\n\nStatement: welts develop as itchy or stinging areas on your skin, and may turn into hives\nAnswer: yes"
##############################################################################################################################


class ANLIDecomp(Decomposition):
    def __init__(self, task_name, data_dir, val_split="validation"):
        super().__init__(task_name, data_dir, val_split)

    def get_few_shot_examples(self, train_data, k_shot):
        """Get few shot examples"""
        labels = set(train_data["targets_pretokenized"])
        num_per_class = int(np.ceil(k_shot / len(labels)))
        print(f"Selecting {num_per_class} examples per class.")

        dfs = []
        total_in_context = 0
        for label in labels:
            while num_per_class + total_in_context > k_shot:
                num_per_class -= 1
            sub_df = train_data[train_data["targets_pretokenized"] == label].sample(
                num_per_class
            )
            dfs.append(sub_df)
            total_in_context += num_per_class
            if total_in_context == k_shot:
                break
        mini_df = pd.concat(dfs)
        return mini_df

    def _get_boost_decomp_examples(self, train_data, boost_id):
        seed = [6, 69, 987][boost_id] 
        k_shot = 64
        random.seed(seed)
        np.random.seed(seed)

        data_train = pd.DataFrame(train_data)
        labels = set(data_train["targets_pretokenized"])
        num_per_class = int(np.ceil(k_shot / len(labels)))

        dfs = []
        total_in_context = 0
        for label in labels:
            while num_per_class + total_in_context > k_shot:
                num_per_class -= 1
            if seed % 2 == 1:
                sub_df = data_train[data_train["targets_pretokenized"] == label].sample(num_per_class, random_state = seed)
            elif seed % 2 == 0:
                sub_df = data_train[data_train["targets_pretokenized"] != label].sample(num_per_class, random_state = seed)
            dfs.append(sub_df)
            total_in_context += num_per_class
            if total_in_context == k_shot:
                break

        booster_df = pd.concat(dfs) #.sample(frac=1, random_state=0)

        print(f"Selected: {len(booster_df)} in context examples.")
        return [
            booster_df
        ]

    def get_boost_decomp_examples(self, train_data, boost_id):
        return [
            questioner_prompt_examples[boost_id],
            extraction_qa_examples[boost_id],
        ]

    def zero_few_baseline(
        self,
        test_data,
        few_shot_df,
        manifest,
        overwrite_manifest,
        do_few_shot=True,
    ):
        expt_log = {}
        golds = []
        preds = []

        labels = set(test_data["targets_pretokenized"])
        labels = [l.lower().strip() for l in labels]

        for i, (ind, row) in tqdm(
            enumerate(test_data.iterrows()), total=len(test_data)
        ):
            if ind in expt_log:
                pred = expt_log[ind]["pred"]
                gold = expt_log[ind]["gold"]
            else:
                icl_str = ""

                if do_few_shot:
                    for s_ind, s_row in few_shot_df.iterrows():
                        icl_str += f"{s_row['inputs_pretokenized']}{s_row['targets_pretokenized']}\n\n"

                text = row["inputs_pretokenized"]
                text = text.replace("True, False, or Neither?", "").strip().strip("\n")
                text = text + " True, False, or Neither? "
                gold = row["targets_pretokenized"]
                prompt = f"{icl_str}{{text:}}"
                pmp = prompt.format(text=text)
                if i == 0:
                    print(pmp)

                raw_answer = get_response(
                    pmp,
                    manifest,
                    overwrite=bool(overwrite_manifest),
                    max_toks=20,
                )

                answer = raw_answer.strip().lower()
                answer = answer.split("\n")
                answer = [a for a in answer if a]
                answer = [
                    a for a in answer if any(l.lower() in a.lower() for l in labels)
                ]
                if answer:
                    answer = answer[0]
                else:
                    answer = ""
                answer = "".join(
                    [a for a in answer if a not in [".", ",", "?", ";", ":", "'", '"']]
                )
                is_yes = "true" in answer.split()
                is_no = "false" in answer.split()
                is_maybe = "neither" in answer.split()
                pred = "Neither"
                if is_yes and (not is_maybe and not is_no):
                    pred = "True"
                if is_no and (not is_maybe and not is_yes):
                    pred = "False"
                if is_maybe and (not is_no and not is_yes):
                    pred = "Neither"

                gold = gold.strip().lower()
                pred = pred.strip().lower()
                entry = {
                    "ind": ind,
                    "example": text,
                    "base_prompt": pmp,
                    "raw_answer": raw_answer,
                    "pred": pred,
                    "gold": gold,
                }
                expt_log[ind] = entry

            golds.append(gold)
            preds.append(pred)

        report = classification_report(golds, preds, output_dict=True)
        return expt_log, report["accuracy"]

    def get_extraction(self, question, passage, prompt, boost_ex, manifest, overwrite_manifest):
        prompt_suffix = prompt(boost_ex)
        if "Based on the context," in prompt_suffix:
            question_prefix = " Based on the context,"
        else:
            question_prefix = ""
        extract_prompt = f"{prompt_suffix}\n\nContext: {{passage:}}\nQuestion:{question_prefix} {question}\nAnswer:"
        extract_pmp = extract_prompt.format(passage=passage)
        answer = get_response(
            extract_pmp,
            manifest,
            overwrite=bool(overwrite_manifest),
            max_toks=50,
        )
        answer = answer.replace(",", "").replace(".", "").replace("?", "")
        answer = [a for a in answer.split("\n") if a]
        if answer:
            answer = answer[0]
        else:
            answer = passage
        return answer, extract_pmp

    def get_question(self, statement, prompt, boost_ex, manifest, overwrite_manifest):
        prompt_suffix = prompt(boost_ex)
        question_prompt = f"{prompt_suffix}\n\nStatement: {{statement:}}\nQuestion:"
        question_pmp = question_prompt.format(statement=statement)
        answer = get_response(
            question_pmp,
            manifest,
            overwrite=bool(overwrite_manifest),
            max_toks=50,
        )
        answer = answer.replace("Question: ", "")
        answer = [a for a in answer.split("\n") if a]
        if answer:
            answer = answer[0].strip()
        else:
            answer = ''
        statement = statement.strip().strip(".")
        if (
            not answer
            or statement.lower() == answer.lower()
            or not answer.strip().endswith("?")
        ):
            answer = f"{statement}. Yes, no, or unknown?"
        answer = answer.split("\n")[0]
        return answer, question_pmp

    def resolve_pred(self, answer):
        is_yes = "yes" in answer.split() or "true" in answer.split()
        is_no = "no" in answer.split() or "false" in answer.split()
        is_maybe = "maybe" in answer.split() or "maybe" in answer.split()

        pred = "Neither"
        if is_yes and (not is_maybe and not is_no):
            pred = "True"
        if is_no and (not is_maybe and not is_yes):
            pred = "False"
        if is_maybe and (not is_no and not is_yes):
            pred = "Neither"

        return pred

    def run_decomposed_prompt(
        self, test_data, boost_data_train, boost_dfs, manifest, overwrite_manifest
    ):
        expt_log, all_boost_preds, labels = self._run_decomp_single_data(test_data, boost_dfs, manifest, overwrite_manifest, run_limit=-1)
        expt_log_train, all_boost_train_preds, train_labels = self._run_decomp_single_data(boost_data_train, boost_dfs, manifest, overwrite_manifest, run_limit=1000)
        # Do WS
        preds = self.merge_boosted_preds(all_boost_preds, all_boost_train_preds, train_labels, expt_log, expt_log_train, indecisive_ans="neither")
        # Get accuracies across all boost sets
        individual_accuracies = []
        for i in range(len(all_boost_preds[0])):
            report = classification_report(labels, [p[i] for p in all_boost_preds], output_dict=True)
            individual_accuracies.append(report["accuracy"])
            print(report)
            print("\n\n")
        report = classification_report(labels, preds, output_dict=True)
        print(report)
        return expt_log, expt_log_train, report["accuracy"], individual_accuracies    
    
    def _run_decomp_single_data(
        self, test_data, boost_dfs, manifest, overwrite_manifest, run_limit = -1
    ):
        expt_log = {}
        all_boost_preds = []
        labels = []

        #test_data = test_data[0:100]
        for i, (ind, row) in tqdm(
            enumerate(test_data.iterrows()), total=len(test_data)
        ):
            prompts_across_boost = []
            preds_across_boost = []

            if i == run_limit:
                break
            
            text = row["inputs_pretokenized"]
            gold = row["targets_pretokenized"].strip()
            passage = text.split("\n")[0]
            statement = (
                text.split("\n")[-1]
                .replace("True, False, or Neither?", "")
                .strip()
                .strip("\n")
                .replace("Question: ", "")
            )
            for boost_idx, boost_examples in enumerate(boost_dfs):
                all_prompts = []

                # question / extract prompt

                question, question_final_prompt = self.get_question(
                    statement, questioner_prompt, boost_examples[0], manifest, overwrite_manifest
                )
                all_prompts.append(question_final_prompt)

                open_answer_f, extraction_final_prompt = self.get_extraction(
                    question,
                    passage,
                    extraction_qa,
                    boost_examples[1],
                    manifest,
                    overwrite_manifest,
                )
                all_prompts.append(extraction_final_prompt)
                if i == 0:
                    print("\n".join(all_prompts))
                answer_f = open_answer_f.lower()
                pred = self.resolve_pred(answer_f)
                pred = pred.strip().lower()

                preds_across_boost.append(pred)


                icl_str = ""
                if boost_idx == 0: icl_str = icl_1
                elif boost_idx == 1: icl_str = icl_2
                else: icl_str = icl_3

                prompt = f"{icl_str}\n\nStatement: {{statement:}}\nAnswer:" 

                pmp = prompt.format(statement=statement)
                if i == 0:
                    print("PMP ICL")
                    print(pmp)
                pred = get_response(
                    pmp,
                    manifest,
                    overwrite=bool(overwrite_manifest),
                    max_toks=10,
                    stop_token="\n",
                )
                pred = pred.lower().strip()
                pred = pred.replace(".", "").replace(",", "").replace("Label: ", "").replace("Sentiment:", "")
                pred = [p for p in pred.split("\n") if p]
                if pred:
                    pred = pred[0]
                else:
                    pred = ""

                all_prompts.append(prompt)
                prompts_across_boost.append(all_prompts)
                pred = self.resolve_pred(pred).lower()
                preds_across_boost.append(pred)
                gold = gold.strip().lower()
            print(preds_across_boost)

            expt_log[ind] = {
                "ind": ind,
                "preds_boost": preds_across_boost,
                "prompts": prompts_across_boost,
                "example": text,
                "pred": pred,
                "gold": gold,
            }
            all_boost_preds.append(preds_across_boost)
            labels.append(gold)
        return expt_log, all_boost_preds, labels


def main():
    args = get_args()
    task_name = "anli_r3"
    data_dir = "/home/data/P3/data_feather/anli_GPT_3_style_r3"
    decomp = ANLIDecomp(task_name, data_dir, val_split="test")
    decomp.run(args)


if __name__ == "__main__":
    main()
