import json

from sentence_transformers import SentenceTransformer
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage, dendrogram
import pickle
from openai import OpenAI
import os
import copy
from metaflow_workbench import *



init_temperature = 0.7
# v3 for no priviledged
# sft for sft
version = "workbench_sft"
generator_openai_client = OpenAI(
    api_key="",
    base_url="http://localhost:8888/v1",
)
generator_model_name = 'metaflow_generator'
ICL_number = 0
distance_bound = 0.5
max_try=2


agent_type_list = ["react", "code_lines_ratio"]
reward_func_list = [make_correct_ratio_reward_func(agent_type) for agent_type in agent_type_list]
reward_weights = [0.7, 0.3]


print('agent_type_list:', agent_type_list)
print('reward_func_list:', reward_func_list)
print('num_repetitions:', num_repetitions)
print('max_try', max_try)
# ====================================


leaf_node_path = './data/split/leaf_nodes_train.pkl'

with open(leaf_node_path, 'rb') as fp:
    leaf_nodes = pickle.load(fp)
for k, v in leaf_nodes.items():
    v['id'] = k

LeafId2LeafNodes = {id : node for id, node in leaf_nodes.items()}

embedding_model_path = 'models/sentence-transformers/all-MiniLM-L6-v2'
embedding_model = SentenceTransformer(embedding_model_path)
LeafId2LeafEmb = {node['id'] : embedding_model.encode(node['base_template'] + node['query'], show_progress_bar=False) for node in leaf_nodes.values()}

# ================================
def call_llm(messages: list[dict], client, model='gpt-4o-mini', temperature=0.0, max_tokens=8192) -> str:
    """
    Call an LLM with a history of messages and return the response.
    """
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=temperature,
        top_p=1.0,
        n=1,
        frequency_penalty=0.0,
        presence_penalty=0.0,
        logit_bias={},
        max_tokens=max_tokens,
        # seed=123
    )
    text_response = ""
    if response.choices:
        text_response = response.choices[0].message.content
    return text_response


def make_meta_id(id1, id2):
    return '|'.join(sorted(set(id1.split('|') + id2.split('|'))))



def build_GRPO_input(node1, node2):
    if 'toolcall_and_results' in node1:
        workflow1 = node1['toolcall_and_results']
    elif 'workflow' in node1:
        workflow1 = node1['workflow']
    else:
        raise Exception(f"No workflow found in node {node1}")

    if 'toolcall_and_results' in node2:
        workflow2 = node2['toolcall_and_results']
    elif 'workflow' in node2:
        workflow2 = node2['workflow']
    else:
        raise Exception(f"No workflow found in node {node2}")

    if not isinstance(workflow1, str):
        workflow1 = json.dumps(workflow1, indent=2)
    if not isinstance(workflow2, str):
        workflow2 = json.dumps(workflow2,indent=2)

    sample = {
        'query1': node1['query'],
        'workflow1': workflow1,
        'query2': node2['query'],
        'workflow2': workflow2,
        'leaf_ids' : set(node1['id'].split('|') + node2['id'].split('|'))
    }
    return sample


def merge_nodes(node1, node2, max_try=max_try):
    sample = build_GRPO_input(node1, node2)
    sample['workflow1_lines'] = count_static_tool_calls(sample['workflow1'])
    sample['workflow2_lines'] = count_static_tool_calls(sample['workflow2'])


    prompt = make_summarization_prompt(sample, ICL_number=ICL_number, leaf_nodes=leaf_nodes)['prompt']

    temperature = init_temperature

    candidate_meta_samples = []

    for iter in range(max_try):
        is_success = False
        for _ in range(5):
            try:
                completion = call_llm(prompt, client=generator_openai_client, model=generator_model_name, temperature=temperature)
                is_success = True
                break
            except Exception as e:
                print(f'Build metaflow Error in {prompt}', e)
        if is_success == False:
            continue

        completions = [[{'content': completion}]]

        leaf_ids = [sample['leaf_ids']]
        workflow1_lines = [sample['workflow1_lines']]
        workflow2_lines = [sample['workflow2_lines']]

        reward_list = [reward_func(completions, leaf_ids=leaf_ids, workflow1_lines=workflow1_lines, workflow2_lines=workflow2_lines)[0]
            for reward_func, weight in zip(reward_func_list, reward_weights)]

        reward = np.sum([weight * reward_i for reward_i, weight in zip(reward_list, reward_weights)])


        if reward > 0.0:

            pattern_meta_query = r'\*\*Meta Query\*\*:\s*(.*?)\s*\*\*Meta Workflow\*\*'
            meta_query_match = re.search(pattern_meta_query, completion, re.DOTALL)
            meta_query = meta_query_match.group(1).strip() if meta_query_match else None

            pattern_meta_workflow = r'\*\*Meta Workflow\*\*:\s*```json\s*([\s\S]*?)```'
            meta_workflow_match = re.search(pattern_meta_workflow, completion, re.DOTALL)
            meta_workflow = meta_workflow_match.group(1).strip() if meta_workflow_match else None

            meta_workflow = json.loads(meta_workflow)
            # ==============================
            # ================================
            meta_node = {
                'id': make_meta_id(node1['id'], node2['id']),
                'query': meta_query,
                'workflow': meta_workflow,
                'reward': reward,
                'source': [node1, node2]
            }

            for i in range(len(agent_type_list)):
                meta_node[f"reward_{agent_type_list[i]}"] = reward_list[i]

            candidate_meta_samples.append(meta_node)
            # ============================

        temperature -= 0.1
        if reward >= 0.99:
            break

    max_reward_sample = max(
        candidate_meta_samples,
        key=lambda x: (x['reward'], len(x['workflow']))
    )

    return max_reward_sample


embeddings = list(LeafId2LeafEmb.values())

distance_matrix = pdist(embeddings, metric='cosine')

Z = linkage(distance_matrix, method='average')
merge_steps = []
distance_list = []
leaf_ids = list(leaf_nodes.keys())

all_nodes = copy.deepcopy(leaf_nodes)

print("Each merge step:")
for i, merge in enumerate(Z):
    cluster_1, cluster_2, distance, sample_count = merge
    
    
    if distance > distance_bound:
        break


    node1_id = leaf_ids[int(cluster_1)] if cluster_1 < len(leaf_ids) else merge_steps[int(cluster_1) - len(leaf_ids)]
    node2_id = leaf_ids[int(cluster_2)] if cluster_2 < len(leaf_ids) else merge_steps[int(cluster_2) - len(leaf_ids)]


    node1 = all_nodes[node1_id]
    node2 = all_nodes[node2_id]
    meta_node = merge_nodes(node1, node2)

    meta_node_id = make_meta_id(node1_id, node2_id)


    all_nodes[meta_node_id] = meta_node

    print(f"Step {i + 1}:")
    print(f"  Merged clusters {int(cluster_1)} and {int(cluster_2)} with distance {distance:.3f}.")
    print(f"  New cluster contains {int(sample_count)} samples.")
    print(f"  Merged nodes: {meta_node_id}.")
    print(f"  Meta node reward: {meta_node['reward']}")
    print(f"  Meta node workflow: {meta_node['workflow']}")


    merge_steps.append(meta_node_id)
    distance_list.append(distance)
# =====================

with open(f'./input/metaflow_tree_{version}.pkl', 'wb') as fp:
    pickle.dump(all_nodes, fp)


print('END')