import openai
import random
from igraph import *

KEY = 'Your key'
import base64
from PIL import Image
import io
import igraph as ig
import numpy as np
import pickle
import os
import scipy.sparse as sp
import json
import shutil
import time
from cynetdiff.utils import networkx_to_ic_model, networkx_to_lt_model
import networkx as nx


def encode_image_to_base64(image):
    buffered = io.BytesIO()
    image.save(buffered, format=image.format)
    return base64.b64encode(buffered.getvalue()).decode('utf-8')


def invoke_with_image(query, image_file=None):
    openai.api_key = KEY
    messages = [{"role": "user", "content": [{"type": "text", "text": query}]}]

    if image_file is not None:
        image = Image.open(image_file)
        base64_image = encode_image_to_base64(image)
        image_message = {
            "type": "image_url",
            "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}
        }
        messages[0]["content"].append(image_message)

    response = openai.chat.completions.create(
        model="gpt-4o-2024-08-06",
        messages=messages,
        temperature=0.8,
        max_tokens=1024,
    )


    return response.choices[0].message.content


def save_list_to_text_file(data_list, file_path):
    with open(file_path, 'w') as file:
        for item in data_list:
            file.write(f"{item}\n")


def merge_communities_in_graph(graph, clusters, communities_to_merge):
    """
    Merges specified communities in a graph and returns the new community memberships.

    Parameters:
    - graph: The graph object where communities are detected.
    - clusters: The community clustering object, typically from a community detection algorithm.
    - communities_to_merge: A list of lists, where each sublist contains the community IDs to be merged into one.

    Returns:
    - final_membership: A list of integers representing the new community IDs for each vertex.

    Example:
    - communities_to_merge = [[1, 2], [3, 4]] # This will merge communities 1 and 2, and 3 and 4 into two separate communities.
    """
    membership = clusters.membership.copy()
    new_community_id = {}

    # Assigning new community IDs
    for group in communities_to_merge:
        new_id = min(group)  # Choose the smallest number in the group as the new ID
        for old_id in group:
            new_community_id[old_id] = new_id

    # Update membership list based on the new community IDs
    for i in range(len(membership)):
        if membership[i] in new_community_id:
            membership[i] = new_community_id[membership[i]]

    # Ensure the community IDs are consecutive starting from 0
    unique_communities = sorted(set(membership))
    community_mapping = {old_id: new_id for new_id, old_id in enumerate(unique_communities)}
    final_membership = [community_mapping[old_id] for old_id in membership]

    # Update the graph vertex attribute for community membership
    graph.vs['community'] = final_membership

    # Return the new community membership list
    return final_membership


def merge_communities(graph, initial_communities, target_n_communities):
    while len(initial_communities) > target_n_communities:
        sizes = initial_communities.sizes()
        smallest_community_idx = sizes.index(min(sizes))

        # Find the community with the most edges to the smallest community
        community_edges = {i: 0 for i in range(len(sizes)) if i != smallest_community_idx}
        for edge in graph.es:
            source = initial_communities.membership[edge.source]
            target = initial_communities.membership[edge.target]
            if source == smallest_community_idx and target != smallest_community_idx:
                community_edges[target] += 1
            elif target == smallest_community_idx and source != smallest_community_idx:
                community_edges[source] += 1

        # Determine the closest community
        closest_community_idx = max(community_edges, key=community_edges.get)

        # Merge the smallest community into the closest community
        new_membership = initial_communities.membership[:]
        for i, membership in enumerate(new_membership):
            if membership == smallest_community_idx:
                new_membership[i] = closest_community_idx
            elif membership > smallest_community_idx:
                new_membership[i] -= 1  # Adjust the indices of other communities

        # Update the community structure
        initial_communities = ig.VertexClustering(graph, membership=new_membership)

    return initial_communities


def top_nodes(metric, seed_size):
    # Pair nodes with their degrees and sort by degree in descending order
    nodes_with_degrees = list(enumerate(metric))
    sorted_nodes_by_degree = sorted(nodes_with_degrees, key=lambda x: x[1], reverse=True)

    # Get the indices of the top 10 nodes
    node_seed = [index for index, degree in sorted_nodes_by_degree[:seed_size]]
    return node_seed


def top_byCommunity(graph, ratio, communities):
    # Calculate degrees for all nodes
    degrees = graph.degree()
    betweenness = graph.betweenness()
    closeness = graph.closeness()

    # Find the top 5% important nodes within each community
    top_nodes_degree = []
    top_nodes_betweenness = []
    top_nodes_closeness = []

    # Identify the top 5% important nodes within each community
    for community in communities:
        # Get degrees for nodes in this community
        community_degrees = [(node, degrees[node]) for node in community]
        # Sort by degree
        sorted_community_degrees = sorted(community_degrees, key=lambda x: x[1], reverse=True)

        # Determine the cutoff for the top 5%
        cutoff_index = max(int(len(sorted_community_degrees) * ratio), 1)  # Ensure at least one node
        # Get the top 5% nodes and append node indices to the list
        print([node for node, _ in sorted_community_degrees[:cutoff_index]])
        top_nodes_degree.extend([node for node, _ in sorted_community_degrees[:cutoff_index]])

        community_betweenness = [(node, betweenness[node]) for node in community]
        # Sort by degree
        sorted_community_betweenness = sorted(community_betweenness, key=lambda x: x[1], reverse=True)
        # Determine the cutoff for the top 5%
        cutoff_index = max(int(len(sorted_community_betweenness) * ratio), 1)  # Ensure at least one node
        # Get the top 5% nodes and append node indices to the list
        top_nodes_betweenness.extend([node for node, _ in sorted_community_betweenness[:cutoff_index]])

        community_closeness = [(node, closeness[node]) for node in community]
        # Sort by degree
        sorted_community_closeness = sorted(community_closeness, key=lambda x: x[1], reverse=True)
        # Determine the cutoff for the top 5%
        cutoff_index = max(int(len(sorted_community_closeness) * ratio), 1)  # Ensure at least one node
        # Get the top 5% nodes and append node indices to the list
        top_nodes_closeness.extend([node for node, _ in sorted_community_closeness[:cutoff_index]])

    # top_nodes_degree.extend(top_nodes_betweenness)
    # top_nodes_degree.extend(top_nodes_closeness)

    top_node = top_nodes_degree

    return top_node


def compute_column_averages(list_of_lists):
    # Transpose the list of lists to get columns as rows
    transposed = list(zip(*list_of_lists))

    # Compute the average of each column
    column_averages = []
    for column in transposed:
        if len(column) == 0:
            column_averages.append(0)  # Handle empty columns to avoid division by zero
        else:
            column_avg = sum(column) / len(column)
            column_averages.append(column_avg)

    return column_averages


def precompute_metrics(graph):
    """ Precompute degree and betweenness centrality for all nodes. """
    degrees = graph.degree()
    betweenness = graph.betweenness()
    return degrees, betweenness


def local_search_influence_maximization(graph, initial_seeds, max_iter, model):
    degrees, betweenness = precompute_metrics(graph)
    current_seeds = list(set(initial_seeds))
    if model == 'IC':
        best_edv = IC(graph, current_seeds, 'evaluation')
    elif model == 'LT':
        best_edv = LT(graph, current_seeds, 'evaluation')
    elif model == 'SI':
        best_edv = SI_objective(graph, current_seeds)

    for _ in range(max_iter):
        improved = False
        for seed in current_seeds:
            neighbors = list(graph.neighbors(seed))

            # Sort neighbors by degree or betweenness
            if random.choice([True, False]):
                sorted_neighbors = sorted(neighbors, key=lambda x: degrees[x], reverse=True)
            else:
                sorted_neighbors = sorted(neighbors, key=lambda x: betweenness[x], reverse=True)

            # Select the highest ranked neighbor that is not already in the seed set
            selected_neighbor = None
            for neighbor in sorted_neighbors:
                if neighbor not in current_seeds:
                    selected_neighbor = neighbor
                    break

            # If no valid neighbor found, continue to the next seed
            if selected_neighbor is None:
                continue
            new_seeds = list((set(current_seeds) - {seed}) | {selected_neighbor})
            if model == 'IC':
                new_edv = IC(graph, new_seeds, 'optimization')
            elif model == 'LT':
                new_edv = LT(graph, new_seeds, 'optimization')
            elif model == 'SI':
                new_edv = SI_objective(graph, new_seeds)

            if new_edv > best_edv:
                current_seeds = new_seeds
                best_edv = new_edv
                improved = True
                break
        if not improved:
            break

    return current_seeds


def SI_objective(graph, initial_infected):
    start_time = time.time()

    beta = 0.1  # Infection probability
    num_simulations = 100
    num_steps = 5
    total_infected_counts = np.zeros(num_steps + 1)

    for _ in range(num_simulations):
        status = {node: (1 if node in initial_infected else 0) for node in range(graph.vcount())}
        infected_counts = np.zeros(num_steps + 1)
        infected_counts[0] = len(initial_infected)

        for step in range(1, num_steps + 1):
            new_infected = []
            for node in range(graph.vcount()):
                if status[node] == 1:  # Infected nodes
                    for neighbor in graph.neighbors(node):
                        if status[neighbor] == 0 and np.random.rand() < beta:
                            new_infected.append(neighbor)

            for node in new_infected:
                status[node] = 1
            infected_counts[step] = infected_counts[step - 1] + len(set(new_infected))

        total_infected_counts += infected_counts

    # Compute the average number of infected nodes at the last step
    # avg_infected_at_last_step = total_infected_counts[-1] / num_simulations
    avg_infected_within_steps = np.mean(total_infected_counts[1:] / num_simulations)
    end_time = time.time()

    # # Calculate and print the elapsed time
    elapsed_time = end_time - start_time
    # print(f"Elapsed time of SI objective: {elapsed_time:.6f} seconds")

    return avg_infected_within_steps


def simulate_SI_model(graph, initial_infected, num_steps):
    beta = 0.1
    status = {node: (1 if node in initial_infected else 0) for node in range(graph.vcount())}
    infected_counts = np.zeros(num_steps + 1)
    infected_counts[0] = len(initial_infected)

    for step in range(1, num_steps + 1):
        new_infected = []
        for node in range(graph.vcount()):
            if status[node] == 1:  # Infected nodes
                for neighbor in graph.neighbors(node):
                    if status[neighbor] == 0 and np.random.rand() < beta:
                        new_infected.append(neighbor)

        for node in new_infected:
            status[node] = 1
        infected_counts[step] = infected_counts[step - 1] + len(set(new_infected))

    return infected_counts


def average_simulation_results(graph, initial_seed):
    num_simulations = 100
    num_steps = 30
    simulation_results = np.zeros(num_steps + 1)

    for _ in range(num_simulations):
        results = simulate_SI_model(graph, initial_seed, num_steps)
        simulation_results += results

    # Average the results over the number of simulations
    simulation_results /= num_simulations
    end_time = time.time()



    return simulation_results.tolist()


def IC(G, seed_set, mode):
    if mode == 'evaluation':
        n_sim = 100000
    elif mode == 'optimization':
        n_sim = 5000
    nx_graph = igraph_to_networkx_directed(G)

    # Convert NetworkX graph to LT model using cynetdiff
    model, _ = networkx_to_ic_model(nx_graph, activation_prob=0.1)

    # Set the initial seed nodes
    model.set_seeds(seed_set)

    total_activated = 0.0

    # Run the diffusion process n_sim times and accumulate the results
    for _ in range(n_sim):
        model.reset_model()  # Reset the model for a new simulation
        model.advance_until_completion()  # Run diffusion process
        total_activated += model.get_num_activated_nodes()  # Get the number of activated nodes

    # Calculate the average number of activated nodes
    avg_activated_nodes = total_activated / n_sim



    return avg_activated_nodes


def igraph_to_networkx_directed(ig_graph):
    """
    Convert an igraph graph to a NetworkX directed graph.
    """
    edges = ig_graph.get_edgelist()
    G = nx.DiGraph()  # Still using DiGraph to handle directed edges

    # Add both directions for each edge to make it bi-directional
    for u, v in edges:
        G.add_edge(u, v)  # Original direction
        G.add_edge(v, u)  # Opposite direction
    return G


def LT(G, seed_set, mode):
    if mode == 'evaluation':
        n_sim = 100000
    elif mode == 'optimization':
        n_sim = 5000

    nx_graph = igraph_to_networkx_directed(G)

    # Convert NetworkX graph to LT model using cynetdiff
    model, _ = networkx_to_lt_model(nx_graph)

    # Set the initial seed nodes
    model.set_seeds(seed_set)

    total_activated = 0.0

    # Run the diffusion process n_sim times and accumulate the results
    for _ in range(n_sim):
        model.reset_model()  # Reset the model for a new simulation
        model.advance_until_completion()  # Run diffusion process
        total_activated += model.get_num_activated_nodes()  # Get the number of activated nodes

    # Calculate the average number of activated nodes
    avg_activated_nodes = total_activated / n_sim

    return avg_activated_nodes


def graph_positions_GD(g, file_path):
    # Load positions from the file if it exists
    with open(file_path, 'r') as file:
        data = json.load(file)
        node_x = data['node_x']
        node_y = data['node_y']

    community_layout = np.column_stack((node_x, node_y))
    # Prepare data for Plotly plotting
    edge_x = []
    edge_y = []
    for edge in g.es:
        start, end = edge.tuple
        edge_x.extend([community_layout[start][0], community_layout[end][0], None])
        edge_y.extend([community_layout[start][1], community_layout[end][1], None])

    return [node_x, node_y], [edge_x, edge_y]


def distance(position, node1, node2):
    return np.sqrt(
        (position[0][0][node1] - position[0][0][node2]) ** 2 + (position[0][1][node1] - position[0][1][node2]) ** 2)


def load_community(graph):
    path = "result/community/" + graph + "_communities.pkl"
    with open(path, 'rb') as f:
        communities = pickle.load(f)
    return communities




def graph_positions_old(g, style):
    # Set seed for reproducibility
    random.seed(42)
    np.random.seed(42)
    layout = g.layout(style)  # 'fr' stands for Fruchterman-Reingold #
    # layout = g.layout('kk')
    positions = {i: layout[i] for i in range(len(layout))}

    # Prepare data for Plotly plotting
    edge_x = []
    edge_y = []
    for edge in g.es:
        start, end = edge.tuple
        edge_x += [positions[start][0], positions[end][0], None]
        edge_y += [positions[start][1], positions[end][1], None]

    node_x = [pos[0] for pos in layout]
    node_y = [pos[1] for pos in layout]
    return [node_x, node_y], [edge_x, edge_y]


def generate_incremental_filename(directory, base_filename):
    i = 1
    # Construct the full path for the file
    full_filename = os.path.join(directory, f"{base_filename}_{i}.png")
    # Increment the filename if it exists
    while os.path.exists(full_filename):
        i += 1
        full_filename = os.path.join(directory, f"{base_filename}_{i}.png")
    return full_filename


def clear_folder(folder_path):
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print(f'Failed to delete {file_path}. Reason: {e}')


def graph_read(GrapH):
    if GrapH == 'dolphins':
        graph = Graph.Read_GML(r"CSD dataset/dolphins/dolphins.gml")
    elif GrapH == 'karate':
        graph = Graph.Read_GML(r"CSD dataset/karate/karate.gml")
    elif GrapH == 'lesmis':
        graph = Graph.Read_GML(r"CSD dataset/lesmis/lesmis.gml")
    elif GrapH == 'polbooks':
        graph = Graph.Read_GML(r"CSD dataset/polbooks/polbooks.gml")
    elif GrapH == 'power':
        graph = Graph.Read_GML(r"C:/Users/zhaoj/Desktop/LLM+EVC/MLLMs/CSD dataset/power/power.gml")
    elif GrapH == 'sex':
        graph = Graph.Read_GML(r"CSD dataset/sex.gml").simplify(
            combine_edges='sum')
    elif GrapH == 'router':
        graph = Graph.Read_GML(r"CSD dataset/router.gml").simplify(
            combine_edges='sum')
    elif GrapH == 'facebook':
        graph = Graph.Read_Ncol(r"CSD dataset/facebook.txt", directed=False)

    return graph


