import requests
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import transformers
from typing import *

# viper gpt
from rich.console import Console
from rich.syntax import Syntax
from rich.live import Live
from rich.padding import Padding
import ast
import astunparse


# viper gpt
from image_patch import *
# from video_segment import *
from vision_models import *
from vision_processes import *

from prompts.utils_prompt import *
# must import time here, since other modules have already imported time from time 
import time

console = Console(highlight=False, force_terminal=False)

gpt3_model = GPT3Model()



def compute_simcse(model, tokenizer, texts):
    '''
    Given a list of texts, compute the similarity between each pair of texts.
    :param texts: a list of text.
    :return: a list of similarity scores.
    '''
    data_loader = DataLoader(texts, shuffle=False, batch_size=32)
    embeddings = []
    for batch in data_loader:
        # Tokenize input texts
        inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt")
        for key in inputs.keys():
            inputs[key] = inputs[key].cuda()
        # Get the embeddings
        with torch.no_grad():
            embedding = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output.detach().cpu().numpy()
            embeddings.extend(embedding)
    
    # compute similarity
    embeddings = torch.from_numpy(np.array(embeddings))
    del tokenizer, model, data_loader
    return embeddings

def sort_by_similarity(included_embeddings, remaining_embeddings):

    similarity_matrix = torch.zeros((len(included_embeddings), len(remaining_embeddings)))

    for i in range(len(remaining_embeddings) // 1000 + 1): # to avoid oom error
        start, end = i*1000, min((i+1)*1000, len(remaining_embeddings))
        part_of_instruction_embeddings = remaining_embeddings[start:end]
        # compute similarity using matrix multiplication
        part_of_similarity = torch.nn.functional.cosine_similarity(included_embeddings.unsqueeze(1).cuda(), part_of_instruction_embeddings.unsqueeze(0).cuda(), dim=2).cpu()
        similarity_matrix[:, start:end] = part_of_similarity
        assert similarity_matrix[:, start:end].shape == part_of_similarity.shape

    # for each sample in the remaining dataset, find the highest similarity value with any included examples.
    similarity_scores, _ = torch.max(similarity_matrix, dim=0) # index does not matter
    print(similarity_matrix.shape, similarity_scores.shape) # similarity_scores.shape == (len(instruction_dataset),)
    
    # sort similarity_scores in ascending order
    sorted_similarity_scores, sorted_similarity_indices = torch.sort(similarity_scores, dim=0)
    print(sorted_similarity_scores[:5], sorted_similarity_scores[-5:])

    return sorted_similarity_indices, sorted_similarity_scores

def filter_direct_query(query):
    prompt_1 =  "You will be given a function named `llm_query`. The function Answers a text question using GPT-3 for reasoning and inference. "\
                "Since GPT-3 cannot process visual information, the question must be image-independent.\n" 
    prompt_2 =  "Then, you will be given a query: {query}\n" 
    prompt_3 =  "You need to decide if this llm_query function is able to **directly** solve this query. Directly answer yes or no.\n" \
                "Tips: If the query requires visual information of an image to solve, you should answer no. Otherwise, if the query is an image-independent inference task that can be solved by LLM reasoning or search engine, you should answer yes.\n" \
                
    print(query)
    message = [
        {"role": "user", "content": prompt_1 + prompt_2.format(query="Why isn't Song Hee taking customers?") + prompt_3},
        {"role": "assistant", "content": "yes"},
        {"role": "user", "content": prompt_1 + prompt_2.format(query="Why might one of the news anchors look angry, and the other look concerned?") + prompt_3},
        {"role": "assistant", "content": "no"},
        {"role": "user", "content": prompt_1 + prompt_2.format(query=query) + prompt_3},
    ]
    reply = gpt3_model.query_with_message(messages, model="gpt-3.5-turbo-0613", max_tokens=1, temperature=0.0)
    response = re.sub(r"[^a-zA-Z0-9]+", "", response.lower())
    return 0 if response == "yes" else 1


def load_image(path):
    if path.startswith("http://") or path.startswith("https://"):
        image = Image.open(requests.get(path, stream=True).raw).convert('RGB')
        image = transforms.ToTensor()(image)
    else:
        image = Image.open(path).convert('RGB')
        image = transforms.ToTensor()(image)
    return image

def extract_code(code): # TODO: may have bug
    code_blocks = code.split("\n\n")
    for code_block in code_blocks:
        if "def execute_command" in code_block: # return first one
            code_block = code_block.split("\ndef ") # several functions
            if len(code_block) == 1:
                return code_block[0]
            else:
                return "\n".join([code_block[0], "def "+code_block[1]]) # query + `execute_command` function
    return "\n\n".join(code_blocks)

def split_codeline_and_indent_level(codeline):
    origlen = len(codeline)
    codeline = codeline.lstrip(" ")
    indent = origlen - len(codeline)
    indent = "\t" * int(indent / 4 + 0.5) # '\t' is 4 spaces; rounding off
    return codeline, indent

def process_code(code):
    code = extract_code(code).split("\n")
    newcode = []
    for codeline in code:
        if codeline.startswith(" "):
            codeline, indent = split_codeline_and_indent_level(codeline)
            newcode.append(f"{indent}{codeline}")
        else:
            newcode.append(codeline)
    return "\n".join(newcode)

def execute_code(code, image, question=None):
    code = astunparse.unparse(ast.parse(code))
    console.print(code)
    exec(compile(code, filename='Codex', mode='exec'), globals())
    if question is None:
        result = execute_command(image)  # The code is created in the exec()
    else:
        result = execute_command(image, question)
    # console.rule(f"[bold]Final Result[/bold]", style="chartreuse2")
    
    return result


def compute_iou(box1, box2):
    """Compute the intersection over union of two set of boxes, each box is [x1,y1,x2,y2].
    Args:
      box1: (tensor) bounding boxes, sized [N,4].
      box2: (tensor) bounding boxes, sized [M,4].
    Return:
      (tensor) iou, sized [N,M].
    """
    N = box1.size(0)
    M = box2.size(0)

    lt = torch.max(
        box1[:, :2].unsqueeze(1).repeat(1, M, 1),  # [N,2] -> [N,1,2] -> [N,M,2]
        box2[:, :2].unsqueeze(0).repeat(N, 1, 1),  # [M,2] -> [1,M,2] -> [N,M,2]
    )
    # print(box1[:, 2:].shape, box2[:, 2:].shape)
    rb = torch.min(
        box1[:, 2:].unsqueeze(1).repeat(1, M, 1),  # [N,2] -> [N,1,2] -> [N,M,2]
        box2[:, 2:].unsqueeze(0).repeat(N, 1, 1),  # [M,2] -> [1,M,2] -> [N,M,2]
    )

    wh = rb - lt
    wh[wh < 0] = 0
    inter = wh[:, :, 0] * wh[:, :, 1]

    area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
    area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
    area1 = area1.unsqueeze(1).expand_as(inter)
    area2 = area2.unsqueeze(0).expand_as(inter)

    iou = inter / (area1 + area2 - inter)
    return iou


def eval_generated_code(code, data, task):
    
    reward = []
    predictions = []
    image_path = os.path.join("./datasets/coco_images/train2017", data["image_path"])
    image = load_image(image_path)
    
    try:
        print("question:", data["question"])
    except:
        print("query:", data["query"])

    try:
        result = execute_code(code, image)
    except: # there are bugs in the generated code
        print("bugs in the generated code")
        reward = -1
        return reward, None
    else:
        data = {
            "question": data["question"],
            "prediction": result,
            "groundtruth": data["answer"],
        }
        reward = check_consistency(data, task)

        return reward, result

from nltk import sent_tokenize
chatgpt_filters = ["sorry", "please", "an ai model", "a language model", "an ai language model", "do not have access to", "it is impossible to"]
def process_llm_outputs(text):
    text = [sent for sent in sent_tokenize(text) if not any(phrase in sent.lower() for phrase in chatgpt_filters)]
    if len(text) == 0:
        return ""
    return " ".join(text)

PROMPT = """Given the question for multimodal task {0}: {1}
Does the following predicted answer have the same meaning as the reference answer in the context of the question?
Predicted Answer: {2}
Reference Answer: {3}
You should compare the answers based on your understanding of the task, question, and answers, rather than relying on some superficial patterns like word overlap.
Directly answer Yes or No.
""".replace("\t", " ").strip("\n").strip()


def check_consistency(data, task):

    question, prediction, groundtruth = str(data["question"]), str(data["prediction"]), str(data["groundtruth"])

    if prediction  == "" or prediction == "None":
        reward = 0
        return reward
    if "grounding" in task:
        if not ((prediction.startswith("[") and prediction.endswith("]") and "ImagePatch" in prediction) # a list of patches\ 
                or (prediction.startswith("ImagePatch") and "ImagePatch" in prediction)): # a single patch
            reward = 0 # format error, should be image patch
            return reward
        # extract all ImagePatch patterns
        groundtruth = re.findall(r"ImagePatch\((.*?)\)", groundtruth) # list of (left, lower, right, upper)
        prediction = re.findall(r"ImagePatch\((.*?)\)", prediction) # list of (left, lower, right, upper)

        # str to tuple
        groundtruth = [tuple(map(int, g.split(","))) for g in groundtruth] # list of (left, lower, right, upper)
        prediction = [tuple(map(int, p.split(","))) for p in prediction] # list of (left, lower, right, upper)
        if len(groundtruth) != len(prediction):
            reward = 0
            return reward
        # compute IOU score which is often used in object detection
        iou = compute_iou(torch.tensor(prediction), torch.tensor(groundtruth)).max(dim=1).values
        print(iou)
        if iou.min(dim=0).values < 0.8: # iou threshold
            reward = 0
            return reward  # if any of the predicted boxes has iou < 0.8 with groundtruth, then it is not a valid prediction
        
        reward = 1

    else: # other tasks: vqa, captioning, inference
        prediction = process_llm_outputs(prediction)
        groundtruth = process_llm_outputs(groundtruth)
        if "ImagePatch" in prediction or prediction.strip() == "":  # do the wrong task or no answer
            reward = 0
            return reward
        prompt = PROMPT.format(task, question, prediction, groundtruth)
        print(prompt)

        from image_patch import llm_query
        while True:
            try:
                response = llm_query(prompt)
            except Exception as e:
                print(e)
            else:
                break
        
        # remove all punctuatoin
        if isinstance(response, str):
            response = re.sub(r"[^a-zA-Z0-9]+", "", response.lower())
        print(response)
        reward = 1 if response == "yes" else 0

    return reward


def abstraction_ablation(query, solution):
    messages=[
        {"role": "user", "content": abstraction_ablation_template.format(
                                        incontext_query=abstraction_ablation_incontext_query,
                                        incontext_solution=abstraction_ablation_incontext_solution,
                                        incontext_name=abstraction_ablation_incontext_name,
                                        incontext_docstring=abstraction_ablation_incontext_docstring,
                                        query=query, 
                                        solution=solution)},
        ]
    response = gpt3_model.query_with_message(messages, model="gpt-3.5-turbo-0613", max_tokens=1024, temperature=0.0)
    name, docstring = response.split("Function docstring:")
    name = name.split("Function name:")[1]
    return name.strip(), docstring.strip()

### abstraction ###
def abstraction(query, solution):

    messages=[
        {"role": "user", "content": abstraction_template.format(
                                        incontext_query=abstraction_incontext_query,
                                        incontext_solution=abstraction_incontext_solution,
                                        incontext_tool=abstraction_incontext_tool,
                                        query=query, 
                                        solution=solution)},
        ]

    response = gpt3_model.query_with_message(messages, model="gpt-4-0613", max_tokens=1024, temperature=0.0)
    # extract tool and api_call from the response
    tool, api_call = response.split("The example to call the tool is: ")
    tool = tool.split("The final generic tool with docstring is:")[1]
    if "```python" and "```" in tool:
        # extract the code between ```python and ``` using regex
        tool = re.findall(r"```python(.*?)```", tool, re.DOTALL)[0]
    tool = tool.strip()
    api_call = api_call.strip("`").strip()
    return tool, api_call



### deduplication ###
def extract_function_name(function):
    '''
    Given a python function, use rule-based method to extract the function name.
    :param function: a python function described in string.
    :return: the function name.
    '''
    
    function = function.strip()
    if function.startswith("def"):
        function = function[3:]
    if function.endswith(":"):
        function = function[:-1]
    function = function.strip()
    function_name = function.split("(")[0].strip()
    return function_name

def extract_function_head(function):
    '''
    Given a python function, use rule-based method to extract the function name.
    :param function: a python function described in string.
    :return: the function name.
    '''
    function = function.strip().split("\n")[0]
    if function.startswith("def"):
        function = function[3:]
    if function.endswith(":"):
        function = function[:-1]
    function_head = function.strip()
    return function_head

def count_args(function_head):
    '''
    Given a python function head, count the number of arguments.
    :param function_head: a python function head.
    :return: the number of arguments.
    '''
    function_head = function_head.strip()
    if function_head.endswith(")"):
        function_head = function_head[:-1]
    if "(" in function_head:
        args = function_head.split("(")[1].strip()
        if args == "":
            return 0
        else:
            return len(args.split(","))
    else:
        return 0

def extract_function_docstring(function):
    '''
    Given a python function, use rule-based method to extract the function docstring.
    :param function: a python function described in string.
    :return:
    '''
    function = function.strip()
    if function.startswith("def"):
        function = function[3:]
    if function.endswith(":"):
        function = function[:-1]
    # return function
    if '"""' in function:
        items = function.split('"""')
    else:
        assert "'''" in function, print(function)
        items = function.split("'''")

    docstring = items[1].strip()
    explanation = docstring.split("\n")[0].strip()
    return (explanation, docstring)


import re
import random
def deduplicate_by_chatgpt(tool_list):

    # random.shuffle(tool_list)

    if len(tool_list) > 5: # devide; otherwise, contexts will be too long
        tool_sublists = [tool_list[i:i+5] for i in range(0, len(tool_list), 5)]
        tool_list = []
        for tool_sublist in tool_sublists:
            tool_sublist = ["No. {}:\n{}\n\n".format(i, tool) for i, tool in enumerate(tool_sublist)]
            prompt = deduplication_template.format("\n\n".join(tool_sublist))
            messages = [
                {"role": "user", "content": prompt}
            ]
            reply = gpt3_model.query_with_message(messages, max_tokens=5)
            
            # print("message:", messages)
            print("reply: {} / {}".format(reply, set([i for i in range(len(tool_sublist))])))
            reply = re.findall(r"\d+", reply) 
            assert len(reply) <= 1, print(messages, "\n", reply)
            reply = int(reply[0]) if len(reply) ==1 else 0 # if all tools are not good, simply choose the first one

            tool_list.append(tool_sublist[reply])

    tool_list = ["No. {}:\n{}\n".format(i, tool) for i, tool in enumerate(tool_list)]
    prompt = deduplication_template.format("\n\n".join(tool_list))
    messages = [
        {"role": "user", "content": prompt},
    ]
    reply = gpt3_model.query_with_message(messages, model="gpt-3.5-turbo-0613", max_tokens=5, temperature=0.0)
    # print("message:", messages)
    print("reply: {} / {}".format(reply, set([i for i in range(len(tool_list))])))
    reply = re.findall(r"\d+", reply)
    assert len(reply) <= 1, print(messages, "\n", reply)
    reply = int(reply[0]) if len(reply) ==1 else 0
    
    return reply


def deduplicate_by_name(all_tools, function_names, function_heads, num_args):
    
    category_head = {}
    category_tool = {}

    # Print the community assignments
    for id, (name, head, num_arg) in enumerate(zip(function_names, function_heads, num_args)):

        if name not in category_head.keys():
            category_head[name] = {} # num_arg: function_head
            category_tool[name] = {} # num_arg: tool
        
        if num_arg not in category_head[name].keys():
            # add {num_arg: head} to the dict
            category_head[name][num_arg] = []
            category_tool[name][num_arg] = []
        # do not add duplicate function
        if head not in category_head[name][num_arg]:
            category_head[name][num_arg].append(head)
            category_tool[name][num_arg].append(id)

    # sort category by key (which is a number)
    category_head = {k: v for k, v in sorted(category_head.items(), key=lambda item: item[0])}

    # flatten category_tool to list
    for name in category_tool.keys():
        for num_arg in category_tool[name].keys():
            if len(category_tool[name][num_arg]) == 1:
                category_tool[name][num_arg] = category_tool[name][num_arg][0] # flatten
                continue
            most_general = deduplicate_by_chatgpt([all_tools[i] for i in category_tool[name][num_arg]])
            category_tool[name][num_arg] = category_tool[name][num_arg][most_general]
            category_head[name][num_arg] = [category_head[name][num_arg][most_general]]

    # flatten the category to category[community_id] = [function_head1, function_head2, ...]
    category_head = {k: [sublist[0] for sublist in v.values()] for k, v in category_head.items()} 
    
    # flatten category_tool to list
    category_tool = [id for ids_with_same_name in category_tool.values() for id in ids_with_same_name.values()]
    
    # verification
    heads = [function_heads[i] for i in category_tool]
    flatted_category_head = [head for sublist in category_head.values() for head in sublist]
    assert set(heads) == set(flatted_category_head), print(len(heads), len(flatted_category_head), set(heads)-set(flatted_category_head))
    
    return category_head, category_tool