import argparse
import os
import pickle

from dsl.gdl import * 
# Assuming data_loader contains these functions and they return an object
# with 'graphs' and 'graph_to_label' attributes.
from data_loader import load_BACE, load_BBBP, load_Data, Data


# Loads the specified dataset using functions from data_loader.
# The returned object is expected to have 'graphs' and 'graph_to_label' attributes.
def load_dataset(dataset_name):
    if dataset_name == "BBBP":
        return load_BBBP()
    if dataset_name == "BACE":
        return load_BACE()
    return load_Data(dataset_name)


# Loads, filters, and sorts learned GDL programs from pickle files.
# Returns the top 'k' programs and their corresponding chosen graph sets.
def load_learned_programs(dataset_name, k):
    program_dir = os.path.join("datasets", dataset_name, "learned_GDL_programs")
    if not os.path.isdir(program_dir):
        raise FileNotFoundError(f"Learned programs directory not found at: {program_dir}")
    program_files = [f for f in os.listdir(program_dir) if f.endswith(".pickle")]

    seen_programs = set()
    learned_programs = []
    labels = set()
    for pkl_file in program_files:
        path = os.path.join(program_dir, pkl_file)
        try:
            with open(path, "rb") as f:
                learned_tuple_set = pickle.load(f)
        except (pickle.UnpicklingError, EOFError) as e:
            print(f"Warning: Could not read pickle file {path}. Skipping. Error: {e}")
            continue

        for learned_tuple in learned_tuple_set:
            # learned_tuple is expected to be (label, program, score, chosen_graphs)
            # We use frozenset of chosen_graphs to identify unique programs.
            #learned_programs.append(learned_tuple)
            program_representation = frozenset(learned_tuple[3])
            labels.add(learned_tuple[0])
            if program_representation not in seen_programs:
                learned_programs.append(learned_tuple)
                seen_programs.add(program_representation)

    # Sort programs by score in descending order
    learned_programs.sort(key=lambda x: x[2], reverse=True)
    print(f"Found {len(learned_programs)} unique learned programs.")

    if k == 0:      
      top_k_programs = learned_programs[:len(labels)]

    else:
      top_k_programs = learned_programs[:k]

    gdl_programs = [p[1] for p in top_k_programs]
    chosen_graphs_per_program = [p[3] for p in top_k_programs]
    # print("lllll")
    '''
    data = load_Data('MUTAG')
    for i in range(k):
      chosen_graphs = eval_GDL_program_on_graphs_GC(gdl_programs[i], data)    
      if chosen_graphs != chosen_graphs_per_program[i]:
        print(chosen_graphs)
        print(chosen_graphs_per_program[i])
        raise
    #'''
    with open('embeddings/gdl_programs.pickle', 'wb') as f:
      pickle.dump(gdl_programs, f)

    return gdl_programs, chosen_graphs_per_program


# Creates feature embeddings (X) and labels (Y) from the data and programs.
def create_embeddings(data, chosen_graphs_per_program):
    num_graphs = len(data.graphs)
    X = [[] for _ in range(num_graphs)]
    Y = [data.graph_to_label[graph_idx] for graph_idx in range(num_graphs)]

    for chosen_graphs in chosen_graphs_per_program:
        for i in range(num_graphs):
            feature = 1 if i in chosen_graphs else 0
            X[i].append(feature)
    return X, Y


# Saves the generated embeddings, labels, and GDL programs to pickle files.
def save_embeddings(dataset_name, X, Y, gdl_programs):
    output_dir = "embeddings"
    os.makedirs(output_dir, exist_ok=True)

    with open(os.path.join(output_dir, f"X_{dataset_name}.pickle"), "wb") as f:
        pickle.dump(X, f)

    with open(os.path.join(output_dir, f"Y_{dataset_name}.pickle"), "wb") as f:
        pickle.dump(Y, f)

    with open(
        os.path.join(output_dir, f"gdl_programs_{dataset_name}.pickle"), "wb"
    ) as f:
        pickle.dump(gdl_programs, f)
    
    print(f"Embeddings saved successfully in '{output_dir}' directory.")


# Main processing pipeline to generate and save graph embeddings.
def generate_embeddings(dataset_name, k):
    print(f"--- Starting embedding generation for dataset: {dataset_name} ---")
    
    data = load_dataset(dataset_name)
    print(f"Loaded dataset with {len(data.graphs)} graphs.")

    gdl_programs, chosen_graphs_per_program = load_learned_programs(dataset_name, k)
    print(f"Using top {len(gdl_programs)} programs for embedding generation.")

    X, Y = create_embeddings(data, chosen_graphs_per_program)
    print(f"Created embeddings of shape ({len(X)}, {len(X[0]) if X else 0}).")

    save_embeddings(dataset_name, X, Y, gdl_programs)
    
    print(f"--- Finished embedding generation for dataset: {dataset_name} ---")

# Main processing pipeline to generate and save graph embeddings.
def generate_embeddings_tmp(dataset_name, k):
    print(f"--- Starting embedding generation for dataset: {dataset_name} ---")
    
    data = load_dataset(dataset_name)
    print(f"Loaded dataset with {len(data.graphs)} graphs.")
     
    with open('embeddings/gdl_programs.pickle', 'rb') as f:
      gdl_programs = pickle.load(f)
    
    print(f"Using top {len(gdl_programs)} programs for embedding generation.")

    Y = [data.graph_to_label[0]]
    X = []
    for program in gdl_programs:
        if len(eval_GDL_program_on_graphs_GC(program, data)) == 1:
            X.append(1)
        else:
            X.append(0)
    X = [X]
    print(f"Created embeddings of shape ({len(X)}, {len(X[0]) if X else 0}).")

    save_embeddings(dataset_name, X, Y, gdl_programs)
    
    print(f"--- Finished embedding generation for dataset: {dataset_name} ---")



# Parses command-line arguments and runs the embedding generation process.
def main():
    parser = argparse.ArgumentParser(
        description="Generate graph embeddings from learned GDL programs."
    )
    parser.add_argument(
        "-d", "--dataset", required=True, help="Input dataset name (e.g., BBBP, BACE)."
    )
    parser.add_argument(
        "-k",
        "--k",
        type=int,
        default=32,
        help="Number of top k programs to use for features. Default: 32.",
    )
    args = parser.parse_args()

    try:
        generate_embeddings(args.dataset, args.k)
    except FileNotFoundError as e:
        print(f"Error: A required file or directory was not found.")
        print(e)
    except Exception as e:
        print(f"An unexpected error occurred: {e}")


if __name__ == "__main__":
    main()
