from my_util import *

# Gaussian Process with Fourier Transfer, filtering through distance， RBF kernel
def BOIM_full(G, config, num_iterations, num_of_sims, candidate_size, diffusion_model, number_of_sources, allowed_shortest_distance, number_of_clusters):

    nl = nx.normalized_laplacian_matrix(G)
    _, eig_vect = np.linalg.eigh(nl.todense())
    UT = np.linalg.inv(eig_vect)
    UT_inv = eig_vect


    candidate_sets = create_candidate_set_pool_filtering(G, candidate_size, number_of_sources, allowed_shortest_distance)

    sets_after_fourier_transfer = fourier_transfer_for_all_candidate_set(candidate_sets ,UT)

    kmeans = KMeans(n_clusters=number_of_clusters, random_state=0).fit(sets_after_fourier_transfer)
    labels = kmeans.labels_

    groups = [[] for i in range(number_of_clusters)]

    for j in range(len(labels)):
        groups[labels[j]].append(sets_after_fourier_transfer[j])

    train_X = []
    train_Y = []

    for i in range(number_of_clusters):

        selected_signal = random.choice(groups[i])
        source_set = find_source_set_from_fourier(selected_signal, number_of_sources, UT_inv)

        if diffusion_model == 'ic':
            e,_ = effectIC(G, config, source_set, num_of_sims)
        elif diffusion_model == 'lt':
            e,_ = effectLT(G, config, source_set, num_of_sims)
        else:
            raise NotImplementedError("Diffusion model not recognized.")

        input = torch.FloatTensor(selected_signal)

        train_X.append(input)
        train_Y.append([float(e)])

    train_X = torch.stack(train_X)
    train_Y = torch.tensor(train_Y)

    function_values = [train_Y.max().item()]
    acquisition_values = []
    max_train_Y_values = [train_Y.max().item()]

    for iteration in range(num_iterations):

        # from each cluster, sample 1 instances, select the one with the highest acquisition function value from the samples
        inputs= []

        for i in range(number_of_clusters):
            samples = random.sample(groups[i], 1)
            for sample in samples:
                inputs.append(torch.FloatTensor(sample))

        inputs = torch.stack(inputs).type(torch.float)

    # Fit a single-output GP model to the observed data
        model = RBFSingleTaskGP(train_X, train_Y)
        mll = ExactMarginalLogLikelihood(model.likelihood, model)
        fit_gpytorch_model(mll)

        acq_func = ExpectedImprovement(model=model, best_f=train_Y.max())

        candidate, acq_value = optimize_acqf_discrete(
            acq_function=acq_func,
            q=1,                                                # Number of candidates to sample in each iteration
            choices = inputs)

        found_candidate = candidate[0]

        signal = found_candidate.tolist()

        selected = find_source_set_from_fourier(signal, number_of_sources, UT_inv)

        if diffusion_model == 'ic':
            e,_ = effectIC(G, config, selected, num_of_sims)
        elif diffusion_model == 'lt':
            e,_ = effectLT(G, config, selected, num_of_sims)

        new_Y = torch.tensor([float(e)])

        # Update the observed data with the new evaluation
        train_X = torch.cat([train_X, candidate], dim=0)
        train_Y = torch.cat([train_Y, new_Y.resize(1,1)], dim=0)  # Add a new dimension for the new evaluation

        # Store function value, acquisition function value, and maximum value of train_Y
        function_values.append(new_Y.item())
        acquisition_values.append(acq_value.item())
        max_train_Y_values.append(train_Y.max().item())

    best = float('-inf')
    identified_signal = None

    for signal in sets_after_fourier_transfer:

        input = torch.FloatTensor([signal])
        y_pred = model(input).loc
        if y_pred > best:
            best = y_pred
            identified_signal = signal

    identified_set = find_source_set_from_fourier(identified_signal, number_of_sources, UT_inv)

    return identified_set

def BOIM_no_GSS(G, config, num_iterations, num_of_sims, candidate_size, diffusion_model, number_of_sources, allowed_shortest_distance):

    nl = nx.normalized_laplacian_matrix(G)
    _, eig_vect = np.linalg.eigh(nl.todense())
    UT = np.linalg.inv(eig_vect)
    UT_inv = eig_vect

    candidate_sets = create_candidate_set_pool_filtering(G, candidate_size, number_of_sources, allowed_shortest_distance)

    train_X = []
    train_Y = []

    for i in range(20):

        selected_set = random.sample(candidate_sets, 1)[0]
        selected_signal = create_signal_from_source_set(G, selected_set, UT)

        if diffusion_model == 'ic':
            e,_ = effectIC(G, config, selected_set, num_of_sims)
        elif diffusion_model == 'lt':
            e,_ = effectLT(G, config, selected_set, num_of_sims)
        else:
            raise NotImplementedError("Diffusion model not recognized.")
        
        input = torch.FloatTensor(selected_signal)

        train_X.append(input)
        train_Y.append([float(e)])

    train_X = torch.stack(train_X)
    train_Y = torch.tensor(train_Y)

    function_values = [train_Y.max().item()]
    acquisition_values = []
    max_train_Y_values = [train_Y.max().item()]

    for iteration in range(num_iterations):

        # from each cluster, sample 1 instances, select the one with the highest acquisition function value from the samples
        inputs= []

        for i in range(20):

            selected_set = random.sample(candidate_sets, 1)[0]
            selected_signal = create_signal_from_source_set(G, selected_set, UT)

            input = torch.FloatTensor(selected_signal)
            inputs.append(input)

        inputs = torch.stack(inputs).type(torch.float)

    # Fit a single-output GP model to the observed data
        model = RBFSingleTaskGP(train_X, train_Y)
        mll = ExactMarginalLogLikelihood(model.likelihood, model)
        fit_gpytorch_model(mll)

        acq_func = ExpectedImprovement(model=model, best_f=train_Y.max())

        candidate, acq_value = optimize_acqf_discrete(
            acq_function=acq_func,
            q=1,                                                # Number of candidates to sample in each iteration
            choices = inputs)

        found_candidate = candidate[0]

        signal = found_candidate.tolist()

        selected = find_source_set_from_fourier(signal, number_of_sources, UT_inv)

        if diffusion_model == 'ic':
            e,_ = effectIC(G, config, selected, num_of_sims)
        elif diffusion_model == 'lt':
            e,_ = effectLT(G, config, selected, num_of_sims)

        new_Y = torch.tensor([float(e)])

        # Update the observed data with the new evaluation
        train_X = torch.cat([train_X, candidate], dim=0)
        train_Y = torch.cat([train_Y, new_Y.resize(1,1)], dim=0)  # Add a new dimension for the new evaluation

        # Store function value, acquisition function value, and maximum value of train_Y
        function_values.append(new_Y.item())
        acquisition_values.append(acq_value.item())
        max_train_Y_values.append(train_Y.max().item())

    index = [index for index, item in enumerate(train_Y) if item == train_Y.max()]
    s = list(train_X[index[0]])
    # final_index = sets_after_fourier_transfer.index(s)
    identified_set = find_source_set_from_fourier(s, number_of_sources, UT_inv)

    return identified_set

# Gaussian Process with node selection, distance filtering, RBF kernel


def BOIM_no_fourier(G, config, num_iterations, num_of_sims, candidate_size, diffusion_model, number_of_sources, allowed_shortest_distance):

    candidate_sets = create_candidate_set_pool_filtering(G, candidate_size, number_of_sources, allowed_shortest_distance)
    
    deg = sorted(G.degree, key=lambda x: x[1], reverse=True)

    candidates = deg[:candidate_size]
    # initialize the GP model with several (c, s) pairs

    train_X = []
    train_Y = []
    source_sets = []

    initial_indices = random.sample(range(len(candidate_sets)), 20)

    for index in initial_indices:
        source_set = candidate_sets[index]

        if diffusion_model == 'ic':
            e,_ = effectIC(G, config, source_set, num_of_sims)
        elif diffusion_model == 'lt':
            e,_ = effectLT(G, config, source_set, num_of_sims)
        else:
            raise NotImplementedError("Diffusion model not recognized.")
            

        input = []
        for item in candidates:
            if item[0] in source_set:
                input.append(1)
            else:
                input.append(0)
        input = torch.tensor(input)

        train_X.append(input)
        train_Y.append([float(e)])
        source_sets.append(source_set)

    train_X = torch.stack(train_X).type(torch.double)
    train_Y = torch.tensor(train_Y).type(torch.double)

    function_values = [train_Y.max().item()]
    acquisition_values = []
    max_train_Y_values = [train_Y.max().item()]

    for iteration in range(num_iterations):
        # Fit a single-output GP model to the observed data
        model = RBFSingleTaskGP(train_X, train_Y)
        mll = ExactMarginalLogLikelihood(model.likelihood, model)
        fit_gpytorch_model(mll)

        acq_func = ExpectedImprovement(model=model, best_f=train_Y.max())

        ################################################
        # 20 random samples
        ################################################
        combs = []

        samples = random.sample(candidate_sets, 20)

        for source_set in samples:
            input = []
            for item in candidates:
                if item[0] in source_set:
                    input.append(1)
                else:
                    input.append(0)

            input = torch.tensor([input])
            combs.append(input)

        combs = torch.stack(combs).type(torch.double)

        candidate, acq_value = optimize_acqf_discrete(
            acq_function=acq_func,
            q=1,                                                # Number of candidates to sample in each iteration
            choices = combs)

        selected = []
        for i in range(candidate_size):
            if list(candidate.squeeze().numpy())[i] == 1:
                selected.append(candidates[i])

        if diffusion_model == 'ic':
            e,_ = effectIC(G, config, selected, num_of_sims)
        elif diffusion_model == 'lt':
            e,_ = effectLT(G, config, selected, num_of_sims)

        new_Y = torch.tensor([float(e)])

        # Update the observed data with the new evaluation
        train_X = torch.cat([train_X, candidate[0]], dim=0)
        train_Y = torch.cat([train_Y, new_Y.resize(1,1)], dim=0)  # Add a new dimension for the new evaluation

        # Store function value, acquisition function value, and maximum value of train_Y
        function_values.append(new_Y.item())
        acquisition_values.append(acq_value.item())
        max_train_Y_values.append(train_Y.max().item())
        source_sets.append(selected)

    index = [index for index, item in enumerate(train_Y) if item == train_Y.max()]
    s = list(train_X[index[0]])
    result = []
    for i in range(candidate_size):
        if int(s[i]) == 1:
            result.append(candidates[i][0])

    return result

# Gaussian Process with fourier transfer, no filtering, RBF kernel

def BOIM_no_filtering(G, config, num_iterations, num_of_sims, candidate_size, diffusion_model, number_of_sources):

    nl = nx.normalized_laplacian_matrix(G)
    _, eig_vect = np.linalg.eigh(nl.todense())
    UT = np.linalg.inv(eig_vect)
    UT_inv = eig_vect


    candidate_sets = create_candidate_set_pool(G, candidate_size, number_of_sources)

    sets_after_fourier_transfer = fourier_transfer_for_all_candidate_set(candidate_sets ,UT)

    train_X = []
    train_Y = []

    for i in range(20):

        selected_set = random.sample(candidate_sets, 1)[0]
        selected_signal = create_signal_from_source_set(G, selected_set, UT)

        if diffusion_model == 'ic':
            e,_ = effectIC(G, config, selected_set, num_of_sims)
        elif diffusion_model == 'lt':
            e,_ = effectLT(G, config, selected_set, num_of_sims)
        else:
            raise NotImplementedError("Diffusion model not recognized.")

        input = torch.FloatTensor(selected_signal)

        train_X.append(input)
        train_Y.append([float(e)])

    train_X = torch.stack(train_X)
    train_Y = torch.tensor(train_Y)

    function_values = [train_Y.max().item()]
    acquisition_values = []
    max_train_Y_values = [train_Y.max().item()]

    for iteration in range(num_iterations):

        # from each cluster, sample 1 instances, select the one with the highest acquisition function value from the samples
        inputs= []

        for i in range(20):

            selected_set = random.sample(candidate_sets, 1)[0]
            selected_signal = create_signal_from_source_set(G, selected_set, UT)

            input = torch.FloatTensor(selected_signal)
            inputs.append(input)

        inputs = torch.stack(inputs).type(torch.float)

    # Fit a single-output GP model to the observed data
        model = RBFSingleTaskGP(train_X, train_Y)
        mll = ExactMarginalLogLikelihood(model.likelihood, model)
        fit_gpytorch_model(mll)

        acq_func = ExpectedImprovement(model=model, best_f=train_Y.max())

        candidate, acq_value = optimize_acqf_discrete(
            acq_function=acq_func,
            q=1,                                                # Number of candidates to sample in each iteration
            choices = inputs)

        found_candidate = candidate[0]

        signal = found_candidate.tolist()

        selected = find_source_set_from_fourier(signal, number_of_sources, UT_inv)

        if diffusion_model == 'ic':
            e,_ = effectIC(G, config, selected, num_of_sims)
        elif diffusion_model == 'lt':
            e,_ = effectLT(G, config, selected, num_of_sims)

        new_Y = torch.tensor([float(e)])

        # Update the observed data with the new evaluation
        train_X = torch.cat([train_X, candidate], dim=0)
        train_Y = torch.cat([train_Y, new_Y.resize(1,1)], dim=0)  # Add a new dimension for the new evaluation

        # Store function value, acquisition function value, and maximum value of train_Y
        function_values.append(new_Y.item())
        acquisition_values.append(acq_value.item())
        max_train_Y_values.append(train_Y.max().item())

    index = [index for index, item in enumerate(train_Y) if item == train_Y.max()]
    s = list(train_X[index[0]])
    # final_index = sets_after_fourier_transfer.index(s)
    identified_set = find_source_set_from_fourier(s, number_of_sources, UT_inv)

    return identified_set


# Gaussian Process with node selection, no filtering, RBF kernel

def BOIM_vanilla(G, config, num_iterations, num_of_sims, candidate_size, diffusion_model, number_of_sources):

    candidate_sets = create_candidate_set_pool_filtering(G, candidate_size, number_of_sources)
    
    deg = sorted(G.degree, key=lambda x: x[1], reverse=True)

    candidates = deg[:candidate_size]
    # initialize the GP model with several (c, s) pairs

    train_X = []
    train_Y = []
    source_sets = []

    initial_indices = random.sample(range(len(candidate_sets)), 20)

    for index in initial_indices:
        source_set = candidate_sets[index]

        if diffusion_model == 'ic':
            e,_ = effectIC(G, config, source_set, num_of_sims)
        elif diffusion_model == 'lt':
            e,_ = effectLT(G, config, source_set, num_of_sims)
        else:
            raise NotImplementedError("Diffusion model not recognized.")
            

        input = []
        for item in candidates:
            if item[0] in source_set:
                input.append(1)
            else:
                input.append(0)
        input = torch.tensor(input)

        train_X.append(input)
        train_Y.append([float(e)])
        source_sets.append(source_set)

    train_X = torch.stack(train_X).type(torch.double)
    train_Y = torch.tensor(train_Y).type(torch.double)

    function_values = [train_Y.max().item()]
    acquisition_values = []
    max_train_Y_values = [train_Y.max().item()]

    for iteration in range(num_iterations):
        # Fit a single-output GP model to the observed data
        model = RBFSingleTaskGP(train_X, train_Y)
        mll = ExactMarginalLogLikelihood(model.likelihood, model)
        fit_gpytorch_model(mll)

        acq_func = ExpectedImprovement(model=model, best_f=train_Y.max())

        ################################################
        # 20 random samples
        ################################################
        combs = []

        samples = random.sample(candidate_sets, 20)

        for source_set in samples:
            input = []
            for item in candidates:
                if item[0] in source_set:
                    input.append(1)
                else:
                    input.append(0)

            input = torch.tensor([input])
            combs.append(input)

        combs = torch.stack(combs).type(torch.double)

        candidate, acq_value = optimize_acqf_discrete(
            acq_function=acq_func,
            q=1,                                                # Number of candidates to sample in each iteration
            choices = combs)

        selected = []
        for i in range(candidate_size):
            if list(candidate.squeeze().numpy())[i] == 1:
                selected.append(candidates[i])

        if diffusion_model == 'ic':
            e,_ = effectIC(G, config, selected, num_of_sims)
        elif diffusion_model == 'lt':
            e,_ = effectLT(G, config, selected, num_of_sims)

        new_Y = torch.tensor([float(e)])

        # Update the observed data with the new evaluation
        train_X = torch.cat([train_X, candidate[0]], dim=0)
        train_Y = torch.cat([train_Y, new_Y.resize(1,1)], dim=0)  # Add a new dimension for the new evaluation

        # Store function value, acquisition function value, and maximum value of train_Y
        function_values.append(new_Y.item())
        acquisition_values.append(acq_value.item())
        max_train_Y_values.append(train_Y.max().item())
        source_sets.append(selected)

    index = [index for index, item in enumerate(train_Y) if item == train_Y.max()]
    s = list(train_X[index[0]])
    result = []
    for i in range(candidate_size):
        if int(s[i]) == 1:
            result.append(candidates[i][0])

    return result


def eigen(g, config, budget):

    g_eig = g.__class__()
    g_eig.add_nodes_from(g)
    g_eig.add_edges_from(g.edges)
    for a, b in g_eig.edges():
        weight = config.config["edges"]['threshold'][(a, b)]
        g_eig[a][b]['weight'] = weight

    eig = []

    for k in range(budget):

        eigen = nx.eigenvector_centrality_numpy(g_eig)
        selected = sorted(eigen, key=eigen.get, reverse=True)[0]
        eig.append(selected)
        g_eig.remove_node(selected)

    return eig

def degree(g, config, budget):
    g_deg = g.__class__()
    g_deg.add_nodes_from(g)
    g_deg.add_edges_from(g.edges)
    for a, b in g_deg.edges():
        weight = config.config["edges"]['threshold'][(a, b)]
        g_deg[a][b]['weight'] = weight

    deg = []

    for k in range(budget):
        degree = nx.centrality.degree_centrality(g_deg)
        selected = sorted(degree, key=degree.get, reverse=True)[0]
        deg.append(selected)
        g_deg.remove_node(selected)

    return deg

def pi(g, config, budget):
    g_greedy = g.__class__()
    g_greedy.add_nodes_from(g)
    g_greedy.add_edges_from(g.edges)

    for a, b in g_greedy.edges():
        weight = config.config["edges"]['threshold'][(a, b)]
        g_greedy[a][b]['weight'] = weight

    result = []

    for k in range(budget):

        n = g_greedy.number_of_nodes()

        I = np.ones((n, 1))

        C = np.ones((n, n))
        N = np.ones((n, n))

        A = nx.to_numpy_array(g_greedy, nodelist=list(g_greedy.nodes()))

        for i in range(5):
            B = np.power(A, i + 1)
            D = C - B
            N = np.multiply(N, D)

        P = C - N

        pi = np.matmul(P, I)

        value = {}

        for i in range(n):
            value[list(g_greedy.nodes())[i]] = pi[i, 0]

        selected = sorted(value, key=value.get, reverse=True)[0]

        result.append(selected)

        g_greedy.remove_node(selected)

    return result

def sigma(g, config, budget):
    g_greedy = g.__class__()
    g_greedy.add_nodes_from(g)
    g_greedy.add_edges_from(g.edges)

    for a, b in g_greedy.edges():
        weight = config.config["edges"]['threshold'][(a, b)]
        g_greedy[a][b]['weight'] = weight

    result = []

    for k in range(budget):

        n = g_greedy.number_of_nodes()

        I = np.ones((n, 1))

        F = np.ones((n, n))
        N = np.ones((n, n))

        A = nx.to_numpy_array(g, nodelist=g_greedy.nodes())

        sigma = I
        for i in range(5):
            B = np.power(A, i + 1)
            C = np.matmul(B, I)
            sigma += C

        value = {}

        for i in range(n):
            value[list(g_greedy.nodes())[i]] = sigma[i, 0]

        selected = sorted(value, key=value.get, reverse=True)[0]

        result.append(selected)

        g_greedy.remove_node(selected)

    return result

def Netshield(g, config, budget):

    g_greedy = g.__class__()
    g_greedy.add_nodes_from(g)
    g_greedy.add_edges_from(g.edges)

    for a, b in g_greedy.edges():
        weight = config.config["edges"]['threshold'][(a, b)]
        g_greedy[a][b]['weight'] = weight

    A = nx.adjacency_matrix(g_greedy)

    lam, u = np.linalg.eigh(A.toarray())
    lam = list(lam)
    lam = lam[-1]

    u = u[:, -1]

    u = np.abs(np.real(u).flatten())
    v = (2 * lam * np.ones(len(u))) * np.power(u, 2)

    nodes = []
    for i in range(budget):
        B = A[:, nodes]
        b = B * u[nodes]

        score = v - 2 * b * u
        score[nodes] = -1

        nodes.append(np.argmax(score))

    return nodes

def Soboldeg(g, config, budget):
    g_deg = g.__class__()
    g_deg.add_nodes_from(g)
    g_deg.add_edges_from(g.edges)
    for a, b in g_deg.edges():
        weight = config.config["edges"]['threshold'][(a, b)]
        g_deg[a][b]['weight'] = weight

    deg = []

    for k in range(2*budget):
        degree = nx.centrality.degree_centrality(g_deg)
        selected = sorted(degree, key=degree.get, reverse=True)[0]
        deg.append(selected)
        g_deg.remove_node(selected)


    for j in range(budget):
        df = simulationIC(1, g, deg, config)
        ST = SobolT(df, deg)
        rank = []
        for node in sorted(ST, key=ST.get, reverse=True):
            rank.append(node)
        rem = rank.pop()
        deg.remove((rem))

    return deg

def degreeDis(g, config, budget):

    selected = []
    d = {}
    t = {}
    dd = hd.heapdict()

    for node in g.nodes():
        d[node] = sum([g[node][v]['weight'] for v in g[node]])
        dd[node] = -d[node]
        t[node] = 0

    for i in range(budget):
        seed, _ = dd.popitem()
        selected.append(seed)
        for v in g.neighbors(seed):
            if v not in selected:
                t[v] += g[seed][v]['weight']
                discount = d[v] - 2*t[v] - (d[v] - t[v])*t[v]
                dd[v] = -discount

    return selected

def SoboldegreeDis(g, config, budget):

    selected = []
    d = {}
    t = {}
    dd = hd.heapdict()

    for node in g.nodes():
        d[node] = sum([g[node][v]['weight'] for v in g[node]])
        dd[node] = -d[node]
        t[node] = 0

    for i in range(2*budget):
        seed, _ = dd.popitem()
        selected.append(seed)
        for v in g.neighbors(seed):
            if v not in selected:
                t[v] += g[seed][v]['weight']
                discount = d[v] - 2*t[v] - (d[v] - t[v])*t[v]
                dd[v] = -discount

    for j in range(budget):
        df = simulationIC(10, g, selected, config)
        ST = SobolT(df, selected)
        rank = []
        for node in sorted(ST, key=ST.get, reverse=True):
            rank.append(node)
        rem = rank.pop()
        selected.remove((rem))

    return selected