import os
from src.dataset import ClutrrDataset, GenClutrrDataset, ClevrDataset
from src.function import LLMNesy
from llm_symbolic_eval import get_task_predictions, get_raw_predictions, APIModel, OurLLM
from vllm import LLM
import numpy as np
from src.utils import IOExamples, RawInput
from src.symbol_mapping import LLMNet
import re
import argparse
import ast
        

def equiv(gt, pred, i):
    return gt == pred


def clutrr_extract(data, model):
    examples = None
    if args.few_shot:
        examples = IOExamples(
            description=None,
            inputs=[RawInput(image_input=None, text_input="Bob is the son of John. Bob is John's what?"), RawInput(image_input=None, text_input="Bob is the son of John. John is Bob's what?")],
            outputs=[['son'], ['father']],
        )
    # extract_relation = LLMNet(
    #     model,
    #     "a description of a relationship between two people and a query about the two people's relationship",
    #     "the described relationship which answers the question. The relationship must be one of the following: {'brother', 'sister', 'father', 'mother', 'son', 'daughter', 'grandfather', 'grandmother', 'uncle', 'aunt', 'nephew', 'niece', 'husband', 'wife', 'brother-in-law', 'sister-in-law', 'son-in-law', 'daughter-in-law', 'father-in-law', 'mother-in-law', 'grandson', 'granddaughter'}. For example, for the input 'John took his sister Mary to the store. John is Mary\'s what?' the output should be 'brother.' Output just the relationship as a word.",
    #     examples
    # )
    extract_relations = LLMNet(
        model,
        "a story with a question about the relationships between people in the story",
        "a Python dictionary mapping pairs of people to their relationship described in the text. The relationship is one of the following: {'brother', 'sister', 'father', 'mother', 'son', 'daughter', 'grandfather', 'grandmother', 'uncle', 'aunt', 'nephew', 'niece', 'husband', 'wife', 'brother-in-law', 'sister-in-law', 'son-in-law', 'daughter-in-law', 'father-in-law', 'mother-in-law', 'grandson', 'granddaughter'}. Include all relationships described in the passage.",
        IOExamples(
            description=None,
            inputs=[RawInput(image_input=None, text_input="Bob is the son of John. John is the son of Abe. How is Bob related to Abe?")],
            outputs=[[{("Bob", "John"): "son", ("John", "Bob"): "father", ("John", "Abe"): "son", ("Abe", "John"): "father"}]],
        )
    )

    def parse(context: RawInput, query: RawInput):
        context.text_input += f" How is {query.text_input[0]} related to {query.text_input[1]}?"
        out_dict = extract_relations.forward(context)
        print("out_dict:", out_dict)
        out_dict = ast.literal_eval(out_dict.strip())
        print("type:", type(out_dict))
        print(query.text_input)
        # return out_dict, tuple(query.text_input.replace("[", "").replace("]", "").replace("'", "").split(", "))
        return out_dict, query.text_input

    # def parse(context: RawInput, query: RawInput):
    #     # Preprocess sentences
    #     relation_sentences = []
    #     relation_name_pairs = []
    #     curr_relation_sentences = []
    #     curr_name_pairs = []
    #     skip_next = False
    #     skip_until = 0
    #     context = [s.strip() for s in context.text_input.split(".") if s.strip() != ""]
    #     for (j, sentence) in enumerate(context):
    #         # It is possible to skip a sentence because the previous one includes the current one.
    #         if skip_next:
    #             if j >= skip_until:
    #                 skip_next = False
    #             continue

    #         # Get all the names of the current sentence
    #         names = re.findall(r"\[(\w+)\]", sentence)

    #         # Check if we need to include the next sentence(s) as well
    #         num_sentences_limit = 4
    #         num_sentences = 1
    #         union_sentence = f"{sentence}"
    #         for k in range(j + 1, len(context)):
    #             next_sentence = context[k]
    #             next_sentence_names = re.findall(r"\[(\w+)\]", next_sentence)
    #             if (len(names) == 1 or len(next_sentence_names) == 1) and num_sentences < num_sentences_limit:
    #                 if len(next_sentence_names) > 0:
    #                     num_sentences += 1
    #                     union_sentence += f". {next_sentence}"
    #                     names += next_sentence_names
    #                 skip_next = True
    #                 if len(next_sentence_names) == 1:
    #                     skip_until = k - 1
    #                 else:
    #                     skip_until = k
    #             else:
    #                 break

    #         # Deduplicate the names
    #         names = list(dict.fromkeys(names))

    #         # Clean up the sentence and add it to the batch
    #         clean_sentence = union_sentence.replace("[", "").replace("]", "")
    #         curr_relation_sentences += [f"{clean_sentence}. {names[k]} is {names[l]}'s what?" for k in range(len(names)) for l in range(len(names)) if k != l]
    #         curr_name_pairs += [(k, l) for k in names for l in names if k != l]

    #     # Construct the current datatpoint
    #     relation_sentences += curr_relation_sentences
    #     relation_name_pairs += curr_name_pairs

    #     facts = []
    #     for i in range(len(relation_sentences)):
    #         rel = extract_relation.forward(RawInput(image_input=None, text_input=relation_sentences[i]))
    #         rel = re.sub(r"[^a-zA-Z\-]", "", rel)
    #         facts.append((relation_name_pairs[i], rel))

    #     return facts, query.text_input #tuple(query.text_input.replace("[", "").replace("]", "").replace("'", "").split(", "))

    def function(facts, query):
        rules = {
           ("daughter", "daughter"): "granddaughter",
        ("daughter", "sister"): "daughter",
        ("daughter", "son"): "grandson",
        ("daughter", "aunt"): "sister",
        ("daughter", "father"): "husband",
        ("daughter", "husband"): "son-in-law",
        ("daughter", "brother"): "son",
        ("daughter", "mother"): "wife",
        ("daughter", "uncle"): "brother",
        ("daughter", "grandfather"): "father",
        ("daughter", "grandfather"): "father-in-law",
        ("daughter", "grandmother"): "mother",
        ("daughter", "grandmother"): "mother-in-law",
        ("sister", "daughter"): "niece",
        ("sister", "sister"): "sister",
        ("sister", "son"): "nephew",
        ("sister", "aunt"): "aunt",
        ("sister", "father"): "father",
        ("sister", "brother"): "brother",
        ("sister", "mother"): "mother",
        ("sister", "uncle"): "uncle",
        ("sister", "grandfather"): "grandfather",
        ("sister", "grandmother"): "grandmother",
        ("son", "daughter"): "granddaughter",
        ("son", "sister"): "daughter",
        ("son", "son"): "grandson",
        ("son", "aunt"): "sister",
        ("son", "father"): "husband",
        ("son", "brother"): "son",
        ("son", "mother"): "wife",
        ("son", "uncle"): "brother",
        ("son", "grandfather"): "father",
        ("son", "grandfather"): "father-in-law",
        ("son", "grandmother"): "mother",
        ("son", "grandmother"): "mother-in-law",
        ("aunt", "sister"): "aunt",
        ("aunt", "father"): "grandfather",
        ("aunt", "brother"): "uncle",
        ("aunt", "mother"): "grandmother",
        ("father", "daughter"): "sister",
        ("father", "sister"): "aunt",
        ("father", "son"): "brother",
        ("father", "father"): "grandfather",
        ("father", "brother"): "uncle",
        ("father", "mother"): "grandmother",
        ("father", "wife"): "mother",
        ("husband", "daughter"): "daughter",
        ("husband", "son"): "son",
        ("husband", "father"): "father-in-law",
        ("husband", "granddaughter"): "granddaughter",
        ("husband", "mother"): "mother-in-law",
        ("husband", "grandson"): "grandson",
        ("granddaughter", "sister"): "granddaughter",
        ("granddaughter", "brother"): "grandson",
        ("brother", "daughter"): "niece",
        ("brother", "sister"): "sister",
        ("brother", "son"): "nephew",
        ("brother", "aunt"): "aunt",
        ("brother", "father"): "father",
        ("brother", "brother"): "brother",
        ("brother", "mother"): "mother",
        ("brother", "uncle"): "uncle",
        ("brother", "grandfather"): "grandfather",
        ("brother", "grandmother"): "grandmother",
        ("nephew", "sister"): "niece",
        ("nephew", "brother"): "nephew",
        ("mother", "daughter"): "sister",
        ("mother", "sister"): "aunt",
        ("mother", "son"): "brother",
        ("mother", "father"): "grandfather",
        ("mother", "husband"): "father",
        ("mother", "brother"): "uncle",
        ("mother", "mother"): "grandmother",
        ("mother", "father"): "grandfather",
        ("mother", "mother"): "grandmother",
        ("uncle", "sister"): "aunt",
        ("uncle", "father"): "grandfather",
        ("uncle", "brother"): "uncle",
        ("uncle", "mother"): "grandmother",
        ("grandfather", "wife"): "grandmother",
        ("wife", "daughter"): "daughter",
        ("wife", "son"): "son",
        ("wife", "father"): "father-in-law",
        ("wife", "granddaughter"): "granddaughter",
        ("wife", "mother"): "mother-in-law",
        ("wife", "grandson"): "grandson",
        ("wife", "son-in-law"): "son-in-law",
        ("wife", "father-in-law"): "father",
        ("wife", "daughter-in-law"): "daughter-in-law",
        ("wife", "mother-in-law"): "mother",
        ("grandmother", "husband"): "grandfather",
        ("grandson", "sister"): "granddaughter",
        ("grandson", "brother"): "grandson", 
        }

        # rules = {('sister-in-law', 'brother'): 'brother-in-law', ('sister-in-law', 'sister'): 'sister-in-law', ('brother-in-law', 'brother'): 'brother-in-law', ('brother-in-law', 'sister'): 'sister-in-law', ('daughter-in-law', 'daughter'): 'granddaughter', ('daughter-in-law', 'son'): 'grandson', ('son-in-law', 'daughter'): 'granddaughter', ('son-in-law', 'son'): 'grandson', ('nephew', 'sister'): 'niece', ('nephew', 'brother'): 'nephew', ('niece', 'sister'): 'niece', ('niece', 'brother'): 'nephew', ('grandson', 'aunt'): 'daughter', ('grandson', 'uncle'): 'son', ('granddaughter', 'aunt'): 'daughter', ('granddaughter', 'uncle'): 'son', ('brother-in-law', 'son'): 'nephew', ('sister-in-law', 'son'): 'nephew', ('brother-in-law', 'daughter'): 'niece', ('sister-in-law', 'daughter'): 'niece', ('sister-in-law', 'father'): 'father-in-law', ('grandson', 'brother'): 'grandson', ('grandson', 'father'): 'son', ('grandson', 'sister'): 'granddaughter', ('niece', 'grandfather'): 'father', ('grandmother', 'husband'): 'grandfather', ('wife', 'mother'): 'mother-in-law', ('wife', 'brother'): 'brother-in-law', ('wife', 'sister'): 'sister-in-law', ('wife', 'mother-in-law'): 'mother', ('wife', 'daughter-in-law'): 'daughter-in-law', ('wife', 'father-in-law'): 'father', ('wife', 'son-in-law'): 'son-in-law', ('wife', 'grandson'): 'grandson', ('wife', 'granddaughter'): 'granddaughter', ('wife', 'father'): 'father-in-law', ('wife', 'son'): 'son', ('wife', 'daughter'): 'daughter', ('grandfather', 'wife'): 'grandmother', ('grandfather', 'son'): 'father', ('grandfather', 'daughter'): 'mother', ('uncle', 'mother'): 'grandmother', ('uncle', 'brother'): 'uncle', ('uncle', 'father'): 'grandfather', ('uncle', 'sister'): 'aunt', ('mother', 'mother-in-law'): 'grandmother', ('mother', 'daughter-in-law'): 'wife', ('mother', 'father-in-law'): 'grandfather', ('mother', 'son-in-law'): 'husband', ('mother', 'grandson'): 'son', ('mother', 'mother'): 'grandmother', ('mother', 'brother'): 'uncle', ('mother', 'granddaughter'): 'daughter', ('mother', 'husband'): 'father', ('mother', 'father'): 'grandfather', ('mother', 'son'): 'brother', ('mother', 'sister'): 'aunt', ('mother', 'daughter'): 'sister', ('nephew', 'grandmother'): 'mother', ('nephew', 'grandfather'): 'father', ('nephew', 'uncle'): 'brother', ('nephew', 'mother'): 'sister', ('nephew', 'father'): 'brother', ('nephew', 'aunt'): 'sister', ('brother', 'niece'): 'niece', ('brother', 'grandmother'): 'grandmother', ('brother', 'grandfather'): 'grandfather', ('brother', 'uncle'): 'uncle', ('brother', 'mother'): 'mother', ('brother', 'nephew'): 'nephew', ('brother', 'brother'): 'brother', ('brother', 'father'): 'father', ('brother', 'aunt'): 'aunt', ('brother', 'son'): 'nephew', ('brother', 'sister'): 'sister', ('brother', 'daughter'): 'niece', ('granddaughter', 'grandmother'): 'wife', ('granddaughter', 'grandfather'): 'husband', ('granddaughter', 'mother'): 'daughter', ('granddaughter', 'brother'): 'grandson', ('granddaughter', 'father'): 'son', ('granddaughter', 'sister'): 'granddaughter', ('husband', 'grandson'): 'grandson', ('husband', 'mother'): 'mother-in-law', ('husband', 'brother'): 'brother-in-law', ('husband', 'granddaughter'): 'granddaughter', ('husband', 'father'): 'father-in-law', ('husband', 'son'): 'son', ('husband', 'sister'): 'sister-in-law', ('husband', 'daughter'): 'daughter', ('father', 'wife'): 'mother', ('father', 'mother'): 'grandmother', ('father', 'brother'): 'uncle', ('father', 'granddaughter'): 'daughter', ('father', 'father'): 'grandfather', ('father', 'son'): 'brother', ('father', 'sister'): 'aunt', ('father', 'daughter'): 'sister', ('aunt', 'mother'): 'grandmother', ('aunt', 'brother'): 'uncle', ('aunt', 'father'): 'grandfather', ('aunt', 'sister'): 'aunt', ('son', 'grandmother'): 'mother', ('son', 'wife'): 'daughter-in-law', ('son', 'grandfather'): 'father', ('son', 'uncle'): 'brother', ('son', 'mother'): 'wife', ('son', 'brother'): 'son', ('son', 'father'): 'husband', ('son', 'aunt'): 'sister', ('son', 'son'): 'grandson', ('son', 'sister'): 'daughter', ('son', 'daughter'): 'granddaughter', ('sister', 'niece'): 'niece', ('sister', 'grandmother'): 'grandmother', ('sister', 'grandfather'): 'grandfather', ('sister', 'uncle'): 'uncle', ('sister', 'mother'): 'mother', ('sister', 'nephew'): 'nephew', ('sister', 'brother'): 'brother', ('sister', 'husband'): 'brother-in-law', ('sister', 'father'): 'father', ('sister', 'aunt'): 'aunt', ('sister', 'son'): 'nephew', ('sister', 'sister'): 'sister', ('sister', 'daughter'): 'niece', ('daughter', 'grandmother'): 'mother', ('daughter', 'grandfather'): 'father', ('daughter', 'uncle'): 'brother', ('daughter', 'mother'): 'wife', ('daughter', 'brother'): 'son', ('daughter', 'husband'): 'son-in-law', ('daughter', 'father'): 'husband', ('daughter', 'aunt'): 'sister', ('daughter', 'son'): 'grandson', ('daughter', 'sister'): 'daughter', ('daughter', 'daughter'): 'granddaughter', ('niece', 'father'): 'brother', ('niece', 'mother'): 'sister', ('sister', 'wife'): 'sister-in-law', ('brother', 'husband'): 'brother-in-law', ('brother', 'wife'): 'sister-in-law', ('sister-in-law', 'mother'): 'mother-in-law', ('brother-in-law', 'father'): 'father-in-law', ('brother-in-law', 'mother'): 'mother-in-law', ('grandmother', 'daughter'): 'mother', ('grandmother', 'son'): 'father', ('mother-in-law', 'daughter'): 'wife', ('father-in-law', 'daughter'): 'wife', ('sister-in-law', 'husband'): 'brother-in-law', ('brother-in-law', 'wife'): 'sister-in-law', ('mother-in-law', 'son'): 'brother-in-law', ('husband', 'wife'): 'self', ('wife', 'husband'): 'self', ('grandson', 'mother'): 'daughter', ('cousin', 'grandmother'): 'grandmother', ('cousin', 'grandfather'): 'grandfather', ('aunt', 'son'): 'cousin', ('aunt', 'daughter'): 'cousin', ('uncle', 'son'): 'cousin', ('uncle', 'daughter'): 'cousin', ('niece', 'grandmother'): 'mother', ('niece', 'uncle'): 'brother', ('niece', 'aunt'): 'sister', ('cousin', 'uncle'): 'uncle', ('cousin', 'aunt'): 'aunt', ('grandson', 'grandfather'): 'husband', ('grandson', 'grandmother'): 'wife', ('cousin', 'mother'): 'aunt', ('cousin', 'father'): 'uncle', ('cousin', 'sister'): 'sister', ('cousin', 'brother'): 'brother', ('father-in-law', 'son'): 'husband', ('mother-in-law', 'husband'): 'father-in-law', ('father-in-law', 'wife'): 'mother-in-law', ('aunt', 'husband'): 'uncle', ('uncle', 'wife'): 'aunt'}

        # facts = {(pair[0], pair[1]): rel for pair, rel in facts}
        print("facts:", facts)
        last_facts = {}
        while query not in facts:
            added_facts = {}
            for fact1 in facts.items():
                for fact2 in facts.items():
                    if fact1[0][0] != fact2[0][1] and fact1[0][1] == fact2[0][0] and (fact2[1], fact1[1]) in rules and (fact1[0][0], fact2[0][1]) not in facts:
                        new_fact = rules[(fact2[1], fact1[1])]
                        print("Applying rule:", (fact2[1], fact1[1]), "->", new_fact)
                        added_facts[(fact1[0][0], fact2[0][1])] = new_fact
            print("Adding facts:", added_facts)
            facts.update(added_facts)
            if last_facts == facts:
                break
            last_facts = facts.copy()
        print("final facts:", facts)

        if query in facts:
            return facts[query]
        else:
            return "Uncertain"

    return parse, function


def main(args):
    np.random.seed(0)
    if args.dataset == "clutrr":
        test_data = ClutrrDataset(train=False, varied_complexity=True)
        gt = [sample[1] for sample in test_data]
    else:
        np.random.seed(0)
        test_data = ClevrDataset(max_samples=500)
        test_data_ids = list(range(min(200, len(test_data)))) #+ list(range(103, len(data)))
        shuf = np.random.permutation(test_data_ids)
        test_data = [test_data[int(i)] for i in shuf[:200]]
        gt = [test_data[i][1] for i in range(len(test_data))]
    # test_data = GenClutrrDataset()
    prompt = None


    # load model
    if not args.use_hf and not "gemini" in args.model.lower() and not "gpt" in args.model.lower() and not "o3" in args.model.lower():
        extra_args = {}
        if "mistral" in args.model.lower():
            extra_args = {"config_format": "mistral", "load_format": "mistral", "tokenizer_mode": "mistral"}
        model = LLM(
            model=args.model,
            max_model_len=12288,
            limit_mm_per_prompt={"image": 10},
            max_num_seqs=1,
            enforce_eager=True if "llama" in args.model.lower() else False,
            trust_remote_code=True,
            tensor_parallel_size=args.num_gpus,
            **extra_args
        )
    elif "gemini" in args.model.lower() or "gpt" in args.model.lower() or "o3" in args.model.lower():
        model = APIModel(args.model)
    else:
        model = OurLLM(model_name=args.model)

    if args.end2end:
        if args.dataset == "clutrr":
            # prompt = "Analyze the provided input and think through the answer step-by-step. The answer is a one of the following relations: aunt, brother, brother-in-law, daughter, daughter-in-law, father, father-in-law, granddaughter, grandfather, grandmother, grandson, mother, mother-in-law, nephew, niece, sister, sister-in-law, son, son-in-law, uncle, wife, husband."
            prompt = "Analyze the provided input and use code to help solve the problem. The answer is a one of the following relations: aunt, brother, brother-in-law, daughter, daughter-in-law, father, father-in-law, granddaughter, grandfather, grandmother, grandson, mother, mother-in-law, nephew, niece, sister, sister-in-law, son, son-in-law, uncle, wife, husband."
            test_data = [((None, sample[0][0] + f" How is {sample[0][1][0]} related to {sample[0][1][1]}?"), sample[1]) for sample in test_data]
        else:
            prompt = "Analyze the provided image and question and output the answer to the question. Colors can be one of ['gray','green','blue','red','brown','purple','yellow','cyan']. Each shape can be one of ['cube','cylinder','sphere']. Material can be one of ['rubber','metal']. Size can be one of ['small','large']."
            test_data = [([sample[0][0], sample[2][0]], sample[1]) for sample in test_data]

        preds, gts, logs = get_raw_predictions(model, test_data, log=True, equiv=equiv, instruction=prompt)

    elif args.direct:
        prompt = "Analyze the provided input. The answer is a one of the following relations: aunt, brother, brother-in-law, daughter, daughter-in-law, father, father-in-law, granddaughter, grandfather, grandmother, grandson, mother, mother-in-law, nephew, niece, sister, sister-in-law, son, son-in-law, uncle, wife, husband. OUTPUT JUST THE ANSWER and NOTHING ELSE."
        test_data = [((None, sample[0][0] + f" How is {sample[0][1][0]} related to {sample[0][1][1]}?"), sample[1]) for sample in test_data]

        preds, gts, logs = get_raw_predictions(model, test_data, log=True, equiv=equiv, instruction=prompt)
    else:
        get_symbols, function = clutrr_extract(None, model)

        task = LLMNesy(get_symbols, function)
        preds, logs = get_task_predictions(
            task,
            test_data,
            log=args.log,
            equiv=equiv,
        )

    acc = sum([equiv(gt[i], preds[i], i) for i in range(len(preds))]) / len(preds)
    print(f"Accuracy:", acc)
    
    # check if logs/model dir exists
    model_name = args.model.split("/")[-1]
    method_name = "end2end" if args.end2end else "direct" if args.direct else "llm_symbolic"
    if not os.path.exists(f"logs/{('debug/' if args.debug else '') + model_name}/{args.dataset}_complex/"):
        os.makedirs(f"logs/{('debug/' if args.debug else '') + model_name}/{args.dataset}_complex/")
    with open(
        f"logs/{('debug/' if args.debug else '') + model_name}/{args.dataset}_complex/{method_name}_{'fs' if args.few_shot else 'zs'}.txt",
        "w",
    ) as f:
        for log in logs:
            f.write(str(log) + "\n")

    # append to results file
    with open(f"logs/{('debug/' if args.debug else '') + model_name}/{args.dataset}_complex/results.txt", "a") as f:
        f.write(
            f"{('debug_' if args.debug else '') + model_name},{args.few_shot},{method_name},{args.single_turn},{args.image_before},{args.dataset}_complex,{acc}\n"
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--num_gpus", type=int, default=1)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--log", action="store_true")
    parser.add_argument("--end2end", action="store_true")
    parser.add_argument("--direct", action="store_true")
    parser.add_argument("--few_shot", action="store_true")
    parser.add_argument("--single_turn", action="store_true")
    parser.add_argument("--image_before", action="store_true")
    parser.add_argument("--dataset", type=str, default="clutrr")
    parser.add_argument("--use_hf", action="store_true")
    args = parser.parse_args()

    main(args)