import os
import logging
from src.dataset import (
    BlocksWorldDataset,
    LongSortDataset,
    ListSynthesisDataset,
    ClutrrDataset,
    SudokuDataset,
    TwentyFourGameDataset,
    OmniMathDataset,
    ClevrDataset,
    LeafDataset,
    AIMEDataset,
    BBEHDataset,
)
from src.answer_equivalence import get_equivalence, get_equivalence_or_judge
import boto3
from botocore.config import Config
from tqdm import tqdm
import argparse
from vllm import LLM
import re
import numpy as np
import torch
import ast
from transformers import (
    AutoProcessor,
    AutoModelForCausalLM,
    AutoTokenizer,
    MllamaForCausalLM,
)
from src.utils import base642img, RawInput, IOExamples
from src.pddl import eval_solution_files
from src.baselines import zs_cot, self_discover, autogen_prompt, code_prompt, zs_tot, zs_cots, our_method, gen_sym_prog, gen_sym, gen_task_prog, gen_sym_reason_prog, gen_sym_reason_prog_iter, gen_sym_reason_prog_test_iter, gen_sym_reason_translate, gen_sym_reason_prog_checks, code_interpreter_prompt, gen_sym_reason_prog_checks2, ablate_gen_reason_prog_checks
import time
import json
import tempfile
import re
import operator
from openai import OpenAI
import anthropic
from src.llm_models import OurLLM, APIModel
from src.function_evaluation import python_eval

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


METHOD_MAP = {
    "gen_task_prog": gen_task_prog,
    "gen_sym_prog": gen_sym_prog,
    "gen_sym_reason_prog": gen_sym_reason_prog,
    "gen_sym_reason_translate": gen_sym_reason_translate,
    "gen_sym_reason_prog_iter": gen_sym_reason_prog_iter,
    "gen_sym_reason_prog_test_iter": gen_sym_reason_prog_test_iter,
    "gen_sym_reason_prog_checks": gen_sym_reason_prog_checks,
    "gen_sym_reason_prog_checks2": gen_sym_reason_prog_checks2,
    "gen_sym": gen_sym,
    "zs_cot": zs_cot,
    "self_discover": self_discover,
    "autogen": autogen_prompt,
    "code": code_prompt,
    "code_interpreter": code_interpreter_prompt,
    "tot": zs_tot,
    'ablate_gen_reason_prog_checks':ablate_gen_reason_prog_checks
}

BATCHED_METHOD_MAP = {
    "zs_cot": zs_cots
}

def get_method_predictions(args, model: LLM, method, data, log=False, equiv=None):
    if equiv is None:
        equiv = lambda x, y: str(x) == str(y)
    logs = []
    preds = []
    gt = []
    template = None
    task_program = None
    if args.method == "gen_task_prog":
        training_examples = get_train_dataset(args)
        training_examples = [training_examples[i] for i in range(3)]
        training_examples = IOExamples(description=None, inputs=[RawInput(image_input=sample[0][0], text_input=sample[0][1]) for sample in training_examples], outputs=[sample[1] for sample in training_examples])
    for i in (pbar := tqdm(range(len(data)))):
        input_img = data[i][0][0]
        input_str = data[i][0][1]
        label = data[i][1]
        gt.append(label)

        if args.method == "self_discover":
            output, intermediate = method(RawInput(image_input=input_img, text_input=input_str), model, template)
            template = intermediate["reasoning_structure"]
        elif args.method == "gen_task_prog":
            output, intermediate = method(RawInput(image_input=input_img, text_input=input_str), model, task_program, training_examples)
            task_program = intermediate["program"]
        elif args.method == "gen_sym_prog" or args.method == "zs_cot":
            output, intermediate = method(RawInput(image_input=input_img, text_input=input_str), model, args.num_gen, args.temperature)
        else:
            output, intermediate = method(RawInput(image_input=input_img, text_input=input_str), model, args.temperature)
        if type(output) == list or type(output) == tuple:
            output = output[0]

        pred = get_pred(args, output)

        if log:
            logger.info("Output: %s", output)
            logger.info("GT: %s, Pred: %s", repr(gt[-1]), repr(pred))

        logs.append((output, pred, intermediate))
        preds.append(pred)
        pbar.set_description(
            f"Acc: {sum([equiv(data[i][1], preds[i], i) for i in range(len(preds))]) / len(preds)}"
        )

        # append logs to file
        if not os.path.exists(
            f"logs/{('debug/' if args.debug else '') + args.model}/{args.dataset}/{args.method}"
        ):
            os.makedirs(
                f"logs/{('debug/' if args.debug else '') + args.model}/{args.dataset}/{args.method}"
            )
        with open(
            f"logs/{('debug/' if args.debug else '') + args.model}/{args.dataset}/{args.method}/intermediate_outputs_gen_{args.num_gen}_temp_{args.temperature}.txt",
            "a",
        ) as f:
            f.write(f"{intermediate}\n")

    return preds, gt, logs

def get_method_predictions_batched(args, model: LLM, method, data, log=False, equiv=None):
    if equiv is None:
        equiv = lambda x, y: str(x) == str(y)
    logs = []
    inputs = []
    gt = []
    template = None
    
    for i in range(len(data)):
        input = data[i][0][1]
        label = data[i][1]
        gt.append(label)
        inputs.append(input)
    
    outputs, output_log = method(inputs, model)
    
    outputs = [o[0] for o in outputs if type(o) in [list, tuple]]
    
    preds = []
    
    for i in (pbar := tqdm(range(len(outputs)))):
        output = outputs[i]
        pred = get_pred(output)
        preds.append(pred)
        if log:
            logger.info("Output: %s", output)
            logger.info("GT: %s, Pred: %s", repr(gt[i]), repr(pred))
        
        logs.append((output, pred, output_log[i]))
        pbar.set_description(
            f"Acc: {sum([equiv(gt[i], preds[i], i) for i in range(len(preds))]) / len(preds)}"
        )
    
    # for i in (pbar := tqdm(range(len(data)))):
    #     input = data[i][0][1]
    #     label = data[i][1]
    #     gt.append(label)

    #     if args.method == "self_discover":
    #         output, intermediate = method(input, model, template)
    #         template = intermediate["reasoning_structure"]
    #     else:
    #         output, intermediate = method(input, model)
    #     if type(output) == list or type(output) == tuple:
    #         output = output[0]

    #     pred = get_pred(output)

    #     if log:
    #         logger.info("Output: %s", output)
    #         logger.info("GT: %s, Pred: %s", repr(gt[-1]), repr(pred))

    #     logs.append((output, pred, intermediate))
    #     preds.append(pred)
    #     pbar.set_description(
    #         f"Acc: {sum([equiv(data[i][1], preds[i], i) for i in range(len(preds))]) / len(preds)}"
    #     )

    return preds, gt, logs




def get_dataset(args, use_less = False):
    if args.dataset == "blocksworld":
        data = BlocksWorldDataset()
    elif args.dataset == "longsort":
        data = LongSortDataset()
    elif args.dataset == "listsynthesis":
        data = ListSynthesisDataset()
    elif args.dataset == "clutrr":
        data = ClutrrDataset(varied_complexity=True, decomposed=False)
    elif args.dataset == "sudoku":
        data = SudokuDataset()
    elif args.dataset == "24game":
        data = TwentyFourGameDataset(args.twentyfourgame_n)
    elif args.dataset.startswith("omnimath"):
        data = OmniMathDataset(split="test", difficulty=int(args.dataset[9:]))
    elif args.dataset == "aime":
        data = AIMEDataset()
    elif args.dataset == "clevr":
        data = ClevrDataset(raw_data=True)
    elif args.dataset == "leaf":
        data = LeafDataset(raw_data=True)
    elif args.dataset.startswith("bbeh"):
        dataset_name = args.dataset[5:]
        if use_less:
            data = BBEHDataset(subtasks=[dataset_name], split = "test")
        else:
            data = BBEHDataset(subtasks=[dataset_name], split = "all")
    else:
        raise NotImplementedError

    return data


def get_train_dataset(args):
    if args.dataset == "clutrr":
        data = ClutrrDataset(train=True, varied_complexity=True, decomposed=False)
    elif args.dataset == "leaf":
        data = LeafDataset(raw_data=True, train=True)
    elif args.dataset.startswith("bbeh"):
        dataset_name = args.dataset[5:]
        data = BBEHDataset(subtasks=[dataset_name], split = "train")
    else:
        raise NotImplementedError
    return data


def get_pred(args, output):
    extra_args = []
    # if not args.dataset == "gsm8k":
    extra_args.append(re.DOTALL)
    try:
        if "\\[ \\boxed{" in output:
            res = re.findall(r"\[ \\boxed{(.*)}", output, *extra_args)[-1]
            pred = res.strip()
        elif "$\\boxed{" in output:
            res = re.findall(r"\$\\boxed{(.*?)}", output, *extra_args)[-1]
            pred = res.strip()
        elif "**FINAL ANSWER:**" in output:
            res = re.findall(r"\*\*FINAL ANS.*:\*\*(.*)(?:<|$)", output, *extra_args)[
                -1
            ]
            pred = res.strip()
        elif "**Final Answer:**" in output:
            res = re.findall(r"\*\*Final Ans.*:\*\*(.*)(?:<|$)", output, *extra_args)[
                -1
            ]
            pred = res.strip()
        elif "**Final Answer**" in output:
            res = re.findall(r"\*\*Final Answer\*\*(.*)(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "**Final answer:**" in output:
            res = re.findall(r"\*\*Final ans.*:\*\*(.*)(?:<|$)", output, *extra_args)[
                -1
            ]
            pred = res.strip()
        elif "**Answer:**" in output:
            res = re.findall(r"\*\*Answer:\*\*(.*)(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "*FINAL ANSWER:*" in output:
            res = re.findall(r"\*FINAL ANS.*:(.*)(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "*Final Answer:*" in output:
            res = re.findall(r"\*Final Ans.*:(.*)(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "Final answer:" in output:
            res = re.findall(r"Final ans.*:(.*)(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "*Final answer:*" in output:
            res = re.findall(r"\*Final ans.*:(.*)(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "*Final Answer*" in output:
            res = re.findall(r"\*Final Answer\*(.*)(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "*Answer:*" in output:
            res = re.findall(r"\*Answer:\*(.*)(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "**Answer**:" in output:
            res = re.findall(r"\*\*Answer\*\*:(.*)(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "*Answer*:" in output:
            res = re.findall(r"\*Answer\*:(.*)(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "**Final answer:" in output:
            res = re.findall(r"\*\*Final ans.*:(.*)\*\*(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "**Final Answer:" in output:
            res = re.findall(r"\*\*Final Answer:(.*)\*\*(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "**FINAL ANSWER:" in output:
            res = re.findall(r"\*\*FINAL ANS.*:(.*)\*\*(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "*Final answer:" in output:
            res = re.findall(r"\*Final ans.*:(.*)\*(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "*Final Answer:" in output:
            res = re.findall(r"\*Final Answer:(.*)\*(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "*FINAL ANSWER:" in output:
            res = re.findall(r"\*FINAL ANS.*:(.*)\*(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "The final answer is:" in output:
            res = re.findall(r"The final ans.*:(.*)(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        elif "FINAL ANSWER:" in output:
            res = re.findall(r"FINAL ANSWER:(.*)(?:<|$)", output, *extra_args)[-1]
            pred = res.strip()
        else:
            # print("here", re.findall(r"FINAL ANS.*:(.*)(?:<|$)", output, *extra_args))
            # res = re.findall(r"FINAL ANS.*:(.*)(?:<|$)", output, *extra_args)[-1]
            res = output
            pred = res.strip()

        if "```" in pred:
            pred = re.sub(r"```", "", pred).strip()
        if "<|eot_id|>" in pred:
            pred = re.sub(r"<\|eot_id\|>", "", pred).strip()
    except Exception:
        if args.dataset == "sudoku":
            pred = output
        else:
            pred = "None found"
    
    if pred == None:
        pred = ""
    
    return pred


def eval_iterative(args):
    data = get_dataset(args)
    equiv = get_equivalence_or_judge(args)

    if "omnimath" in args.dataset or "leaf" in args.dataset:
        np.random.seed(0)
    test_data_ids = list(range(min(200, len(data)))) #+ list(range(103, len(data)))
    shuf = np.random.permutation(test_data_ids)
    test_data = [data[int(i)] for i in shuf[:200]]
    gt = [test_data[i][1] for i in range(len(test_data))]

    model_name = args.model.split("/")[-1]

    with open(f"logs/{model_name}/{args.dataset}/{args.method}/outputs_gen_{args.num_gen}_temp_{args.temperature}.txt", "r") as f:
        preds = f.readlines()

    outputs = []
    for pred in preds:
        try:
            output = ast.literal_eval(pred)[2]
        except Exception:
            output = {"all_outputs": []}
        outputs.append(output)
    # outputs = [ast.literal_eval(pred)[2] for pred in preds]

    predictions = []
    for output in outputs:
        preds = []
        symbols = ""
        program = ""
        for out in output["all_outputs"]:
            # extract the json from the first JSON code block
            try:
                json_str = re.findall(r"```json(.*?)```", out, re.DOTALL)[-1]
                json_str = json.loads(json_str)
                symbols = str(json_str)
            except Exception:
                json_str = ""

            # extract the code from the second code block
            try:
                code_str = re.findall(r"```python(.*?)```", out, re.DOTALL)[-1]
                program = str(code_str)
            except Exception:
                code_str = ""

            try:
                output, stdout, err = python_eval(program + "\nsymbols = " + symbols + "\nanswer = solve(symbols)")
                preds.append(output)
            except Exception as e:
                preds.append("None")
        predictions.append(preds)

    # # count tokens
    # total_tokens = 0
    # tokenizer = AutoTokenizer.from_pretrained(args.model)
    # for output in outputs:
    #     total_tokens += len(tokenizer(output)["input_ids"])
    # print(f"Avg. tokens per output: {total_tokens / len(outputs)}")

    if False and os.path.exists(f"logs/{model_name}/{args.dataset}/{args.method}/matches_iter_gen_{args.num_gen}_temp_{args.temperature}.txt"):
        gt_matches = [ast.literal_eval(line.strip()) for line in open(f"logs/{model_name}/{args.dataset}/{args.method}/matches_gen_{args.num_gen}_temp_{args.temperature}.txt", "r").readlines()]
    else:
        # first check if each prediction can be converted to a string
        # if not, set it to None
        for i in range(len(predictions)):
            for j in range(len(predictions[i])):
                try:
                    predictions[i][j] = str(predictions[i][j])
                except Exception as e:
                    predictions[i][j] = "None"
        gt_matches = [[equiv(gt[i], str(predictions[i][j]), shuf[i]) for j in range(len(predictions[i]))] for i in range(len(predictions))]
        with open(f"logs/{model_name}/{args.dataset}/{args.method}/matches_iter_gen_{args.num_gen}_temp_{args.temperature}.txt", "w") as f:
            for match in gt_matches:
                f.write(str(match) + "\n")

    # compute accuracy at each k
    accs = []
    for k in range(10):
        correct = 0
        for i in range(len(gt_matches)):
            if k >= len(gt_matches[i]):
                correct += 1 if gt_matches[i][-1] else 0
            else:
                correct += 1 if gt_matches[i][k] else 0
        accs.append(correct / len(gt_matches))
    print(f"Accuracies at each k: {accs}")

def eval(args):
    # np.random.seed(0)
    data = get_dataset(args)
    equiv = get_equivalence_or_judge(args)

    if "omnimath" in args.dataset:
        np.random.seed(0)
    test_data_ids = list(range(min(200, len(data)))) #+ list(range(103, len(data)))
    shuf = np.random.permutation(test_data_ids)
    test_data = [data[int(i)] for i in shuf[:200]]
    gt = [test_data[i][1] for i in range(len(test_data))]

    model_name = args.model.split("/")[-1]

    with open(f"logs/{model_name}/{args.dataset}/{args.method}/outputs_gen_{args.num_gen}_temp_{args.temperature}.txt", "r") as f:
        preds = f.readlines()

    # predictions = [ast.literal_eval(pred)[0] for pred in preds]
    if args.method == "tot":
        outputs = ["".join([a for b in list(ast.literal_eval(pred)[-1].values()) for a in b if type(a) == str]) for pred in preds]
    elif args.method == "self_discover":
        outputs = [ast.literal_eval(pred)[-1]["reasoning_structure"] + ast.literal_eval(pred)[0] for pred in preds]
    else:
        # outputs = [ast.literal_eval(pred)[2] for pred in preds]
        outputs = [ast.literal_eval(pred)[0] for pred in preds]

    predictions = []
    for output in outputs:
        pred = get_pred(args, output)
        predictions.append(pred)
        # try:
        #     output, stdout, err = python_eval(output["program"] + "\nsymbols = " + str(output["symbols"]) + "\nanswer = solve(symbols)")
        #     predictions.append(output)
        # except Exception as e:
        #     predictions.append("None")

    # # count tokens
    # total_tokens = 0
    # tokenizer = AutoTokenizer.from_pretrained(args.model)
    # for output in outputs:
    #     total_tokens += len(tokenizer(output)["input_ids"])
    # print(f"Avg. tokens per output: {total_tokens / len(outputs)}")

    if False and os.path.exists(f"logs/{model_name}/{args.dataset}/{args.method}/matches_gen_{args.num_gen}_temp_{args.temperature}.txt"):
        gt_matches = [ast.literal_eval(line.strip()) for line in open(f"logs/{model_name}/{args.dataset}/{args.method}/matches_gen_{args.num_gen}_temp_{args.temperature}.txt", "r").readlines()]
    else:
        gt_matches = [equiv(gt[i], predictions[i], shuf[i]) for i in range(len(predictions))]
        with open(f"logs/{model_name}/{args.dataset}/{args.method}/matches_gen_{args.num_gen}_temp_{args.temperature}.txt", "w") as f:
            for match in gt_matches:
                f.write(str(match) + "\n")
    acc = sum(gt_matches) / len(gt_matches)
    print(f"Method: {args.method}, Acc: {acc}")

    # with open(f"logs/{model_name}/{args.dataset}/{args.method}/results.txt", "a") as f:
    #     f.write(f"{model_name},{args.dataset},{args.method},{acc},{total_tokens / len(outputs)}\n")


def main(args):
    # model = MllamaForConditionalGeneration.from_pretrained(
    #     args.model, torch_dtype=torch.bfloat16
    # ).to("cuda:0")
    # processor = AutoProcessor.from_pretrained(args.model)

    if "claude" not in args.model and "gemini" not in args.model and "gpt" not in args.model and "o3" not in args.model and "o4" not in args.model and "Qwen" not in args.model and "nano" not in args.model and not args.use_hf:
        model = LLM(
            model=args.model,
            max_model_len=10000 if "Llama" in args.model else 30000,
            gpu_memory_utilization=0.95,
            # limit_mm_per_prompt={"image": 10},
            max_num_seqs=1,
            
            # enforce_eager=True if "llama" in args.model.lower() else False,
            trust_remote_code=True,
            tensor_parallel_size=args.num_gpus,
        )
    elif "claude" in args.model or "gemini" in args.model or "gpt" in args.model or "o3" in args.model or "o4" in args.model or "Qwen" in args.model or "nano" in args.model:
        model = APIModel(model_name=args.model)
    else:
        model = OurLLM(model_name=args.model)

    data = get_dataset(args)
    equiv = get_equivalence(args)

    test_data_ids = list(range(min(200, len(data)))) #+ list(range(103, len(data)))
    shuf = np.random.permutation(test_data_ids)
    test_data = [data[int(i)] for i in shuf[:min(200, len(shuf))]]
    gt = [test_data[i][1] for i in range(len(test_data))]

    method = METHOD_MAP[args.method] if not args.batched else BATCHED_METHOD_MAP[args.method]
    
    get_preds = get_method_predictions if not args.batched else get_method_predictions_batched
    
    preds, gt, logs = get_preds(
        args,
        model,
        method,
        test_data,
        log=args.log,
        equiv=equiv,
    )

    acc = sum([equiv(gt[i], preds[i], shuf[i]) for i in range(len(preds))]) / len(preds)

    # check if logs/model dir exists
    model_name = args.model
    if not os.path.exists(
        f"logs/{('debug/' if args.debug else '') + model_name}/{args.dataset}/{args.method}"
    ):
        os.makedirs(
            f"logs/{('debug/' if args.debug else '') + model_name}/{args.dataset}/{args.method}"
        )
    with open(
        f"logs/{('debug/' if args.debug else '') + model_name}/{args.dataset}/{args.method}/outputs_gen_{args.num_gen}_temp_{args.temperature}.txt",
        "w",
    ) as f:
        for log in logs:
            f.write(str(log) + "\n")

    # append to results file
    with open(
        f"logs/{('debug/' if args.debug else '') + model_name}/{args.dataset}/{args.method}/results.txt",
        "a",
    ) as f:
        f.write(
            f"{('debug_' if args.debug else '') + model_name},{args.dataset},{args.method},{args.num_gen},{args.temperature},{acc}\n"
        )


if __name__ == "__main__":
    # set seeds
    np.random.seed(0)

    args = argparse.ArgumentParser()
    args.add_argument("--dataset", type=str, default="longsort")
    args.add_argument(
        "--model", type=str, default="meta-llama/Llama-3.2-90B-Vision-Instruct"
    )
    args.add_argument("--method", default="zs_cot")
    args.add_argument("--log", action="store_true")
    args.add_argument("--num_gpus", type=int, default=2)
    args.add_argument("--twentyfourgame_n", type=int, default=4)
    args.add_argument("--use_hf", action="store_true")
    args.add_argument("--debug", action="store_true")
    args.add_argument("--eval", action="store_true")
    args.add_argument("--batched", action="store_true")
    args.add_argument("--num_gen", type=int, default=1)
    args.add_argument("--temperature", type=float, default=0.0)
    args = args.parse_args()

    logger.info("Starting")

    if args.eval and args.method == "gen_sym_reason_prog_checks":
        eval_iterative(args)
    elif args.eval:
        eval(args)
    else:
        main(args)
