import json
import numpy as np
import faiss
import base64
import requests
from datetime import datetime
import os
import shutil
from collections import defaultdict


def find_indices_to_remove(data_list):
    """
    Find the dictionary index to be deleted: keep the one with the highest score in the same query group

    Args:
        data_list (list[dict]): A list of dictionaries, each containing a 'query' and a 'score' field

    Returns:
        list[int]: The index list of the dictionary to be deleted in the original list
    """
    # 1. Group by query+plan, save tuples (index, dictionary)[[5]][[9]]
    groups = defaultdict(list)
    for idx, d in enumerate(data_list):
        query = d['query']
        plan = d['plan']  # plan is a list
        k = query + '#' + ','.join(plan)
        groups[k].append((idx, d))

    delete_indices = []

    # 2. Traverse each query+plan group, find the highest score and mark the index for deletion[[2]][[6]]
    for k, elements in groups.items():
        if not elements:
            continue

        # Get all score values this group
        scores = [d['Score'] for idx, d in elements]
        max_score = max(scores)

        # Get all the score values the group and filter out all the candidate elements with the highest score[[8]]
        candidates = [(idx, d) for idx, d in elements if d['Score'] == max_score]

        # Only the first candidate element is kept, and the rest are marked for deletion[[2]]
        keep_idx = candidates[0][0]
        current_deletions = [idx for idx, _ in elements if idx != keep_idx]
        delete_indices.extend(current_deletions)

    return delete_indices


# Convert query to embedding
def trion20_onnx_embedding_clip(querys):
    """
    get embedding of input query
    :param querys: input query
    :return: embedding
    """
    return np.zeros(shape=(1, 1024))


# Reading JSON Files
def load_json(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data


# Calculate time difference
def calculate_time_diff(time1, time2):
    time_format = "%Y-%m-%d %H:%M:%S"
    time1 = datetime.strptime(time1, time_format)
    time2 = datetime.strptime(time2, time_format)
    time_diff = time2 - time1
    return time_diff.total_seconds()


# Update faiss and json libraries when adding memory
def add_new_memory(query_data, query_score, query_timestep, recalled_memories, add_vectors, loaded_index, loaded_json,
                   memory_faiss_path, memory_json_path):
    # Add vectors and custom IDs to the index
    loaded_index.add(add_vectors)

    # Save the index to the local Faiss library
    faiss.write_index(loaded_index, memory_faiss_path)

    # Write the updated data back to the original JSON file
    query_data['Score'] = query_score
    query_data['Time'] = query_timestep
    # query_data['recalled_memory'] = recalled_memories
    # del query_data["recalled_memory"]
    # Safely delete the specified key
    value_default = query_data.pop('recalled_memory', None)
    loaded_json.append(query_data)
    print('add new memory: ', query_data)
    with open(memory_json_path, "w", encoding="utf-8") as f:
        json.dump(loaded_json, f, ensure_ascii=False, indent=4)


# Update faiss library when deleting memory
def update_faiss(faiss_vectors, memory_faiss_path):
    # Create a Faiss index
    dimension = faiss_vectors.shape[1]
    index = faiss.IndexFlatIP(dimension)
    index.add(np.array(faiss_vectors).astype('float32'))

    # Save the Faiss index locally
    faiss.write_index(index, memory_faiss_path)


# Update json and faiss libraries when deleting memory
def delete_low_score_memory_json(filtered_index, loaded_json, memory_json_path, memory_delete_json_path,
                                 memory_vector_path, memory_faiss_path):
    # Delete the record at the specified index and return the deleted record
    removed_record = loaded_json.pop(filtered_index)
    print('removed_record', removed_record)
    removed_record["system_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    # JSON of record deletion
    with open(memory_delete_json_path, "r+", encoding="utf-8") as delete_f:
        delete_data_list = json.load(delete_f)
        delete_data_list.append(removed_record)

        # Go back to the beginning of the file and truncate the file contents
        delete_f.seek(0)
        delete_f.truncate()
        json.dump(delete_data_list, delete_f, ensure_ascii=False, indent=4)

    # Write the updated data back to the original JSON file
    with open(memory_json_path, "w", encoding="utf-8") as f:
        json.dump(loaded_json, f, ensure_ascii=False, indent=4)

    # Update vector array
    faiss_vectors = np.load(memory_vector_path)
    faiss_vectors = np.delete(faiss_vectors, [filtered_index], axis=0)
    np.save(memory_vector_path, faiss_vectors)

    # Update faiss library when deleting memory
    update_faiss(faiss_vectors, memory_faiss_path)


# Memory initialization function
def memory_initialize(memory_json_path, memory_delete_json_path, memory_vector_path, memory_faiss_path):
    # Create an empty json file
    memory_dir = "/".join(memory_json_path.split('/')[:-1])
    if not os.path.exists(memory_dir):
        os.makedirs(memory_dir, exist_ok=True)
    with open(memory_json_path, "w", encoding="utf-8") as json_file:
        json.dump([], json_file)
    with open(memory_delete_json_path, "w", encoding="utf-8") as json_file:
        json.dump([], json_file)

    # Create an empty faiss index file
    index = faiss.IndexFlatIP(1024)
    faiss.write_index(index, memory_faiss_path)

    # Create an empty vector file
    query_faiss_vectors = np.empty((0, 1024))
    np.save(memory_vector_path, query_faiss_vectors)


def memory_initialize_from_file(origin_data_path, memory_json_path, memory_delete_json_path, memory_vector_path,
                                memory_faiss_path):
    # Create an empty json file
    memory_dir = "/".join(memory_json_path.split('/')[:-1])
    if not os.path.exists(memory_dir):
        os.makedirs(memory_dir, exist_ok=True)
    with open(memory_json_path, "w", encoding="utf-8") as json_file:
        json.dump([], json_file)
    with open(memory_delete_json_path, "w", encoding="utf-8") as json_file:
        json.dump([], json_file)

    # Create an empty faiss index file
    index = faiss.IndexFlatIP(1024)
    faiss.write_index(index, memory_faiss_path)

    # Create an empty vector file
    query_faiss_vectors = np.empty((0, 1024))
    # np.save(memory_vector_path, query_faiss_vectors)

    loaded_index = faiss.read_index(memory_faiss_path)
    loaded_json = load_json(memory_json_path)
    with open(origin_data_path, 'r') as file:
        lines = file.readlines()
        for line in lines:
            ori_data = json.loads(line)
            query = ori_data['query']
            score = ori_data['Score']
            time = ori_data['Time']
            model_name = ori_data['model_name']
            plan = ori_data['plan']
            response = ori_data['response']

            query_data = dict()
            query_data['model_name'] = model_name
            query_data['query'] = query
            query_data['plan'] = plan
            query_data['response'] = response
            query_data['Score'] = score
            query_data['Time'] = time

            input_embeddings = trion20_onnx_embedding_clip([query]).astype('float32')
            # Update vector array
            query_faiss_vectors = np.vstack((query_faiss_vectors, input_embeddings))
            # Update index
            loaded_index.add(input_embeddings)
            loaded_json.append(query_data)
    np.save(memory_vector_path, query_faiss_vectors)
    faiss.write_index(loaded_index, memory_faiss_path)
    with open(memory_json_path, 'w', encoding='utf-8') as f:
        json.dump(loaded_json, f, ensure_ascii=False, indent=4)


def longestCommonSubsequence(list_a, list_b) -> int:
    m = len(list_a)
    n = len(list_b)
    res = ""
    # Record the longest common subsequence

    if m * n == 0:
        return 0

    dp = [[0] * (n + 1) for _ in range(m + 1)]
    # dp[i][j] represents the longest common subsequence length of list_a[:i] and list_b[:j]

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if list_a[i - 1] == list_b[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
                res += list_a[i - 1]
            else:
                dp[i][j] = max([dp[i - 1][j], dp[i][j - 1]])
    return dp[m][n]


# Update memory main function
def update_agent_memory(query_data, recalled_memories, current_memory_cnt, query_score, query_timestep,
                        memory_score_list, memory_json_path, memory_vector_path, memory_faiss_path):
    time_weight = 1.0  # Time Weight
    answer_score_weight = 1.0  # Score Weight
    plan_score_weight = 1.0

    # Calculate the upper and lower bounds to update the thresholds
    upper_threshold = 1.0
    lower_threshold = 0.0
    if len(memory_score_list) > 0:
        # Find all paths with values greater than or less than the threshold and update the memory
        all_seq_rewards = np.asarray(memory_score_list)
        mean_seq_rewards = np.mean(all_seq_rewards)
        std_seq_rewards = np.std(all_seq_rewards)
        upper_threshold = mean_seq_rewards + std_seq_rewards
        # If the current upper_bound is less than 0, set upper_bound to 0. Sequences less than 0 cannot enter memory.
        upper_threshold = max(upper_threshold, 0)
        lower_threshold = mean_seq_rewards - std_seq_rewards

    if current_memory_cnt < 100:
        print('The memory in the file is less than 100. Add new memory.')
        print('new_memory: ', query_data)

        # Add new memory to json and faiss files
        loaded_index = faiss.read_index(memory_faiss_path)
        loaded_json = load_json(memory_json_path)
        q_texts = [query_data["query"]]
        input_embeddings = trion20_onnx_embedding_clip(q_texts).astype('float32')
        add_new_memory(query_data, query_score, query_timestep, recalled_memories, input_embeddings, loaded_index,
                       loaded_json, memory_faiss_path, memory_json_path)
        # Update vector array
        faiss_vectors = np.load(memory_vector_path)
        faiss_vectors = np.vstack((faiss_vectors, input_embeddings))
        np.save(memory_vector_path, faiss_vectors)

        # Update the recalled memory score and time, and add memory if there is no duplication
        if len(recalled_memories) > 0:
            # Perform softmax operation on similarity
            def softmax(x):
                exp_x = np.exp(x - np.max(x))
                return exp_x / np.sum(exp_x)

            recalled_answers = [recalled_memory["response"] for recalled_memory in recalled_memories]
            recalled_answers_embeddings = trion20_onnx_embedding_clip(recalled_answers).astype('float32')
            query_answer_embeddings = trion20_onnx_embedding_clip([query_data["response"]]).astype('float32')

            query_plan = query_data["plan"]
            recalled_plans = [recalled_memory["plan"] for recalled_memory in recalled_memories]

            # Calculate the cosine similarity of the answer
            answer_similarities = np.dot(recalled_answers_embeddings,
                                         query_answer_embeddings.T).flatten()
            answer_softmax_scores = softmax(answer_similarities).tolist()

            # Calculate the similarity of plans
            plan_similaritis = []
            for p in recalled_plans:
                if len(p) == 0 or len(query_plan) == 0:
                    plan_similaritis.append(0)
                else:
                    longest_sub_plan_len = longestCommonSubsequence(p, query_plan)
                    longest_plan_len = max(len(p), len(query_plan))
                    plan_sim = longest_sub_plan_len * 1.0 / longest_plan_len
                    plan_similaritis.append(plan_sim)
            plan_similaritis = np.asarray(plan_similaritis)
            plan_softmax_scores = softmax(plan_similaritis).tolist()

            if query_score >= upper_threshold:
                print('Increase the recalled memories.')
                loaded_json = load_json(memory_json_path)
                # Update the time and score of the recalled memory in the json library
                for i in range(len(recalled_memories)):
                    filtered_index = recalled_memories[i]['Index']
                    answer_score_i = answer_softmax_scores[i]
                    plan_score_i = plan_softmax_scores[i]
                    loaded_json[filtered_index]['Time'] = query_timestep
                    answer_diff = answer_score_i * abs(query_score - upper_threshold)
                    plan_diff = plan_score_i * abs(query_score - upper_threshold)
                    score_diff = answer_score_weight * answer_diff + plan_score_weight * plan_diff
                    loaded_json[filtered_index]['Score'] += score_diff

                # Write the updated data back to the original JSON file
                with open(memory_json_path, "w", encoding="utf-8") as f:
                    json.dump(loaded_json, f, ensure_ascii=False, indent=4)

            elif query_score < lower_threshold:
                print('Decrease the recalled memories.')

                loaded_json = load_json(memory_json_path)
                for i in range(len(recalled_memories)):
                    filtered_index = recalled_memories[i]['Index']
                    answer_score_i = answer_softmax_scores[i]
                    plan_score_i = plan_softmax_scores[i]
                    time_diff = - abs(loaded_json[filtered_index]['Time'] - query_timestep)
                    answer_score_diff = - answer_score_i * abs(query_score - upper_threshold)
                    plan_score_diff = - plan_score_i * abs(query_score - upper_threshold)
                    score_diff = time_weight * time_diff + answer_score_weight * answer_score_diff + plan_score_weight * plan_score_diff
                    loaded_json[filtered_index]['Score'] += score_diff

                # Write the updated data back to the original JSON file
                with open(memory_json_path, "w", encoding="utf-8") as f:
                    json.dump(loaded_json, f, ensure_ascii=False, indent=4)

    else:
        # Update the recalled memory score and time, and add memory if there is no duplication
        if len(recalled_memories) > 0:
            # Perform softmax operation on similarity
            def softmax(x):
                exp_x = np.exp(x - np.max(x))
                return exp_x / np.sum(exp_x)

            recalled_answers = [recalled_memory["response"] for recalled_memory in recalled_memories]
            recalled_answers_embeddings = trion20_onnx_embedding_clip(recalled_answers).astype('float32')
            query_answer_embeddings = trion20_onnx_embedding_clip([query_data["response"]]).astype('float32')

            query_plan = query_data["plan"]
            recalled_plans = [recalled_memory["plan"] for recalled_memory in recalled_memories]

            # Calculate the cosine similarity of the answer
            answer_similarities = np.dot(recalled_answers_embeddings,
                                         query_answer_embeddings.T).flatten()
            answer_softmax_scores = softmax(answer_similarities).tolist()

            # Calculate the similarity of plans
            plan_similaritis = []
            for p in recalled_plans:
                if len(p) == 0 or len(query_plan) == 0:
                    plan_similaritis.append(0)
                else:
                    longest_sub_plan_len = longestCommonSubsequence(p, query_plan)
                    longest_plan_len = max(len(p), len(query_plan))
                    plan_sim = longest_sub_plan_len * 1.0 / longest_plan_len
                    plan_similaritis.append(plan_sim)
            plan_similaritis = np.asarray(plan_similaritis)
            plan_softmax_scores = softmax(plan_similaritis).tolist()

            if query_score >= upper_threshold:
                print('Increase the recalled memories.')

                loaded_json = load_json(memory_json_path)
                for i in range(len(recalled_memories)):
                    filtered_index = recalled_memories[i]['Index']
                    answer_score_i = answer_softmax_scores[i]
                    plan_score_i = plan_softmax_scores[i]
                    loaded_json[filtered_index]['Time'] = query_timestep
                    answer_diff = answer_score_i * abs(query_score - upper_threshold)
                    plan_diff = plan_score_i * abs(query_score - upper_threshold)
                    score_diff = answer_score_weight * answer_diff + plan_score_weight * plan_diff
                    loaded_json[filtered_index]['Score'] += score_diff

                # Write the updated data back to the original JSON file
                with open(memory_json_path, "w", encoding="utf-8") as f:
                    json.dump(loaded_json, f, ensure_ascii=False, indent=4)

                # Add new memory to json and faiss files
                loaded_index = faiss.read_index(memory_faiss_path)
                loaded_json = load_json(memory_json_path)
                q_texts = [query_data["query"]]
                input_embeddings = trion20_onnx_embedding_clip(q_texts).astype('float32')
                add_new_memory(query_data, query_score, query_timestep, recalled_memories, input_embeddings,
                               loaded_index, loaded_json, memory_faiss_path, memory_json_path)
                # Update vector array
                faiss_vectors = np.load(memory_vector_path)
                faiss_vectors = np.vstack((faiss_vectors, input_embeddings))
                np.save(memory_vector_path, faiss_vectors)

            elif query_score < lower_threshold:
                print('Decrease the recalled memories.')

                loaded_json = load_json(memory_json_path)
                for i in range(len(recalled_memories)):
                    filtered_index = recalled_memories[i]['Index']
                    answer_score_i = answer_softmax_scores[i]
                    plan_score_i = plan_softmax_scores[i]
                    # time_diff = - abs(calculate_time_diff(loaded_json[filtered_index]['Time'], query_timestep)) # 用于timestep为系统时间的情况
                    time_diff = - abs(loaded_json[filtered_index]['Time'] - query_timestep)
                    answer_score_diff = - answer_score_i * abs(query_score - upper_threshold)
                    plan_score_diff = - plan_score_i * abs(query_score - upper_threshold)
                    score_diff = time_weight * time_diff + answer_score_weight * answer_score_diff + plan_score_weight * plan_score_diff
                    loaded_json[filtered_index]['Score'] += score_diff

                with open(memory_json_path, "w", encoding="utf-8") as f:
                    json.dump(loaded_json, f, ensure_ascii=False, indent=4)

        else:
            print('There is no recalled memory. identify if new memory match upper bound...')

            if query_score >= upper_threshold:
                # Add new memory to json and faiss files
                loaded_index = faiss.read_index(memory_faiss_path)
                loaded_json = load_json(memory_json_path)
                q_texts = [query_data["query"]]
                input_embeddings = trion20_onnx_embedding_clip(q_texts).astype('float32')
                add_new_memory(query_data, query_score, query_timestep, recalled_memories, input_embeddings,
                               loaded_index, loaded_json, memory_faiss_path, memory_json_path)
                # Update vector array
                faiss_vectors = np.load(memory_vector_path)
                faiss_vectors = np.vstack((faiss_vectors, input_embeddings))
                np.save(memory_vector_path, faiss_vectors)
