import numpy as np

from itertools import product
from tqdm import tqdm

from templates import flooding
from templates import flooding_utils

import json 
from pathlib import Path
from datetime import datetime

import itertools
import argparse

from initialization_utils import flooding_final_messages, create_node_features, two_hop_coloring

from saving_utils import save_test_samples
eps = 1e-10

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Generate flooding test samples")
    parser.add_argument("--D", type=int, required=True)
    parser.add_argument("--l", type=int, required=True)
    args = parser.parse_args()

    D = args.D
    l = args.l


    print("Generate flooding test samples")
    d = D**2+1
    file_path = f"l{l}_D{D}_n5to20"
    X, Y_train = flooding.get_dataset(l,d)
    X_train = np.eye(len(X))
    Y_train = np.array(Y_train, dtype=np.float64)
    n0 = len(X)

    # Load the graphs from the JSON file
    with open(f"saved_adjacencies/graphs_D{D}.json", 'r') as f:
        graphs_data = json.load(f)

    num_graphs = len(graphs_data)
    num_messages = 2**l
    total_cases = num_graphs * num_messages

    counter = 1
    test_samples = []
    seen = set()        

    # Loop through each graph's adjacency matrix
    for graph in graphs_data:
        g_idx = graph["graph_id"]
        n = graph["n"]
        adjacency_matrix = graph["adjacency_matrix"]
        
        # Convert to numpy array if needed
        A = np.array(adjacency_matrix)
        
        # compute true 2-hop local ids
        local_ids = two_hop_coloring(A, D)  # values in 1..D^2+1
        # iterate all binary messages
        for bits_idx, bits in enumerate(itertools.product([0, 1], repeat=l)):
            message = list(bits)
            # source is the 0-th node
            s = 0
            final_messages = flooding_final_messages(A, s, message)
            X_test_list = create_node_features(n, s, l, d, message, local_ids)

            for _ in range(250):
                X_test_encoded_l = []
                X_test_encoded_nonorm = []
                for x_test in X_test_list:
                    x_e_no_norm, x_enc = flooding_utils.encode_data(x_test, X)
                    X_test_encoded_l.append(x_enc)
                    X_test_encoded_nonorm.append(x_e_no_norm)
                
                X_test = np.array(X_test_encoded_l, dtype=np.float64)
                Y_pred = X_test @ X_train @ Y_train

                Y_pred_round = np.where(Y_pred > eps, 1.0, 0.0)


                for x_t, x_e, y_t in zip(X_test_list, X_test_encoded_nonorm, Y_pred_round):
                    tuple_e = tuple(x_e)
                    if tuple_e not in seen:
                        test_samples.append(( x_t, x_e, y_t))
                        seen.add(tuple_e)

                # Get dimensions
                n_nodes = Y_pred_round.shape[0]  # Number of nodes
                n_features = Y_pred_round.shape[1]  # Total number of features
                split_point = l + 3*d*l  # The column where we split

                # Create P_c (selects columns 0 to split_point-1)
                P_c = np.zeros((n_features, n_features))
                P_c[:split_point, :split_point] = np.eye(split_point)

                # Create P_m (selects columns from split_point onward)
                P_m = np.zeros((n_features, n_features))
                P_m[split_point:, split_point:] = np.eye(n_features - split_point)

                # Compute X_test according to the formula
                X_test = Y_pred_round @ P_c + A @ Y_pred_round @ P_m

                X_test_list = []

                for e, row in enumerate(X_test):
                    x_test = flooding.unflatten_sample(row, l, d)
                    X_test_list.append(x_test)
                
            print(f"finished processing {counter} out of {total_cases}")
            counter += 1

    # save final versions with final metadata:
    save_test_samples(
        test_samples,
        file_path,
        input_shape=x_e.shape if hasattr(x_e, 'shape') else None,
        output_shape=y_t.shape if hasattr(y_t, 'shape') else None
    )
