import os
import argparse
import random
import pandas as pd
from collections import deque

def create_grounded_coloring_program(N, edges, add_comments):
    lines = []
    if add_comments:
        lines.append(f"% Grounded Coloring Program for N={N}")
        lines.append("")

    # Add Probabilistic Facts based on the edges
    for source, dest, prob in edges:
        lines.append(f"{prob}::edge({source},{dest}).")

    # Add node facts
    for i in range(1, N + 1):
        lines.append(f"1.0::node({i}).")

    lines.append("")

    # Add grounded coloring rules
    for i in range(1, N + 1):
        lines.append(f"red({i}) :- node({i}), not green({i}), not blue({i}).")
        lines.append(f"green({i}) :- node({i}), not red({i}), not blue({i}).")
        lines.append(f"blue({i}) :- node({i}), not red({i}), not green({i}).")
    lines.append("")

    # Add grounded edge facts
    for source, dest, _ in edges:
        lines.append(f"e({source},{dest}) :- edge({source},{dest}).")
        lines.append(f"e({dest},{source}) :- edge({source},{dest}).")

    lines.append("")

    # Add grounded constraints to prevent adjacent nodes from having the same color
    existing_edges = set()
    for source, dest, _ in edges:
        if (source, dest) not in existing_edges:
            existing_edges.add((source, dest))
            lines.append(f":- e({source},{dest}), red({source}), red({dest}).")
            lines.append(f":- e({source},{dest}), green({source}), green({dest}).")
            lines.append(f":- e({source},{dest}), blue({source}), blue({dest}).")

        if (dest, source) not in existing_edges:
            existing_edges.add((dest, source))
            lines.append(f":- e({dest},{source}), red({dest}), red({source}).")
            lines.append(f":- e({dest},{source}), green({dest}), green({source}).")
            lines.append(f":- e({dest},{source}), blue({dest}), blue({source}).")

    return lines

def build_graph(df):
    """
    Builds a directed graph from a list of edges.
    Returns a dictionary: source -> list of (target, rating)
    """
    graph = {}
    for _, row in df.iterrows():
        source, target, rating = row['source'], row['target'], row['rating']
        if source not in graph:
            graph[source] = []
        graph[source].append((target, (rating + 11)/22))
    return graph

def snowball_sample(graph, start_node=None, max_nodes=100):
    """
    Performs snowball sampling starting from a given node.
    Returns a list of sampled (source, target, rating) edges.
    """
    visited = set()
    queue = deque()
    sampled_edges = []

    if start_node is None:
        start_node = random.choice(list(graph.keys()))

    node_mapping = {}
    node_mapping[start_node] = 1
    visited.add(start_node)
    queue.append(start_node)

    while queue and len(visited) < max_nodes:
        current = queue.popleft()
        for neighbor, rating in graph.get(current, []):
            if neighbor not in visited:
                node_mapping[neighbor] = len(node_mapping) + 1
                visited.add(neighbor)
                queue.append(neighbor)
            sampled_edges.append((node_mapping[current], node_mapping[neighbor], rating))
            if len(visited) >= max_nodes:
                break

    return sampled_edges

def snowball_sampling(bitcoin_data, N):
    graph = build_graph(bitcoin_data)
    return snowball_sample(graph, max_nodes=N)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Create Grounded Coloring Problem programs.')
    parser.add_argument('beginning', type=int, help='Start size of the graph')
    parser.add_argument('end', type=int, nargs='?', default=None, help='End size of the graph (optional, defaults to beginning)')
    parser.add_argument('-c', action='store_true', help='Add comments to output')
    parser.add_argument('--bitcoin', type=str, required=True, help='Path to the Bitcoin dataset')
    parser.add_argument('--seed', type=int, default=None, help='Random seed for reproducibility')

    args = parser.parse_args()

    beginning = args.beginning
    end = args.end if args.end is not None else beginning
    add_comments = args.c
    bitcoin_path = args.bitcoin
    seed = args.seed

    # Set random seed if provided
    if seed is not None:
        random.seed(seed)

    # Load the Bitcoin dataset
    bitcoin_data = pd.read_csv(bitcoin_path, names=['source', 'target', 'rating', 'time'], header=0)

    for N in range(beginning, end + 1):
        directory = f"plp/programs/coloring_{N}"
        os.makedirs(directory, exist_ok=True)
        file_path = os.path.join(directory, f"coloring_{N}.pasp")

        # Perform snowball sampling to generate the graph structure
        edges = snowball_sampling(bitcoin_data, N)

        # Generate the grounded coloring program
        lines = create_grounded_coloring_program(N, edges, add_comments)

        # Write the program to a file
        with open(file_path, 'w') as f:
            f.write("\n".join(lines))
