import json

from sentence_transformers import SentenceTransformer
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage, dendrogram
import pickle
from openai import OpenAI
import os
import copy
from train_metaflow_generator_GRPO import *


init_temperature = 0.7


api_key = ""
base_url = ""
embedding_model_path = 'models/sentence-transformers/all-MiniLM-L6-v2'




set_name = "dev"
version = "sft78"
generator_openai_client = OpenAI(
    api_key=api_key,
    base_url=base_url,
)
generator_model_name = 'metaflow_generator'


agent_type_list = ["privileged_FullCodeRefl", "code_lines_ratio"]
reward_func_list = [make_correct_ratio_reward_func(agent_type) for agent_type in agent_type_list]
reward_weights = [0.6, 0.4]


# ====================================



def make_meta_id(id1, id2):
    return '|'.join(sorted(set(id1.split('|') + id2.split('|'))))

def merge_nodes(node1, node2, max_try=5):
    sample = {
        'query1': node1['query'],
        'workflow1': node1['workflow'],
        'query2': node2['query'],
        'workflow2': node2['workflow'],
        'leaf_ids': set(node1['id'].split('|') + node2['id'].split('|')),
    }
    sample['workflow1_lines'] = count_python_lines_in_markdown(sample['workflow1'])
    sample['workflow2_lines'] = count_python_lines_in_markdown(sample['workflow2'])


    prompt = make_summarization_prompt(sample, ICL_number=ICL_number, version=2)['prompt']

    best_completion = ""
    temperature = init_temperature

    candidate_meta_samples = []

    for iter in range(max_try):
        is_success = False
        for _ in range(5):
            try:
                completion = call_llm(prompt, client=generator_openai_client, model=generator_model_name, temperature=temperature)
                is_success = True
                break
            except Exception as e:
                print(f'Build metaflow Error in {prompt}', e)
        if is_success == False:
            continue

        completions = [[{'content': completion}]]

        leaf_ids = [sample['leaf_ids']]
        workflow1_lines = [sample['workflow1_lines']]
        workflow2_lines = [sample['workflow2_lines']]

        reward_list = [reward_func(completions, leaf_ids=leaf_ids, workflow1_lines=workflow1_lines, workflow2_lines=workflow2_lines)[0]
            for reward_func, weight in zip(reward_func_list, reward_weights)]

        reward = np.sum([weight * reward_i for reward_i, weight in zip(reward_list, reward_weights)])


        if reward >= 0.0:
            pattern_meta_query = r'\*\*Meta Query\*\*:\s*(.*?)\s*\*\*Meta Workflow\*\*:'
            pattern_code_blocks = r'```(json|python)\s*([\s\S]*?)```'

            meta_query_match = re.search(pattern_meta_query, completion, re.DOTALL)
            meta_query = meta_query_match.group(1).strip().strip('-')
            meta_workflow_raw = re.findall(pattern_code_blocks, completion, re.DOTALL)
            output = []
            for lang, code in meta_workflow_raw:
                block = f"```{lang}\n{code}```"
                output.append(block)

            meta_workflow = "\n---\n".join(output)
            # ==============================================================
            meta_node = {
                'id': make_meta_id(node1['id'], node2['id']),
                'query': meta_query,
                'workflow': meta_workflow,
                # 'apps': list(set(node1['apps'] + node2['apps'])),
                # 'apis' : node1['apis'] | node2['apis'],
                'reward': reward,
                'source': [node1, node2],
            }
            for i in range(len(agent_type_list)):
                meta_node[f"reward_{agent_type_list[i]}"] = reward_list[i]

            candidate_meta_samples.append(meta_node)
            # ============================
        temperature -= 0.1

    max_reward_sample = max(
        candidate_meta_samples,
        key=lambda x: (x['reward'], len(x['workflow']))
    )

    return max_reward_sample

with open(f'./input/leaf_nodes_{set_name}.pkl', 'rb') as fp:
    leaf_nodes_train = pickle.load(fp)
leaf_nodes_train = {node['id']: node for node in leaf_nodes_train}

leaf_nodes = leaf_nodes_train
embedding_model = SentenceTransformer(embedding_model_path)

LeafId2LeafEmb = {node['id']: embedding_model.encode(node['query'], show_progress_bar=False) for node in leaf_nodes.values()}

embeddings = list(LeafId2LeafEmb.values())

distance_matrix = pdist(embeddings, metric='cosine')

Z = linkage(distance_matrix, method='average')


leaf_ids = list(leaf_nodes.keys())

merge_steps = []
distance_list = []



all_nodes = copy.deepcopy(leaf_nodes)

print("Each merge step:")
for i, merge in enumerate(Z):
    cluster_1, cluster_2, distance, sample_count = merge

    node1_id = leaf_ids[int(cluster_1)] if cluster_1 < len(leaf_ids) else merge_steps[int(cluster_1) - len(leaf_ids)]
    node2_id = leaf_ids[int(cluster_2)] if cluster_2 < len(leaf_ids) else merge_steps[int(cluster_2) - len(leaf_ids)]


    node1 = all_nodes[node1_id]
    node2 = all_nodes[node2_id]
    meta_node = merge_nodes(node1, node2)
    print(meta_node)

    meta_node_id = make_meta_id(node1_id, node2_id)


    all_nodes[meta_node_id] = meta_node

    print(f"Step {i + 1}:")
    print(f"  Merged clusters {int(cluster_1)} and {int(cluster_2)} with distance {distance:.3f}.")
    print(f"  New cluster contains {int(sample_count)} samples.")
    print(f"  Merged nodes: {meta_node_id}.")
    print(f"  Meta Node Reward: {meta_node['reward']}")


    merge_steps.append(meta_node_id)
    distance_list.append(distance)

plt.figure(figsize=(8, 5))
plt.hist(distance_list, bins=10, color='blue', edgecolor='black')
plt.title('Distance between Clusters at Each Step')
plt.xlabel('Merge Step')
plt.ylabel('Distance')
plt.grid(True)
plt.show()



with open(f'./input/metaflow_tree_{version}_{set_name}.pkl', 'wb') as fp:
    pickle.dump(all_nodes, fp)