import math
import random
import networkx as nx
from typing import List
from function_call_agent.graph_search.graph_tool import graph_to_tree_string, draw_dependency_graph


class HybridSearcher:
    def __init__(self, G: nx.DiGraph, graph_degree: int = 3, data_type=None):
        self.G = G
        self.sub_graph = nx.subgraph_view(
            self.G,
            filter_edge=self.edge_filter
        ).reverse(copy=True)
        self.graph_degree = graph_degree
        self.api_nodes = {n for n, d in G.nodes(data=True) if d.get('type') == 'api'}
        self.param_nodes = {n for n, d in G.nodes(data=True) if d.get('type') == 'param'}
        self.ori_temperature = 300
        self.cooling_rate = 0.7
        self.population_size = 40
        self.max_nodes = 24
        self.data_type = data_type

    def is_required(self, api, param):
        try:
            if api.startswith('param-'):
                return True
            if self.data_type == 'api_bank':
                return True
            else:
                return eval(self.G.nodes[api]['prop']['input_param'])[param[6:]]['is_required'] == 'true'
        except Exception as e:
            return False

    def _per_api_initialize(self, api: str, temperature) -> List[str]:

        population = []
        max_degree = 3 if self.graph_degree is None else self.graph_degree
        for _ in range(self.population_size):
            path_nodes = {api} | set(self.G.predecessors(api))
            current_node = api
            current_depth = 0
            while current_depth <= max_degree and current_node:
                path_nodes.add(current_node)
                predecessors = self.get_predecessors(current_node)
                current_node = predecessors[random.randint(0, min(2, len(predecessors)-1))] if predecessors else None
                current_depth += 1
            chrom = self._encode_chromosome(api, list(path_nodes))
            population.append(self._sa_mutation(chrom, temperature))
        return list(set(population))

    def _api_specific_fitness(self, chrom: str, target_api: str) -> float:
        nodes = self._decode_chromosome(chrom, target_api)
        subgraph = nx.subgraph_view(
                self.G,
                filter_node=lambda n: n in nodes,
                filter_edge=lambda u, v: self.G.get_edge_data(u, v).get('label').cast() != 'depends_on'
            )

        api_centrality = nx.closeness_centrality(subgraph, u=target_api, wf_improved=True)

        param_density = sum(1 for n in nodes if n in self.param_nodes) / len(nodes)

        avg_depth = sum(nx.shortest_path_length(self.G, n, target_api) for n in nodes) / len(nodes)

        total_weight = 0
        for u, v in subgraph.edges():
            total_weight += self.G.get_edge_data(u, v).get('weight', 0).cast()
        weight_penalty = total_weight / (100 + total_weight)

        path_complexity = 1 - nx.average_clustering(subgraph)

        nodes_num = len(subgraph.nodes) / 3
        combined_penalty = 0.2 * math.exp(-avg_depth / 10) + 0.8 * math.exp(-nodes_num / 8)
        return (0.35 * api_centrality + 0.15 * math.log1p(param_density) + 0.3 * combined_penalty +
                0.15 * weight_penalty + 0.05 * path_complexity)

    def _genetic_operations_per_api(self, population, target_api: str, temperature: float):
        evaluated = sorted([(c, self._api_specific_fitness(c, target_api)) for c in population],
                           key=lambda x: -x[1])
        elites = [e[0] for e in evaluated[:int(0.6 * self.population_size)]]

        new_pop = elites.copy()
        all_nodes = sorted(self.get_filter_graph(target_api).nodes())
        while len(new_pop) < self.population_size:
            parent1, parent2 = random.choices(population, k=2)
            crossover_point = random.randint(1, min(max(len(parent1) - 2, len(parent1)), 2))
            child = parent1[:crossover_point] + parent2[crossover_point:]

            if random.random() < 0.7 - 0.3 * (temperature / 1000):
                child = self._adaptive_sa_mutation(target_api, child, temperature, all_nodes)
            new_pop.append(child)

        return new_pop

    def _sa_mutation(self, chrom: str, temperature: float) -> str:
        bits = list(chrom)
        mutation_num = int(temperature / 100)
        for _ in range(mutation_num):
            idx = random.randint(0, len(bits) - 1)
            bits[idx] = '1' if bits[idx] == '0' else '0'
        return ''.join(bits)

    def edge_filter(self, u, v):
        edge_data = self.G.get_edge_data(u, v)
        return edge_data.get('label', '').cast() != "depends_on"

    def get_filter_graph(self, target_api):
        return nx.ego_graph(self.sub_graph, target_api, radius=3, undirected=False)

    def _encode_chromosome(self, target_api, path_nodes: list) -> str:
        k_hop = sorted(self.get_filter_graph(target_api).nodes())
        return ''.join(['1' if node in path_nodes else '0' for node in k_hop])

    def get_predecessors(self, node):
        predecessors = sorted([
                                p for p in self.G.predecessors(node)
                                if (self.G.get_edge_data(p, node).get('label').cast() != 'depends_on') and
                                   (self.is_required(node, p))],
                              key=lambda n: self.G.get_edge_data(n, node).get('weight', 0).cast(),
                              reverse=True)
        predecessors = predecessors[:3] if node.startswith('param') else predecessors
        return predecessors

    def _decode_chromosome(self, chrom: str, target_api: str) -> set:
        base_nodes = {node for bit, node in zip(chrom, sorted(self.G.nodes())) if (bit == '1') and (
            nx.has_path(self.G, node, target_api))}
        base_nodes.add(target_api)

        def _find_predecessor_apis(node, visited=None):
            if visited is None:
                visited = set()
            if node in visited:
                return set()
            visited.add(node)

            predecessors = set(self.get_predecessors(node))
            return predecessors

        expanded_nodes = set(base_nodes)
        for node in base_nodes:
            if node in self.api_nodes:
                expanded_nodes.update(_find_predecessor_apis(node))

        return expanded_nodes

    def _merge_subgraphs(self, subgraphs, target_apis) -> nx.DiGraph:
        merged_graph = nx.DiGraph()
        for sg, target_api in zip(subgraphs, target_apis):
            sg1 = nx.subgraph_view(
                sg,
                filter_node=lambda n: n in self.get_max_graph(sg, target_api),
            )
            merged_graph = nx.compose(merged_graph, sg1)
        return merged_graph

    def get_max_graph(self, sg, target_api):
        all_nodes = [target_api]
        current_nodes = [target_api]
        while current_nodes:
            node = current_nodes[0]
            current_nodes = current_nodes[1:]
            if len(all_nodes) > self.max_nodes and node.startswith('param'):
                break
            for pred in sg.predecessors(node):
                current_nodes.append(pred)
                all_nodes.append(pred)
        return set(all_nodes)

    def _adaptive_sa_mutation(self, target_api: str, chrom: str, temperature: float, all_nodes) -> str:
        bits = list(chrom)

        for i in range(len(bits)):
            node = all_nodes[i]
            if bits[i] == '1' and node in self.api_nodes:
                for pred in self.G.predecessors(node):
                    if self.G.nodes[pred].get('type') == 'param':
                        if pred in all_nodes:
                            pred_idx = all_nodes.index(pred)
                            bits[pred_idx] = '1'

        mutation_rate = 0.1 + (1 - temperature / 1000) * 0.6
        for i in range(len(bits)):
            if random.random() < mutation_rate:
                bits[i] = '1' if random.random() < 0.8 else '0'

        return ''.join(bits)

    def hybrid_search(self, target_apis: List[str]):

        subgraphs = []

        for api in target_apis:
            temperature = self.ori_temperature
            population = self._per_api_initialize(api, temperature)
            iteration_count = 0
            while temperature > 1 and iteration_count <= 10:
                iteration_count += 1
                temperature *= (self.cooling_rate ** (1 + iteration_count/5))
                population = self._genetic_operations_per_api(population, api, temperature)

            best_chrom = max(population, key=lambda x: self._api_specific_fitness(x, api))
            subgraphs.append(self._build_api_subgraph(best_chrom, api))
        new_graph = self._merge_subgraphs(subgraphs, target_apis)
        return self.get_graph(target_apis, new_graph), new_graph

    def _build_api_subgraph(self, chrom: str, target_api: str):
        nodes = self._decode_chromosome(chrom, target_api)

        missing_params = set()
        for node in nodes & self.api_nodes:
            required_params = {pred for pred in self.get_predecessors(node)
                               if node in self.api_nodes}
            missing_params.update(required_params - nodes)
        nodes = nodes | missing_params
        connected_nodes = set()
        for node in nodes:
            if nx.has_path(self.G, node, target_api):
                connected_nodes.add(node)
                for path_node in nx.shortest_path(self.G, node, target_api):
                    connected_nodes.add(path_node)

        def edge_filter(u, v):
            edge_data = self.G.get_edge_data(u, v)
            return edge_data.get('label', '').cast() != "depends_on"

        filtered_graph = nx.subgraph_view(
            self.G,
            filter_node=lambda n: n in connected_nodes,
            filter_edge=edge_filter
        )
        return filtered_graph

    def get_graph(self, target_apis, pruned_graph):
        tree_str = graph_to_tree_string(pruned_graph, target_apis=target_apis, reverse=False)
        return tree_str


if __name__=='__main__':
    G = nx.DiGraph()
    G.add_nodes_from([
        ('api-delete_account', {'type': 'api'}),
        ('param1', {'type': 'param'}),
        ('api-get_order_id', {'type': 'api'}),
        ('param2', {'type': 'param'}),
        ('api-get_time', {'type': 'api'}),
    ])
    G.add_edges_from([
        ('param1', 'api-delete_account', {'weight': 0.8}),
        ('param2', 'api-get_order_id', {'weight': 0.9}),
        ('param2', 'api-get_time', {'weight': 0.7}),

    ])

    searcher = HybridSearcher(G, graph_degree=10)
    pruned_graph = searcher.hybrid_search(
        target_apis=['api-delete_account', 'api-get_order_id']
    )

    print(pruned_graph.nodes())
    print(pruned_graph.edges())