from utils import *
import datasets
import random
import numpy as np
from tqdm import tqdm
import pandas as pd

import openai
import warnings
from time import sleep
warnings.filterwarnings("ignore")
import json
import argparse
import os


import datasets
import torch
import requests
from collections import Counter
os.environ["TOKENIZERS_PARALLELISM"] = "false"


# Step 1. Infer on instruction dataset to obtain specific solutions
def generate_code(example, task):
    query = example["question"] if random.uniform(0, 1) < 0.5 else example["instruction"]
    code = forward('codex', prompt=query, input_type="image")
    code = process_code(code)
    # print(code)

    return {
        "question": example["question"],
        "query": query,
        "code": code,
        "task": task
    }


# Step 2. Validate the code
def validate_turbo_code(generated_code, save_path):
    validate_tool = True if "4_valid_tools.csv" in save_path else False
    generated_code.remove_columns(["question"])
    if os.path.exists(save_path):
        generated_code = datasets.Dataset.from_csv(save_path)
    else:
        filtered = []
        for i, data in tqdm(enumerate(generated_code)):
            pred = data["tool_call"] if validate_tool else data["code"]
            assert isinstance(pred, str)
            if "pixel" in pred:
                continue
            reward, prediction = eval_generated_code(pred, data, "vqa")
            if reward == 1:
                filtered.append(i)
                print("{}/{}".format(len(filtered), len(generated_code)))
        print("filtered: {}/{}".format(len(filtered), len(generated_code)))
        generated_code = generated_code.select(filtered)
        generated_code.to_csv(save_path)
    return generated_code


# Step 3. Abstract the specific solutions to general tools
def abstract_tools(specific_solutions, save_dir_path="./results/viper"):

    if os.path.exists(f"{save_dir_path}/3_all_general_tools.csv"):
        general_tools_with_calls = datasets.Dataset.from_csv(f"{save_dir_path}/3_all_general_tools.csv")
    else:
        # print("exit in step 3")
        # exit()
        reserved_features = ["image_path", "question", "query", "answer", "code", "tool_id"]
        all_features = specific_solutions.features.keys()
        specific_solutions = specific_solutions.remove_columns(set(all_features) - set(reserved_features))
        general_tools = specific_solutions.map(lambda x: {"tool_and_call": abstraction(x["query"], x["code"])})
        general_tools_with_calls = general_tools.map(lambda x: {"tool": x["tool_and_call"][0], "call": x["tool_and_call"][1]}).remove_columns(["tool_and_call"])
        general_tools_with_calls.to_csv(f"{save_dir_path}/3_all_general_tools.csv")
    return general_tools_with_calls


# Step 4. Validate the general tools
def validate_tools(all_tools, save_dir_path="./results/viper"):

    def process_to_function(tool, call):
        code = "\n".join([
            "def execute_command(image):\n",
            "\n".join(["\t"+line for line in tool.split("\n")]),
            "\t" + "image_patch = ImagePatch(image)",
            "\t" + f"return {call}"
        ])
        return code

    if os.path.exists(f"{save_dir_path}/4_valid_tools.csv"):
        valid_tools = datasets.Dataset.from_csv(f"{save_dir_path}/4_valid_tools.csv")
    else:
        all_tools = all_tools.map(lambda x: {"tool_call": "\n".join(["from PIL import Image", "from typing import *", "from image_patch import *", process_to_function(x["tool"], x["call"].strip())])})
        valid_tools = validate_turbo_code(all_tools, f"{save_dir_path}/4_valid_tools.csv")

    return valid_tools


# Step 5. Deduplicate the tools
def deduplicate_tools(all_tools, save_dir_path="./results/viper"):

    if os.path.exists(f"{save_dir_path}/5_deduplicated_tools.csv"):
        deduplicated_tools = datasets.Dataset.from_csv(f"{save_dir_path}/5_deduplicated_tools.csv")
    else:
        tool_list = all_tools["tool"]

        selected_tools = []

        function_names = [" ".join(extract_function_name(item).split()) for item in tool_list]
        function_heads = [extract_function_head(item) for item in tool_list]
        num_args = [count_args(item) for item in function_heads]
        function_explanations = [extract_function_docstring(item)[0] for item in tool_list]
        function_queries = all_tools["query"]

        name_embedding = compute_simcse(model, tokenizer, function_names)
        explanation_embedding = compute_simcse(model, tokenizer, function_explanations)
        query_embedding = compute_simcse(model, tokenizer, function_queries)

        name_similarity = torch.nn.functional.cosine_similarity(name_embedding.unsqueeze(1), name_embedding.unsqueeze(0), dim=2)
        explanation_similarity = torch.nn.functional.cosine_similarity(explanation_embedding.unsqueeze(1), explanation_embedding.unsqueeze(0), dim=2)
        query_similarity = torch.nn.functional.cosine_similarity(query_embedding.unsqueeze(1), query_embedding.unsqueeze(0), dim=2)

        # # count how many points are higher than 0.8 in each similarity matrix
        # print(torch.sum(name_similarity > 0.9)/2, torch.sum(explanation_similarity > 0.8)/2, torch.sum(query_similarity > 0.8)/2)

        conjunction_matrix = torch.where(
            torch.where(name_similarity + explanation_similarity + query_similarity > 2, 1, 0) \
            + torch.where(name_similarity > 0.9, 1, 0) > 0, # `OR` operation
            1, 0
            ).numpy()

        category_head, category_node = deduplicate_by_name(tool_list, conjunction_matrix, function_heads, num_args)     

        # write to json
        with open(f"{save_dir_path}/category_head.json", "w") as f:
            json.dump(category_head, f, indent=4)
        
        deduplicated_tools = all_tools.select(category_node)
        deduplicated_tools.to_csv(f"{save_dir_path}/5_deduplicated_tool.csv")

    return deduplicated_tools
            

# Check if continue
def check_if_continue(specific_solutions, extended_size_per_epoch, current_save_dir_path):
    if os.path.exists(f"{current_save_dir_path}/4_valid_tools.csv"):
        deduplicated_tools = pd.read_csv(f"{current_save_dir_path}/4_valid_tools.csv")
    else:
        valid_specific_solutions = validate_turbo_code(specific_solutions, save_path=f"{current_save_dir_path}/2_valid_specific_solutions.csv")
        general_tools = abstract_tools(valid_specific_solutions, current_save_dir_path)
        valid_tools = validate_tools(general_tools, current_save_dir_path)
        # deduplicated_tools = deduplicate_tools(valid_tools, current_save_dir_path)
    
    return True

# def check_if_continue(specific_solutions, extended_size_per_epoch, current_save_dir_path):
#     if os.path.exists(f"{current_save_dir_path}/2_valid_specific_solutions.csv"):
#         valid_specific_solutions = pd.read_csv(f"{current_save_dir_path}/2_valid_specific_solutions.csv")
#     else:
#         valid_specific_solutions = validate_turbo_code(specific_solutions, save_path=f"{current_save_dir_path}/2_valid_specific_solutions.csv")
#     return True


#################################################
# set up ssh connection
import paramiko
import json
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(hostname="65.130.193.108", port=33827, username="root", password="ylf123")
#################################################
print("set up ssh connection")
def filter_direct_query(query, max_tokens=1):
    print("query:", query)
    query = query.replace('"', "'")
    while True:
        command = f"""/opt/conda/bin/python turbo_filter.py --query "{query}" """
        stdin, stdout, stderr = ssh.exec_command(command)
        output = stdout.read().decode('utf-8').strip().strip("\n")
        err = stderr.read().decode('utf-8').strip().strip("\n")
        if len(err.strip("\n").strip())>0:
            
            print("stderr:", err)
            
            continue
        # response = {"choices": [{"message": {"content": output}}]}
        break
    print("output:", output)
    return int(output)



def select_samples(instruction_dataset, total_extended_size=2000, initial_extended_size=1000, extended_size_per_epoch=100, save_dir_path="./results/viper"):
    if not os.path.exists(f"{save_dir_path}/sampled_ids.pt"):
        sampled_ids = {}
        
        epochs = (total_extended_size - initial_extended_size) // extended_size_per_epoch + 1

        # construct a embedding mapping
        all_embeddings = compute_simcse(model, tokenizer, instruction_dataset["question"])
        id_embedding_pairs = {tool_id: embedding for tool_id, embedding in zip(instruction_dataset["tool_id"], all_embeddings)}
        included_ids = []

        for epoch in range(epochs):
            if os.path.exists(f"{save_dir_path}/sampled_ids_{epoch}.pt"):
                sampled_ids[epoch] = torch.load(f"{save_dir_path}/sampled_ids_{epoch}.pt")[epoch]
                included_ids += sampled_ids[epoch]
                continue

            if epoch == 0:
                # update sampled_ids and highest_similarity_score
                sampled_ids[epoch] = random.sample(list(id_embedding_pairs.keys()), initial_extended_size)
                torch.save(sampled_ids, f"{save_dir_path}/sampled_ids_{epoch}.pt")
            else:
                # update embeddings
                included_embeddings = torch.tensor([id_embedding_pairs[id].tolist() for id in included_ids])
                remaining_embeddings = torch.tensor([id_embedding_pairs[id].tolist() for id in instruction_dataset["tool_id"] 
                                                    if id not in included_ids])
                # record the tool_id corresponding to each embedding
                remaining_ids = [id for id in instruction_dataset["tool_id"] if id not in included_ids]
                # update remaining_instruction_dataset
                remaining_instruction_dataset = instruction_dataset.filter(lambda x: x["tool_id"] not in included_ids)
                assert remaining_ids == remaining_instruction_dataset["tool_id"] # ensure the order is the same
                # sort by similarity using computed embeddings
                sorted_remaining_indices, sorted_similarity_scores = \
                                        sort_by_similarity(included_embeddings, remaining_embeddings)
                # cut off those with similarity scores >= 0.7
                sorted_remaining_indices = sorted_remaining_indices[sorted_similarity_scores < 0.7]
                sorted_instruction_dataset = remaining_instruction_dataset.select(sorted_remaining_indices)
                # filter out those which can be directly queried by llm_query
                valid_ids = []
                invalid_ids = []
                for data in sorted_instruction_dataset:
                    if filter_direct_query(data["question"]) == 1:
                        valid_ids.append(data["tool_id"])
                        if len(valid_ids) == extended_size_per_epoch:
                            break
                    else:
                        invalid_ids.append(data["tool_id"])
                print("valid rate: {}/{}".format(len(valid_ids), len(valid_ids)+len(invalid_ids)))
                
                # update sampled_ids and highest_similarity_score
                sampled_ids[epoch] = valid_ids
                highest_similarity_score = sorted_similarity_scores[len(valid_ids) + len(invalid_ids)]
                print(instruction_dataset.filter(lambda x: x["tool_id"] in valid_ids)["question"])
                
            # update included_ids
            included_ids += sampled_ids[epoch]
            torch.save(sampled_ids, f"{save_dir_path}/sampled_ids_{epoch}.pt")
        torch.save(sampled_ids, f"{save_dir_path}/sampled_ids.pt")
    else:
        sampled_ids = torch.load(f"{save_dir_path}/sampled_ids.pt")
    return sampled_ids


def construct_toolbase(instruction_dataset, total_extended_size, initial_extended_size, extended_size_per_epoch, save_dir_path="./results/viper"):
    
    if os.path.exists(f"{save_dir_path}/5_deduplicated_tool.csv"):
        toolbase = datasets.Dataset.from_csv(f"{save_dir_path}/5_deduplicated_tool.csv")
    else:

        sampled_ids = select_samples(instruction_dataset, total_extended_size=total_extended_size, initial_extended_size=initial_extended_size, extended_size_per_epoch=extended_size_per_epoch, save_dir_path=save_dir_path)
        all_solutions = []
        arxived_ids = []
        epochs = (total_extended_size - initial_extended_size) // extended_size_per_epoch + 1
        for epoch in range(epochs):
            print("epoch:", epoch)
            current_save_dir_path = f"{save_dir_path}/epoch_{epoch}"
            os.makedirs(current_save_dir_path, exist_ok=True)
            if os.path.exists(f"{current_save_dir_path}/1_raw_specific_solutions.json"):
                print("load")
                specific_solutions = datasets.Dataset.from_json(f"{current_save_dir_path}/1_raw_specific_solutions.json")
            else:
                sampled_dataset = instruction_dataset.filter(lambda x: x["tool_id"] in sampled_ids[epoch])
                specific_solutions = sampled_dataset.map(lambda x: generate_code(x, "vqa"))
                specific_solutions.to_json(f"{current_save_dir_path}/1_raw_specific_solutions.json")
            
            # check if it is necessary to continue
            global flag
            if not flag:
                init_vision_models()
                flag = True
            if not check_if_continue(specific_solutions, extended_size_per_epoch, current_save_dir_path):
                break

        # all_solutions = datasets.concatenate_datasets([datasets.Dataset.from_json(f"{save_dir_path}/epoch_{epoch}/1_raw_specific_solutions.json") for epoch in range(epochs)])
        toolbase = datasets.concatenate_datasets([datasets.Dataset.from_csv(f"{save_dir_path}/epoch_{epoch}/4_valid_tools.csv") for epoch in range(epochs)])
        toolbase = deduplicate_tools(toolbase, save_dir_path)
        toolbase = toolbase.map(process_toolbase)
    return toolbase

def construct_complex_toolbase(instruction_dataset, simple_toolbase):
    # construct conplex tools, using the instruction samples that turbo fails to handle
    # 1. load all question viper fails before
    all_specific_solutions = datasets.concatenate_datasets([datasets.Dataset.from_json(f"{save_dir_path}/epoch_{epoch}/1_raw_specific_solutions.json") for epoch in range(epochs)])
    all_valid_solutions = datasets.concatenate_datasets([datasets.Dataset.from_json(f"{save_dir_path}/epoch_{epoch}/2_valid_specific_solutions.json") for epoch in range(epochs)])
    all_invalid_solutions_ids = list(set(all_specific_solutions["tool_id"]) - set(all_valid_solutions["tool_id"]))
    difficult_questions = instruction_dataset.filter(lambda x: x["tool_id"] in all_invalid_solutions_ids)

    # regenerate solution for these invalid solutions, with the help of toolbase
    
    difficult_solutions = difficult_questions.map(lambda x: generate_code(x, "vqa"))

def construct_vector_library(model, tokenizer, toolbase, save_dir_path):
    
    if os.path.exists(f"{save_dir_path}/vector_library.pt"):
        vector_library = torch.load(f"{save_dir_path}/vector_library.pt")
    else:
        all_tools = toolbase["tool"]
        function_names = [" ".join(extract_function_name(tool).split("_")) for tool in all_tools]
        function_explanations = [extract_function_docstring(tool)[0] for tool in all_tools]
        function_docstrings = [extract_function_docstring(tool)[1] for tool in all_tools]
        function_queries = toolbase["query"]


        name_embedding = compute_simcse(model, tokenizer, function_names)
        explanation_embedding = compute_simcse(model, tokenizer, function_explanations)
        docstring_embedding = compute_simcse(model, tokenizer, function_docstrings)
        query_embedding = compute_simcse(model, tokenizer, function_queries)

        # save as a dict
        vector_library = {
            "name_embedding": name_embedding,
            "explanation_embedding": explanation_embedding,
            "docstring_embedding": docstring_embedding,
            "query_embedding": query_embedding
        }

        torch.save(vector_library, f"{save_dir_path}/vector_library.pt")

    return vector_library

flag = False
if __name__ == "__main__":
    vqa_dataset = datasets.Dataset.from_json('./datasets/vqa_dataset.json')
    llava_dataset = datasets.Dataset.from_json("/data/private/yuanlifan/multitool/datasets/conversation_58k_flattened.json")

    reserved_columns = ["image_path", "question", "instruction", "answer", "id"]

    # vqa_dataset = vqa_dataset.add_column("tool_id", ["vqa_" + str(i) for i in range(len(vqa_dataset))]).remove_columns(set(vqa_dataset.features.keys())-set(reserved_columns))
    # llava_dataset = llava_dataset.add_column("tool_id", ["llava_" + str(i) for i in range(len(llava_dataset))]).remove_columns(set(llava_dataset.features.keys())-set(reserved_columns))
    # vqa_dataset.to_json("./datasets/vqa_dataset.json")
    # llava_dataset.to_json("/data/private/yuanlifan/multitool/datasets/conversation_58k_flattened.json")

    
    instruction_dataset = datasets.concatenate_datasets([
                    vqa_dataset,
                    llava_dataset                  
                    ])

    from transformers import AutoTokenizer, AutoModel
    tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-roberta-large")
    model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-roberta-large").cuda()

    toolbase = construct_toolbase(instruction_dataset, total_extended_size=2000, initial_extended_size=1000, extended_size_per_epoch=100, save_dir_path="./results/viper_ablation")
    vector_library = construct_vector_library(model, tokenizer, toolbase, "results/viper_ablation")
    print(len(toolbase["tool"]), vector_library["name_embedding"].shape)