import os.path

import numpy as np
import torch
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric
import pickle
from explainers.gnn_explainer import GNNExplainer
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.cm as cm

import os
import torch
import pandas as pd
from math import log
import scipy.sparse as sp
from dataset import TextDataset
from torch.optim.swa_utils import AveragedModel
from model import TextGNN
from eval import eval,MovingAverage

from pprint import pprint

from torch_geometric.utils.subgraph import subgraph,k_hop_subgraph
from transformers import BertTokenizer
from collections import OrderedDict
from collections import defaultdict
from tqdm import tqdm

from os.path import join, exists

from sklearn import metrics



model_name = 'google-bert/bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=True)

vocab_dict = tokenizer.vocab

reversed_dict = OrderedDict()
for key, val in vocab_dict.items():
    reversed_dict[val] = key
print("")


def build_text_graph_dirdataset(train_file, valid_file,test_file, catch_path=None,use_cach = True, window_size=20,
                                max_train_per_class=1000):


    if use_cach and os.path.exists(catch_path):
        textdataset = torch.load(catch_path)
        return textdataset
    else:
        pass

    datasets = [train_file,
                valid_file,
                test_file]
    if isinstance(train_file,list):
        train_data = []
        for train_f in train_file:
            train_d = pd.read_csv(train_f)
            train_data.append(train_d)

        val_data = []
        for valid_f in valid_file:
            valid_d = pd.read_csv(valid_f)
            val_data.append(valid_d)

        test_data = []
        for test_f in test_file:
            test_d = pd.read_csv(test_f)
            test_data.append(test_d)

    else:
        train_data = pd.read_csv(datasets[0])
        val_data = pd.read_csv(datasets[1])
        test_data = pd.read_csv(datasets[2])


    if 'hc3' in train_file:
        train_lines_human = train_data['human_answers'].values[:max_train_per_class].tolist()
        train_labels_human = ['human'] * len(train_lines_human)
        train_lines_chatgpt = train_data['chatgpt_answers'].values[:max_train_per_class].tolist()
        train_labels_chatgpt = ['chatgpt'] * len(train_lines_chatgpt)
        train_lines = []
        train_lines.extend(train_lines_human)
        train_lines.extend(train_lines_chatgpt)
        train_labels = []
        train_labels.extend(train_labels_human)
        train_labels.extend(train_labels_chatgpt)

        valid_lines_human = val_data['human_answers'].values.tolist()
        valid_labels_human = ['human'] * len(valid_lines_human)
        valid_lines_chatgpt = val_data['chatgpt_answers'].values.tolist()
        valid_labels_chatgpt = ['chatgpt'] * len(valid_lines_chatgpt)
        val_lines = []
        val_lines.extend(valid_lines_human)
        val_lines.extend(valid_lines_chatgpt)
        val_labels = []
        val_labels.extend(valid_labels_human)
        val_labels.extend(valid_labels_chatgpt)

        test_lines_human = test_data['human_answers'].values.tolist()
        test_labels_human = ['human'] * len(test_lines_human)
        test_lines_chatgpt = val_data['chatgpt_answers'].values.tolist()
        test_labels_chatgpt = ['chatgpt'] * len(test_lines_chatgpt)
        test_lines = []
        test_lines.extend(test_lines_human)
        test_lines.extend(test_lines_chatgpt)
        test_labels = []
        test_labels.extend(test_labels_human)
        test_labels.extend(test_labels_chatgpt)

    elif 'm4' in train_file:
        train_lines_human = train_data['human_text'].values[:max_train_per_class].tolist()
        train_labels_human = ['human'] * len(train_lines_human)
        train_lines_chatgpt = train_data['machine_text'].values[:max_train_per_class].tolist()
        train_labels_chatgpt = ['chatgpt'] * len(train_lines_chatgpt)
        train_lines = []
        train_lines.extend(train_lines_human)
        train_lines.extend(train_lines_chatgpt)
        train_labels = []
        train_labels.extend(train_labels_human)
        train_labels.extend(train_labels_chatgpt)


        valid_lines_human = val_data['human_text'].values.tolist()
        valid_labels_human = ['human'] * len(valid_lines_human)
        valid_lines_chatgpt = val_data['machine_text'].values.tolist()
        valid_labels_chatgpt = ['chatgpt'] * len(valid_lines_chatgpt)
        val_lines = []
        val_lines.extend(valid_lines_human)
        val_lines.extend(valid_lines_chatgpt)
        val_labels = []
        val_labels.extend(valid_labels_human)
        val_labels.extend(valid_labels_chatgpt)

        test_lines_human = test_data['human_text'].values.tolist()
        test_labels_human = ['human'] * len(test_lines_human)
        test_lines_chatgpt = val_data['machine_text'].values.tolist()
        test_labels_chatgpt = ['chatgpt'] * len(test_lines_chatgpt)
        test_lines = []
        test_lines.extend(test_lines_human)
        test_lines.extend(test_lines_chatgpt)
        test_labels = []
        test_labels.extend(test_labels_human)
        test_labels.extend(test_labels_chatgpt)

    elif 'raid' in train_file:
        train_lines = []
        train_labels = []
        for t_data in train_data:
            train_lines_ = t_data['generation'].values[:max_train_per_class].tolist()
            train_labels_ = t_data['model'].values[:max_train_per_class].tolist()
            train_lines.extend(train_lines_)
            train_lines.extend(train_labels_)

        val_lines = []
        val_labels = []
        for t_data in val_data:
            val_lines_ = t_data['generation'].values.tolist()
            val_labels_ = t_data['model'].values.tolist()
            val_lines.extend(val_lines_)
            val_labels.extend(val_labels_)

        test_lines = []
        test_labels = []
        for t_data in test_data:
            test_lines_ = t_data['generation'].values[:max_train_per_class].tolist()
            test_labels_ = t_data['model'].values[:max_train_per_class].tolist()
            test_lines.extend(test_lines_)
            test_labels.extend(test_labels_)

    else:
        raise ValueError('no such dataset')

    train_split = {}
    doc_list = []
    train_doc_list = []
    count = 0
    for line in train_lines:
        train_doc_list.append(line.strip())
        doc_list.append(line.strip())
        train_split[count]='train'
        count+=1
    for line in val_lines:
        doc_list.append(line.strip())
        train_split[count]='val'
        count += 1
    for line in test_lines:
        doc_list.append(line.strip())
        train_split[count]='test'
        count += 1

    labels = []
    for line in train_labels:
        if 'human' in line:
            labels.append(0)
        elif 'chatgpt' in line:
            labels.append(1)

    for line in val_labels:
        if 'human' in line:
            labels.append(0)
        elif 'chatgpt' in line:
            labels.append(1)

    for line in test_labels:
        if 'human' in line:
            labels.append(0)
        elif 'chatgpt' in line:
            labels.append(1)
    print(max(labels),min(labels))

    assert len(labels) == len(doc_list)

    word_freq,doc_lists_ = get_vocab2(train_doc_list,doc_list)
    vocab = list(word_freq.keys())

    words_in_docs, word_doc_freq = build_word_doc_edges1(doc_lists_)
    word_id_map = {word: i for i, word in enumerate(vocab)}

    sparse_graph = build_edges1(doc_lists_, word_id_map, vocab, word_doc_freq, window_size,direction=True)
    docs_dict = {i: doc for i, doc in enumerate(doc_lists_)}
    textdataset = TextDataset('raid_m', sparse_graph, labels, vocab, word_id_map, docs_dict, None,
                       train_test_split=train_split)
    torch.save(textdataset,catch_path)
    return textdataset

def build_edges1(doc_list, word_id_map, vocab, word_doc_freq, window_size=20,direction=True):
    # constructing all windows
    windows = []
    for doc_words in tqdm(doc_list):
        words = doc_words #.split()
        doc_length = len(words)
        if doc_length <= window_size:
            windows.append(words)
        else:
            for i in range(doc_length - window_size + 1):
                window = words[i: i + window_size]
                windows.append(window)
    # constructing all single word frequency
    word_window_freq = defaultdict(int)
    for window in tqdm(windows):
        appeared = set()
        for word in window:
            if word not in appeared:
                word_window_freq[word] += 1
                appeared.add(word)
    # constructing word pair count frequency
    word_pair_count = defaultdict(int)
    for window in tqdm(windows):
        for i in range(1, len(window)):
            for j in range(i):
                word_i = window[i]
                word_j = window[j]
                if word_i not in word_id_map.keys():
                    continue
                if word_j not in word_id_map.keys():
                    continue
                word_i_id = word_id_map[word_i]
                word_j_id = word_id_map[word_j]
                if word_i_id == word_j_id:
                    continue
                if direction:
                    # word_pair_count[(word_i_id, word_j_id)] += 1
                    word_pair_count[(word_j_id, word_i_id)] += 1
                else:
                    word_pair_count[(word_i_id, word_j_id)] += 1
                    word_pair_count[(word_j_id, word_i_id)] += 1
    row = []
    col = []
    weight = []

    # pmi as weights
    num_docs = len(doc_list)
    num_window = len(windows)
    for word_id_pair, count in tqdm(word_pair_count.items()):
        i, j = word_id_pair[0], word_id_pair[1]
        word_freq_i = word_window_freq[vocab[i]]
        word_freq_j = word_window_freq[vocab[j]]
        pmi = log((1.0 * count / num_window) /
                  (1.0 * word_freq_i * word_freq_j / (num_window * num_window)))
        if pmi <= 0:
            continue
        row.append(num_docs + i)
        col.append(num_docs + j)
        weight.append(pmi)

    # frequency of document word pair
    doc_word_freq = defaultdict(int)
    for i, doc_words in enumerate(doc_list):
        words = doc_words#.split()
        for word in words:
            if word not in word_id_map.keys():
                continue
            word_id = word_id_map[word]
            doc_word_str = (i, word_id)
            doc_word_freq[doc_word_str] += 1

    for i, doc_words in enumerate(doc_list):
        words = doc_words #.split()
        doc_word_set = set()
        for word in words:
            if word not in word_id_map:
                continue
            word_id = word_id_map[word]
            freq = doc_word_freq[(i, word_id)]
            row.append(i)
            col.append(num_docs + word_id)
            idf = log(1.0 * num_docs /
                      (word_doc_freq[vocab[word_id]]+1E-8))
            weight.append(freq * idf)
            doc_word_set.add(word)

    number_nodes = num_docs + len(vocab)
    adj_mat = sp.csr_matrix((weight, (row, col)), shape=(number_nodes, number_nodes))
    adj = adj_mat + adj_mat.T.multiply(adj_mat.T > adj_mat) - adj_mat.multiply(adj_mat.T > adj_mat)
    return adj


def get_vocab2(tain_text_list, text_list):
    doc_lists = []
    word_freq = defaultdict(int)
    for doc_words in tqdm(tain_text_list):
        encoded_data = tokenizer(doc_words)
        input_ids = encoded_data["input_ids"]
        words = input_ids
        for word in words:
            word_freq[word] += 1


    for doc_words in tqdm(text_list):
        encoded_data = tokenizer(doc_words)
        input_ids = encoded_data["input_ids"]
        doc_lists.append(input_ids)

    return word_freq,doc_lists


def build_word_doc_edges1(doc_list):
    # build all docs that a word is contained in
    words_in_docs = defaultdict(set)
    for i, doc_words in enumerate(doc_list):
        words = doc_words#.split()
        for word in words:
            words_in_docs[word].add(i)

    word_doc_freq = {}
    for word, doc_list in words_in_docs.items():
        word_doc_freq[word] = len(doc_list)

    return words_in_docs, word_doc_freq


def local_explain_preseve_evaluation_unweights(dataset_name = 'm4',domain = 'wikihow',model_name='chatgpt',
                                               topk_ratios=[1, 5,10,50,100,0.001,0.005,0.01,0.05,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]):
    device = 'cuda:2'

    if dataset_name == 'hc3':
        dir = '../dataset_lists/hc3'
        save_dir = '../dataset_lists/hc3' + '/' + f'unweighted_{domain}'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        train_file = dir + '/' + f'{domain}_train.csv'
        valid_file = dir + '/' + f'{domain}_valid.csv'
        test_file = dir + '/' + f'{domain}_test.csv'
    elif dataset_name == 'm4':
        dir = '../dataset_lists/m4'
        save_dir = '../dataset_lists/m4' + '/' + f'unweighted_{domain}_{model_name}'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        train_file = dir + '/' + f'{domain}_{model_name}_train.csv'
        valid_file = dir + '/' + f'{domain}_{model_name}_valid.csv'
        test_file = dir + '/' + f'{domain}_{model_name}_test.csv'
    elif dataset_name == 'raid':
        dir = '../dataset_lists/raid'
        train_file = []
        valid_file = []
        test_file = []
        save_dir = '../dataset_lists/raid' + '/' + f'unweighted_{domain}_{model_name}_human'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        train_m = dir + '/' + f'{domain}_{model_name}_train.csv'
        valid_m = dir + '/' + f'{domain}_{model_name}_valid.csv'
        test_m = dir + '/' + f'{domain}_{model_name}_test.csv'
        train_h = dir + '/' + f'{domain}_human_train.csv'
        valid_h = dir + '/' + f'{domain}_human_valid.csv'
        test_h = dir + '/' + f'{domain}_human_test.csv'
        train_file.append(train_m)
        train_file.append(train_h)
        valid_file.append(valid_m)
        valid_file.append(valid_h)
        test_file.append(test_m)
        test_file.append(test_h)
    else:
        raise  ValueError('dataset name error')

    textdataset =  build_text_graph_dirdataset(train_file=train_file,valid_file=valid_file,test_file=test_file,
                                               catch_path = save_dir +'/'+'dataset_cache.pth',use_cach=True,
                                               window_size=20)

    words_length = textdataset.words_length
    train_dataset, val_dataset, test_dataset = textdataset.tvt_split()
    train_dataset.init_node_feats('one_hot_node_init', device)
    val_dataset.init_node_feats('one_hot_node_init', device)
    test_dataset.init_node_feats('one_hot_node_init', device)

    pyg_graph = train_dataset.get_pyg_graph(device)

    model = TextGNN(pred_type='softmax', node_embd_type='gcn',
            num_layers=2, layer_dim_list=[words_length,64,2], act='silu',
            bn=True, num_labels=2, class_weights=True, dropout=True,use_weight=False)
    model = AveragedModel(model)
    model.to(device)
    model.load_state_dict(torch.load(save_dir+'/direct_graph_gnnmodel_best.pth',map_location=device))
    model = model.module

    with torch.no_grad():
        val_loss, preds_val = model(pyg_graph, val_dataset)
        test_loss, preds_test = model(pyg_graph, test_dataset)
        eval_res_val = eval(preds_val, val_dataset)
        test_res_val = eval(preds_test, test_dataset)
        pprint(eval_res_val)
        pprint(test_res_val)

    explainer = GNNExplainer(device, model)

    print(dataset_name,domain)
    print('test')
    count = 0
    ori_correct = 0
    remove_correct_lists = []
    maintain_correct_lists = []

    pred_label = []
    remove_pred_values = []
    maintain_pred_values = []

    total_ori_nodes = []
    total_ori_edges = []
    remove_total_sub_edges = []
    maintain_total_sub_edges = []
    for i in range(len(topk_ratios)):
        remove_total_sub_edges.append([])
        maintain_total_sub_edges.append([])

        remove_pred_values.append([])
        maintain_pred_values.append([])

        remove_correct_lists.append(0)
        maintain_correct_lists.append(0)

    for i in tqdm(range(len(test_dataset.node_ids))):
        idx = test_dataset.node_ids[i]
        label = test_dataset.labels[idx]
        pred_label.append(label)
        subset, edge_index, inv, edge_mask = k_hop_subgraph(node_idx=idx, num_hops=1, edge_index=pyg_graph.edge_index)
        total_ori_nodes.append(len(subset))
        # all_nodes = np.arange(pyg_graph.x.shape[0]).tolist()
        # word_nodes = set(all_nodes) - set(train_dataset.node_ids) \
        #              - set(test_dataset.node_ids) - set(val_dataset.node_ids)
        # subset_nodes = [idx]
        # subset_nodes.extend(list(word_nodes))
        #
        # new_set = set(subset.cpu().numpy().tolist()).intersection(set(subset_nodes))
        # new_set = list(new_set)
        # total_ori_nodes.append(len(new_set))
        # edge_index, edge_attr = subgraph(new_set, pyg_graph.edge_index, pyg_graph.edge_attr)
        # new_pyg_graph = Data(x=pyg_graph.x, edge_index=edge_index, edge_attr=edge_attr)
        new_pyg_graph = Data(x=pyg_graph.x, edge_index=edge_index, edge_attr=pyg_graph.edge_attr[edge_mask])
        ori_edges = edge_index.shape[1]
        total_ori_edges.append(ori_edges)
        edge_imp = explainer.explain_nodes(new_pyg_graph, idx=idx)

        edge_imp = np.random.rand(edge_imp.shape[0])
        # np.random.shuffle(edge_imp)

        output = model.forward1(new_pyg_graph)
        pred = output[idx]
        cor = torch.argmax(pred, dim=-1) == label
        ori_correct += cor
        count += 1

        idxes = np.argsort(edge_imp, axis=0)
        # idxes = idxes[::-1]

        for i, top_k_ratio in enumerate(topk_ratios):
            if top_k_ratio<1:
                top_k = int(edge_imp.shape[0] * top_k_ratio)
            else:
                top_k = top_k_ratio
            bin_mask = np.ones_like(edge_imp)
            bin_mask[idxes[top_k:]] = 0
            bin_mask = bin_mask.astype(bool)

            new_graph = Data(x=pyg_graph.x, edge_index=edge_index[:, bin_mask],
                             edge_attr=pyg_graph.edge_attr[edge_mask][bin_mask])

            explain_edges = bin_mask.sum()
            maintain_total_sub_edges[i].append(explain_edges)

            output = model.forward1(new_graph)
            pred = output[idx]
            maintain_pred_values[i].append(torch.softmax(pred,dim=0)[1].item())
            cor = torch.argmax(pred, dim=-1) == label
            maintain_correct_lists[i] += cor

        for i, top_k_ratio in enumerate(topk_ratios):
            if top_k_ratio<1:
                top_k = int(edge_imp.shape[0] * top_k_ratio)
            else:
                top_k = top_k_ratio

            bin_mask = np.zeros_like(edge_imp)
            bin_mask[idxes[top_k:]] = 1
            bin_mask = bin_mask.astype(bool)

            new_graph = Data(x=pyg_graph.x, edge_index=edge_index[:, bin_mask],
                             edge_attr=pyg_graph.edge_attr[edge_mask][bin_mask])

            explain_edges = bin_mask.sum()
            remove_total_sub_edges[i].append(explain_edges)

            output = model.forward1(new_graph)
            pred = output[idx]
            remove_pred_values[i].append(torch.softmax(pred, dim=0)[1].item())
            cor = torch.argmax(pred, dim=-1) == label
            remove_correct_lists[i] += cor


    print("ori_nodes : ", np.array(total_ori_nodes).mean())
    print("ori_edges : ", np.array(total_ori_edges).mean())
    print("total ori correct:", ori_correct  / count)
    print(" ")

    for i, top_k_ratio in enumerate(topk_ratios):
        # print("sub_edges : ", np.array(total_sub_edges[i]).mean())
        print("remove sub_edges : ", np.array(total_ori_edges).mean() - np.array(remove_total_sub_edges[i]).mean(),end=" ")
        auroc = metrics.roc_auc_score(pred_label,remove_pred_values[i])
        print(f"remove {top_k_ratio}, total correct:", remove_correct_lists[i].item()  / count, "auroc:",auroc)
        # print("auroc:", auroc)

    for i, top_k_ratio in enumerate(topk_ratios):
        # print("sub_edges : ", np.array(total_sub_edges[i]).mean())
        print("maintain sub_edges : ", np.array(maintain_total_sub_edges[i]).mean(),end=" ")
        auroc = metrics.roc_auc_score(pred_label,maintain_pred_values[i])
        print(f"maintain {top_k_ratio}, total correct:", maintain_correct_lists[i].item()  / count, "auroc:",auroc)
        # print("auroc:", auroc)



if __name__=="__main__":
    # local_explain_preseve_evaluation_unweights(dataset_name='raid', domain='books')
    # local_explain_preseve_evaluation_unweights(dataset_name='m4', domain='peerread')
    # local_explain_preseve_evaluation_unweights(dataset_name='hc3', domain='medicine')

    # local_explain_preseve_evaluation_unweights(dataset_name='hc3', domain='open_qa')

    for domain in ['open_qa','wiki_csai','medicine','finance']:
        local_explain_preseve_evaluation_unweights(dataset_name = 'hc3',domain =domain)
    #
    for domain in ['reddit', 'peerread', 'wikihow', 'arxiv']:
        local_explain_preseve_evaluation_unweights(dataset_name = 'm4',domain =domain)
    #
    for domain in ['recipes', 'books', 'poetry', 'reviews']:
        local_explain_preseve_evaluation_unweights(dataset_name='raid', domain=domain)

## MORF
