import pickle
import dgl
import evaluation
import layers
import numpy as np
import sampler as sampler_module
import torch
import torch.nn as nn
import torchtext
import tqdm
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import scipy.sparse as sp
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import kneighbors_graph
from scipy.sparse.csgraph import laplacian
from scipy.sparse import coo_matrix, diags, identity, csr_matrix, find, triu
from julia.api import Julia
from my_utils.utils import spade_nonetworkx,sagman_nonetworkx,construct_weighted_adj,spectral_embedding_eig,construct_weighted_adj,hnsw,SPF
import torch.nn.functional as F
import copy
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import pdist, squareform
import networkx as nx
import pandas as pd

def jaccard_similarity(list1, list2):
    intersection = len(set(list1).intersection(list2))
    union = len(set(list1).union(list2))
    return intersection / union

def jaccard_similarity_eval(list1, list2):
    set1 = set(list1)
    set2 = set(list2)
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union else 0

def calculate_recommendation_similarity(original_recommendations, perturbed_recommendations):
    assert original_recommendations.shape == perturbed_recommendations.shape
    similarities = [jaccard_similarity_eval(list(original_recommendations[i]), list(perturbed_recommendations[i]))
                    for i in range(original_recommendations.shape[0])]

    return sum(similarities)/len(similarities)

def create_knn_graph_adj_matrix(total_score, k):

    similarity_matrix = squareform(pdist(total_score, metric=lambda u, v: jaccard_similarity(u, v)))
    num_users = similarity_matrix.shape[0] 
    knn_graph = nx.Graph() 
    knn_graph.add_nodes_from(range(num_users))
    for i in range(num_users):
        nearest_neighbors = np.argsort(similarity_matrix[i])[-k-1:-1]
        for neighbor in nearest_neighbors:
            if not knn_graph.has_edge(i, neighbor):
                knn_graph.add_edge(i, neighbor, weight=similarity_matrix[i][neighbor])
    adj_matrix = nx.adjacency_matrix(knn_graph)
    return adj_matrix

def build_full_adj_matrix(item_to_user,user_to_item):
    num_items, num_users = item_to_user.shape
    zero_item_to_item = sp.csr_matrix((num_items, num_items))
    zero_user_to_user = sp.csr_matrix((num_users, num_users))

    top = sp.hstack([item_to_user, zero_item_to_item])
    bottom = sp.hstack([zero_user_to_user, user_to_item])
    full_adj_matrix = sp.vstack([top, bottom])

    return full_adj_matrix.tocsr()


class PinSAGEModel(nn.Module):
    def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):
        super().__init__()

        self.proj = layers.LinearProjector(full_graph, ntype, textsets, hidden_dims)
        self.sage = layers.SAGENet(hidden_dims, n_layers)
        self.scorer = layers.ItemToItemScorer(full_graph, ntype)

    def forward(self, pos_graph, neg_graph, blocks):
        h_item = self.get_repr(blocks)
        pos_score = self.scorer(pos_graph, h_item)
        neg_score = self.scorer(neg_graph, h_item)
        return (neg_score - pos_score + 1).clamp(min=0)

    def get_repr(self, blocks):
        h_item = self.proj(blocks[0].srcdata)
        h_item_dst = self.proj(blocks[-1].dstdata)
        return h_item_dst + self.sage(blocks, h_item)

def make_symmetric_csr(adj_matrix_csr):
    # Ensure the matrix is in CSR format
    if not isinstance(adj_matrix_csr, csr_matrix):
        raise ValueError("Input matrix must be in CSR format")
    adj_matrix_symmetric = adj_matrix_csr.maximum(adj_matrix_csr.T)
    
    return adj_matrix_symmetric  

def convert_to_scipy_csr(dgl_sparse_matrix):
    row_indices = dgl_sparse_matrix.indices()[0].numpy()
    col_indices = dgl_sparse_matrix.indices()[1].numpy()
    values = dgl_sparse_matrix.val.numpy()
    shape = dgl_sparse_matrix.shape
    coo_matrix = sp.coo_matrix((values, (row_indices, col_indices)), shape=shape)
    csr_matrix = coo_matrix.tocsr()
    return csr_matrix

def add_random_edges(graph, top_node_list, num_edges=2):
    num_movies = graph.number_of_nodes('movie')
    new_edges_src = []
    new_edges_dst = []
    for user_id in top_node_list:
        random_movies = np.random.choice(num_movies, num_edges, replace=False)
        new_edges_src.extend([user_id] * num_edges)
        new_edges_dst.extend(random_movies)
    new_edges_src_tensor = torch.tensor(new_edges_src, dtype=torch.int64)
    new_edges_dst_tensor = torch.tensor(new_edges_dst, dtype=torch.int64)
    graph.add_edges(new_edges_src_tensor, new_edges_dst_tensor, etype=('user', 'watched', 'movie'))

    return graph

def add_selected_edges(graph, selected_users, selected_movies, user_ntype, movie_ntype, edge_type):
    selected_users_tensor = torch.tensor(selected_users, dtype=torch.int64)
    selected_movies_tensor = torch.tensor(selected_movies.copy(), dtype=torch.int64)
    for user_id in selected_users_tensor:
        new_edges_count = 0
        for movie_id in selected_movies_tensor:
            if new_edges_count < 3:
                if not graph.has_edges_between(user_id, movie_id, etype=(user_ntype, edge_type, movie_ntype)):
                    graph.add_edges(user_id, movie_id, etype=(user_ntype, edge_type, movie_ntype))
                    new_edges_count += 1
            else:
                break 

    return graph


def build_user_knn_graph(user_item_matrix, k):
    # Compute cosine similarity between users
    user_similarity = cosine_similarity(user_item_matrix)
    # Build KNN graph
    knn_graph = kneighbors_graph(user_similarity, k, mode='connectivity', include_self=False)
    return knn_graph.tocsr()


def embedding_user_top10(user_top_items, num_items, k):
    num_users = user_top_items.shape[0]
    # Create a binary user-item matrix
    user_item_matrix = np.zeros((num_users, num_items), dtype=int)
    for user_id, items in enumerate(user_top_items):
        user_item_matrix[user_id, items] = 1
    # Compute cosine similarity between users
    user_similarity = cosine_similarity(user_item_matrix)
    # Build KNN graph
    #knn_graph = kneighbors_graph(user_similarity, k, mode='connectivity', include_self=False)
    #knn_graph_symmetric = (knn_graph + knn_graph.T).astype(bool)

    return user_similarity

def filter_array(arr, number1, number2):

    return arr[(arr >= number1) & (arr <= number2)]

def select_random_elements(arr, num_elements):

    if num_elements > len(arr):
        raise ValueError("num_elements is greater than the length of the array")
    selected_indices = np.random.choice(len(arr), num_elements, replace=False)
    selected_array = arr[selected_indices]

    return selected_array



def main(sagman_switch, user_ranking,perturbation_level):

    # Load dataset
    data_info_path = './data_processed/data.pkl'
    with open(data_info_path, "rb") as f:
        dataset = pickle.load(f)
    train_g_path = './data_processed/train_g.bin'
    g_list, _ = dgl.load_graphs(train_g_path)
    dataset["train-graph"] = g_list[0]
    
    g = dataset["train-graph"]
    val_matrix = dataset["val-matrix"].tocsr()
    test_matrix = dataset["test-matrix"].tocsr()
    item_texts = dataset["item-texts"]
    user_ntype = dataset["user-type"]
    item_ntype = dataset["item-type"]
    user_to_item_etype = dataset["user-to-item-type"]
    timestamp = dataset["timestamp-edge-column"]

    # Assign user and movie IDs and use them as features (to learn an individual trainable
    # embedding for each entity)
    g.nodes[user_ntype].data["id"] = torch.arange(g.num_nodes(user_ntype))
    g.nodes[item_ntype].data["id"] = torch.arange(g.num_nodes(item_ntype))

    m2u_adj = g.adjacency_matrix(etype = 'watched-by')
    u2m_adj = g.adjacency_matrix(etype = 'watched')
    m2u_adj1 = convert_to_scipy_csr(m2u_adj)
    u2m_adj1 = convert_to_scipy_csr(u2m_adj)

    # Prepare torchtext dataset and Vocabulary
    textset = {}
    tokenizer = get_tokenizer(None)

    textlist = []
    batch_first = True

    for i in range(g.num_nodes(item_ntype)):
        for key in item_texts.keys():
            l = tokenizer(item_texts[key][i].lower())
            textlist.append(l)
    for key, field in item_texts.items():
        vocab2 = build_vocab_from_iterator(
            textlist, specials=["<unk>", "<pad>"]
        )
        textset[key] = (
            textlist,
            vocab2,
            vocab2.get_stoi()["<pad>"],
            batch_first,
        )

    # Sampler
    batch_sampler = sampler_module.ItemToItemBatchSampler(
        g, user_ntype, item_ntype, 32
    )
    neighbor_sampler = sampler_module.NeighborSampler(
        g,
        user_ntype,
        item_ntype,
        2,
        0.5,
        10,
        3,
        2,
    )
    collator = sampler_module.PinSAGECollator(
        neighbor_sampler, g, item_ntype, textset
    )
    dataloader = DataLoader(
        batch_sampler,
        collate_fn=collator.collate_train,
        num_workers=2,
    )
    dataloader_test = DataLoader(
        torch.arange(g.num_nodes(item_ntype)),
        batch_size=32,
        collate_fn=collator.collate_test,
        num_workers=2,
    )
    dataloader_it = iter(dataloader)

    model = torch.load('./epoch100_trained_model.pt').cuda()
    model.eval()

    with torch.no_grad():
        item_batches = torch.arange(g.num_nodes(item_ntype)).split(32)
        h_item_batches = []
        for blocks in dataloader_test:
            for i in range(len(blocks)):
                blocks[i] = blocks[i].to("cuda:0")

            h_item_batches.append(model.get_repr(blocks))
        h_item = torch.cat(h_item_batches, 0)

    original_h_item = copy.copy(h_item)
    mean_score, top_10_socre, user_top10_item = evaluation.evaluate_nn(dataset, h_item, 10, 32)
    original_user_top10_item = copy.copy(user_top10_item)
    # Build the user-user KNN graph
    the_k = 30
    #user_knn_graph_in = build_user_knn_graph(u2m_adj1, the_k)
    #user_knn_graph_out = build_user_knn_graph_top10(user_top10_item, 3706, the_k)


    #if sagman_switch:
    if sagman_switch:
        adj_graph_in = build_full_adj_matrix(m2u_adj1,u2m_adj1)
        adj_graph_in = make_symmetric_csr(adj_graph_in)
        spec_embed = spectral_embedding_eig(adj_graph_in,'',use_feature=False,adj_norm=True)
        spec_embed = spec_embed[3706:]
        neighs, distance = hnsw(spec_embed, k=the_k)
        embed_adj_mtx = construct_weighted_adj(neighs, distance)
        embed_adj_mtx,inter_edge_adj = SPF(embed_adj_mtx, 4)
    
        #the_output = embedding_user_top10(user_top10_item,m2u_adj1.shape[0],the_k)
        #neighs, distance = hnsw(the_output, k=the_k)
        #output_adj = construct_weighted_adj(neighs, distance)
        output_adj = create_knn_graph_adj_matrix(user_top10_item, the_k)
        output_adj,inter_edge_adj = SPF(output_adj, 4)
        TopEig, _, TopNodeList, _ = sagman_nonetworkx(embed_adj_mtx, output_adj, k=the_k)
    else:
        TopNodeList = user_ranking
    #TopNodeList = filter_array(TopNodeList, 0, h_item.shape[0])
    unstable_user = TopNodeList[:int(TopNodeList.shape[0]*0.01)]
    stable_user = TopNodeList[-int(TopNodeList.shape[0]*0.01):]
    
    users_id = np.arange(0, 6040)
    random_user = select_random_elements(users_id,int(users_id.shape[0]*0.01))

    g = add_random_edges(g, unstable_user, num_edges=perturbation_level)
    # Sampler
    batch_sampler = sampler_module.ItemToItemBatchSampler(
        g, user_ntype, item_ntype, 32
    )
    neighbor_sampler = sampler_module.NeighborSampler(
        g,
        user_ntype,
        item_ntype,
        2,
        0.5,
        10,
        3,
        2,
    )
    collator = sampler_module.PinSAGECollator(
        neighbor_sampler, g, item_ntype, textset
    )
    dataloader = DataLoader(
        batch_sampler,
        collate_fn=collator.collate_train,
        num_workers=2,
    )
    dataloader_test = DataLoader(
        torch.arange(g.num_nodes(item_ntype)),
        batch_size=32,
        collate_fn=collator.collate_test,
        num_workers=2,
    )
    dataloader_it = iter(dataloader)


    with torch.no_grad():
        item_batches = torch.arange(g.num_nodes(item_ntype)).split(32)
        h_item_batches = []
        for blocks in dataloader_test:
            for i in range(len(blocks)):
                blocks[i] = blocks[i].to("cuda:0")

            h_item_batches.append(model.get_repr(blocks))
        h_item = torch.cat(h_item_batches, 0)

    unstable_h_item = copy.copy(h_item)
    # Replace the old graph in the dataset with the updated graph
    dataset["train-graph"] = g
    mean_score, top_10_socre, user_top10_item = evaluation.evaluate_nn(dataset, h_item, 10, 32)

    jaccard_unstable = calculate_recommendation_similarity(original_user_top10_item[unstable_user], user_top10_item[unstable_user])

    #unstable_user_score1 = top_10_socre.any(axis=1)[unstable_user].mean()

    #print('after attack')
    #print('unstable user: the mean jaccard similarity of original and perturbed top 10 item: {}'.format(jaccard_unstable))

    # Load dataset
    data_info_path = './data_processed/data.pkl'
    with open(data_info_path, "rb") as f:
        dataset = pickle.load(f)
    train_g_path = './data_processed/train_g.bin'
    g_list, _ = dgl.load_graphs(train_g_path)
    dataset["train-graph"] = g_list[0]

    g = dataset["train-graph"]
    g.nodes[user_ntype].data["id"] = torch.arange(g.num_nodes(user_ntype))
    g.nodes[item_ntype].data["id"] = torch.arange(g.num_nodes(item_ntype))
    #改g
    g = add_random_edges(g, stable_user, num_edges=perturbation_level)
    # Sampler
    batch_sampler = sampler_module.ItemToItemBatchSampler(
        g, user_ntype, item_ntype, 32
    )
    neighbor_sampler = sampler_module.NeighborSampler(
        g,
        user_ntype,
        item_ntype,
        2,
        0.5,
        10,
        3,
        2,
    )
    collator = sampler_module.PinSAGECollator(
        neighbor_sampler, g, item_ntype, textset
    )
    dataloader = DataLoader(
        batch_sampler,
        collate_fn=collator.collate_train,
        num_workers=2,
    )
    dataloader_test = DataLoader(
        torch.arange(g.num_nodes(item_ntype)),
        batch_size=32,
        collate_fn=collator.collate_test,
        num_workers=2,
    )
    dataloader_it = iter(dataloader)


    with torch.no_grad():
        item_batches = torch.arange(g.num_nodes(item_ntype)).split(32)
        h_item_batches = []
        for blocks in dataloader_test:
            for i in range(len(blocks)):
                blocks[i] = blocks[i].to("cuda:0")

            h_item_batches.append(model.get_repr(blocks))
        h_item = torch.cat(h_item_batches, 0)

    stable_h_item = copy.copy(h_item)
    # Replace the old graph in the dataset with the updated graph
    dataset["train-graph"] = g
    mean_score, top_10_socre, user_top10_item = evaluation.evaluate_nn(dataset, h_item, 10, 32)

    jaccard_stable = calculate_recommendation_similarity(original_user_top10_item[stable_user], user_top10_item[stable_user])


    #stable_user_score1 = top_10_socre.any(axis=1)[stable_user].mean()
    #print('the stable user score: {}'.format(stable_user_score1))
    #print('stable user: the mean jaccard similarity of original and perturbed top 10 item: {}'.format(jaccard_stable))


    return jaccard_unstable, jaccard_stable, TopNodeList
    




 

if __name__ == '__main__':
    sagman_switch = True
    user_ranking = []
    for perturbation_level in [1,2,3,4,5]:
        unstable_list = []
        stable_list = []
        for i in range(20):
            jaccard_unstable, jaccard_stable, user_ranking = main(sagman_switch,user_ranking,perturbation_level)
            unstable_list.append(jaccard_unstable)
            stable_list.append(jaccard_stable)
            sagman_switch = False
        
        print('perturbation level: {}'.format(perturbation_level))
        print('stable user:   the mean jaccard similarity of original and perturbed top 10 item: {}'.format(np.mean(stable_list)))
        print('unstable user: the mean jaccard similarity of original and perturbed top 10 item: {}'.format(np.mean(unstable_list)))
    

    '''''
    # Combine the data and create a label
    #data = stable_list + unstable_list
    #labels = ['Stable User'] * len(stable_list) + ['Unstable User'] * len(unstable_list)
    # Create a DataFrame for each list with an 'Index' column and a 'Value' column
    print('unstable: {}'.format(unstable_list))
    print('stable: {}'.format(stable_list))
    df1 = pd.DataFrame({'Index': range(len(unstable_list)), 'Value': unstable_list, 'List': 'Unstable'})
    df2 = pd.DataFrame({'Index': range(len(stable_list)), 'Value': stable_list, 'List': 'Stable'})
    df = pd.concat([df1, df2])
    sns.set(style="whitegrid")
    plt.figure(figsize=(10, 6))
    sns.lineplot(data=df, x='Index', y='Value', hue='List', palette='viridis', linewidth=2.5)
    plt.title('Comparison of Two Lists', fontsize=18, fontweight='bold')
    plt.xlabel('Iteration', fontsize=14)
    plt.ylabel('Jaccard Similarity', fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.legend(title='User Type', title_fontsize='13', fontsize='12')

    plt.savefig('comparison_plot.png')

    # Show the plot
    plt.show()
    '''''