from .sbm_loader import load_temporarl_edgelist


import numpy as np
import random
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
import sys

from .normal_util import save_object
import timeit
import rrcf
import pandas as pd
from sklearn.preprocessing import normalize
from . import metrics


"""
NOTE: SPOTLIGHT assumes node ordering persists over time
generate a source or a destination dictionary
G: the graph at snapshot 0
p: probability of sampling a node into the dictionary

return dict: a dictionary of selected sources or destinations for a subgraph
"""


def make_src_dict(G, p):
    out_dict = {}
    for node in G.nodes():
        if random.random() <= p:
            out_dict[node] = 1
    return out_dict


"""
main algorithm for SPOTLIGHT
G_times: a list of networkx graphs for each snapshot in order
K: the number of subgraphs to track. 
p: probability of sampling a node into the source
q: probability of sampling a node into the destination

return a list of SPOTLIGHT embeddings (np arrays) for each snapshot
"""


def SPOTLIGHT(
    G_times, K, p, q, is_bipartite=False, all_src_nodes=None, all_dst_nodes=None
):
    """
    Temporal SPOTLIGHT that handles both bipartite and non-bipartite graphs.

    Args:
        G_times: List of networkx graphs for each snapshot
        K: Number of sketches
        p: Probability of sampling source nodes
        q: Probability of sampling destination nodes
        is_bipartite: Whether the graph is bipartite
        all_src_nodes: Pre-computed set of all source nodes (optional)
        all_dst_nodes: Pre-computed set of all destination nodes (optional)
    """
    if is_bipartite:
        # Use pre-computed node sets if available, otherwise compute them
        if all_src_nodes is not None and all_dst_nodes is not None:
            src_nodes = all_src_nodes
            dst_nodes = all_dst_nodes
        else:
            # Fallback: compute from graphs (less efficient)
            src_nodes = set()
            dst_nodes = set()
            for G in G_times:
                for u, v in G.edges():
                    src_nodes.add(u)
                    dst_nodes.add(v)

        # Initialize sketches from complete partitions
        src_dicts = []
        dst_dicts = []
        for _ in range(K):
            # Sample only from source partition
            src_dict = {node: 1 for node in src_nodes if random.random() <= p}
            # Sample only from destination partition
            dst_dict = {node: 1 for node in dst_nodes if random.random() <= q}
            src_dicts.append(src_dict)
            dst_dicts.append(dst_dict)

        print(
            f"Bipartite SPOTLIGHT: {len(src_nodes)} source nodes, {len(dst_nodes)} destination nodes"
        )

    else:
        # For non-bipartite graphs, collect all nodes
        if all_src_nodes is not None and all_dst_nodes is not None:
            all_nodes = all_src_nodes.union(all_dst_nodes)
        else:
            # Fallback: compute from graphs
            all_nodes = set()
            for G in G_times:
                all_nodes.update(G.nodes())

        src_dicts = []
        dst_dicts = []
        for _ in range(K):
            src_dicts.append(make_src_dict_from_nodes(all_nodes, p))
            dst_dicts.append(make_src_dict_from_nodes(all_nodes, q))

        print(f"Non-bipartite SPOTLIGHT: {len(all_nodes)} total nodes")

    # Process each snapshot
    sl_embs = []
    for G in G_times:
        sl_emb = np.zeros(K)
        for u, v, w in G.edges.data("weight", default=1):
            for i in range(len(src_dicts)):
                if u in src_dicts[i] and v in dst_dicts[i]:
                    sl_emb[i] += w
        sl_embs.append(sl_emb)

    return sl_embs


def make_src_dict_from_nodes(nodes, p):
    """Create source dictionary from a given set of nodes."""
    out_dict = {}
    for node in nodes:
        if random.random() <= p:
            out_dict[node] = 1
    return out_dict


def rrcf_offline(X, num_trees=50, tree_size=50):
    n = len(X)
    X = np.asarray(X)

    # Fix sample size range - use proper range
    sample_size = min(tree_size, n)
    sample_size_range = range(max(1, sample_size), min(2 * sample_size, n + 1))

    # Ensure we have a valid range
    if len(sample_size_range) == 0:
        sample_size_range = range(1, n + 1)

    # Construct forest
    forest = []
    while len(forest) < num_trees:
        # Select random subsets of points uniformly
        sample_size = np.random.choice(sample_size_range)
        ixs = np.random.choice(n, size=sample_size, replace=False)
        # Add sampled trees to forest
        try:
            # Ensure X[ixs] is 2D
            sample_data = X[ixs].reshape(-1, X.shape[1]) if len(ixs) == 1 else X[ixs]
            trees = [rrcf.RCTree(sample_data, index_labels=ixs)]
            forest.extend(trees)
        except Exception as e:
            print(f"Error creating RRCF tree: {e}")
            raise

    # Compute average CoDisp
    avg_codisp = pd.Series(0.0, index=np.arange(n))
    index = np.zeros(n)
    for tree in forest:
        codisp = pd.Series({leaf: tree.codisp(leaf) for leaf in tree.leaves})
        avg_codisp[codisp.index] += codisp
        np.add.at(index, codisp.index.values, 1)
    avg_codisp /= index
    avg_codisp = avg_codisp.tolist()
    return avg_codisp


def run_SPOTLIGHT(
    edgefile,
    K=50,
    window=5,
    percent_ranked=0.05,
    use_rrcf=True,
    seed=0,
    is_bipartite=False,
):
    random.seed(seed)
    p = 0.2
    q = 0.2

    G_times, all_src_nodes, all_dst_nodes = load_temporarl_edgelist(
        edgefile, draw=False
    )
    start = timeit.default_timer()
    sl_embs = SPOTLIGHT(
        G_times,
        K,
        p,
        q,
        is_bipartite=is_bipartite,
        all_src_nodes=all_src_nodes,
        all_dst_nodes=all_dst_nodes,
    )
    end = timeit.default_timer()
    sl_time = end - start
    print("SPOTLIGHT time: " + str(sl_time) + "\n")
    save_object(sl_embs, edgefile.replace(".txt", ".pkl"))

    if use_rrcf:
        # Check for NaN/inf values in embeddings
        sl_embs_array = np.array(sl_embs)

        # Remove any NaN/inf values
        if np.isnan(sl_embs_array).any() or np.isinf(sl_embs_array).any():
            print("Warning: Found NaN or Inf values in embeddings, replacing with 0")
            sl_embs_array = np.nan_to_num(
                sl_embs_array, nan=0.0, posinf=0.0, neginf=0.0
            )
            sl_embs = sl_embs_array.tolist()

        # Add small noise to prevent all-zero embeddings that cause RRCF issues
        if np.all(sl_embs_array == 0):
            print(
                "Warning: All embeddings are zero, adding small noise for RRCF stability"
            )
            noise = np.random.normal(0, 1e-6, sl_embs_array.shape)
            sl_embs_array = sl_embs_array + noise
            sl_embs = sl_embs_array.tolist()

        # Check for low variance that can cause RRCF issues
        if sl_embs_array.shape[1] > 1:  # Only check if we have multiple dimensions
            variances = np.var(sl_embs_array, axis=0)
            if np.any(variances < 1e-10):  # Very low variance
                print("Warning: Low variance detected, adding noise for RRCF stability")
                noise = np.random.normal(0, 1e-6, sl_embs_array.shape)
                sl_embs_array = sl_embs_array + noise
                sl_embs = sl_embs_array.tolist()

        start = timeit.default_timer()
        num_trees = 50
        tree_size = 151

        try:
            scores = rrcf_offline(sl_embs, num_trees=num_trees, tree_size=tree_size)
            end = timeit.default_timer()
            a_time = end - start
            print("rrcf time: " + str(a_time) + "\n")
            scores = np.asarray(scores)
            num_ranked = int(scores.shape[0] * percent_ranked)
            outliers = scores.argsort()[-num_ranked:][::-1]
            outliers.sort()
        except Exception as e:
            print(f"RRCF failed: {e}")
            print("Falling back to simple detector...")
            start = timeit.default_timer()
            scores, outliers = simple_detector(sl_embs)
            end = timeit.default_timer()
            a_time = end - start
            print("simple detector time: " + str(a_time) + "\n")

    else:
        start = timeit.default_timer()
        scores, outliers = simple_detector(sl_embs)
        end = timeit.default_timer()
        a_time = end - start
        print("sum predictor time: " + str(a_time) + "\n")

    return outliers, sl_time, scores


def find_anomalies(scores, percent_ranked, initial_window):
    scores = np.array(scores)
    for i in range(initial_window + 1):
        scores[i] = (
            0  # up to initial window + 1 are not considered anomalies. +1 is because of difference score
        )
    num_ranked = int(round(len(scores) * percent_ranked))
    outliers = scores.argsort()[-num_ranked:][::-1]
    outliers.sort()
    return outliers


def simple_detector(sl_embs, plot=False):
    sums = [np.sum(sl) for sl in sl_embs]

    diffs = [0]
    for i in range(1, len(sums)):
        diffs.append(sums[i] - sums[i - 1])

    events = find_anomalies(diffs, 0.05, 10)
    scores = diffs

    plt.savefig("simple_diff.pdf")
    plt.close()
    return scores, events


if __name__ == "__main__":
    fname = "SBM1000"
    use_rrcf = False
    K = 50

    if use_rrcf:
        print("using robust random cut forest")
    else:
        print("using simple sum predictor")

    real_events = [16, 31, 61, 76, 91, 106, 136]
    accus = []
    sl_times = []
    runs = 5
    seeds = list(range(runs))
    for i in range(runs):
        anomalies, sl_time = run_SPOTLIGHT(fname, K=K, use_rrcf=use_rrcf, seed=seeds[i])
        accu = metrics.compute_accuracy(anomalies, real_events)
        accus.append(accu)
        sl_times.append(sl_time)

    accus = np.asarray(accus)
    sl_times = np.asarray(sl_times)
    print(" the mean accuracy is : ", np.mean(accus))
    print(" the std is : ", np.std(accus))

    print(" the mean spotlight time is : ", np.mean(sl_times))
    print(" the std is : ", np.std(sl_times))
