import torch
import os.path

import networkx as nx
import scipy.sparse as sp
from math import log
from collections import defaultdict
import pandas as pd
from dataset import TextDataset
# from utils import get_corpus_path
from os.path import join, exists
from tqdm import tqdm
from transformers import BertTokenizer
import torch
from torch.optim.swa_utils import AveragedModel
from model import TextGNN
from eval import eval,MovingAverage

from pprint import pprint



model_name = 'google-bert/bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=True)

def build_text_graph_dirdataset(train_file, valid_file,test_file, catch_path=None,use_cach = True, window_size=20,
                                max_train_per_class=1000,llm='chatgpt'):


    if use_cach and os.path.exists(catch_path):
        textdataset = torch.load(catch_path)
        return textdataset
    else:
        pass

    datasets = [train_file,
                valid_file,
                test_file]
    if isinstance(train_file,list):
        train_data = []
        for train_f in train_file:
            train_d = pd.read_csv(train_f)
            train_data.append(train_d)

        val_data = []
        for valid_f in valid_file:
            valid_d = pd.read_csv(valid_f)
            val_data.append(valid_d)

        test_data = []
        for test_f in test_file:
            test_d = pd.read_csv(test_f)
            test_data.append(test_d)

    else:
        train_data = pd.read_csv(datasets[0])
        val_data = pd.read_csv(datasets[1])
        test_data = pd.read_csv(datasets[2])


    if 'hc3' in train_file:
        train_lines_human = train_data['human_answers'].values[:max_train_per_class].tolist()
        train_labels_human = ['human'] * len(train_lines_human)
        train_lines_chatgpt = train_data['chatgpt_answers'].values[:max_train_per_class].tolist()
        train_labels_chatgpt = [llm] * len(train_lines_chatgpt)
        train_lines = []
        train_lines.extend(train_lines_human)
        train_lines.extend(train_lines_chatgpt)
        train_labels = []
        train_labels.extend(train_labels_human)
        train_labels.extend(train_labels_chatgpt)

        valid_lines_human = val_data['human_answers'].values.tolist()
        valid_labels_human = ['human'] * len(valid_lines_human)
        valid_lines_chatgpt = val_data['chatgpt_answers'].values.tolist()
        valid_labels_chatgpt = [llm] * len(valid_lines_chatgpt)
        val_lines = []
        val_lines.extend(valid_lines_human)
        val_lines.extend(valid_lines_chatgpt)
        val_labels = []
        val_labels.extend(valid_labels_human)
        val_labels.extend(valid_labels_chatgpt)

        test_lines_human = test_data['human_answers'].values.tolist()
        test_labels_human = ['human'] * len(test_lines_human)
        test_lines_chatgpt = test_data['chatgpt_answers'].values.tolist()
        test_labels_chatgpt = [llm] * len(test_lines_chatgpt)
        test_lines = []
        test_lines.extend(test_lines_human)
        test_lines.extend(test_lines_chatgpt)
        test_labels = []
        test_labels.extend(test_labels_human)
        test_labels.extend(test_labels_chatgpt)

    elif 'm4' in train_file:
        train_lines_human = train_data['human_text'].values[:max_train_per_class].tolist()
        train_labels_human = ['human'] * len(train_lines_human)
        train_lines_chatgpt = train_data['machine_text'].values[:max_train_per_class].tolist()
        train_labels_chatgpt = [llm] * len(train_lines_chatgpt)
        train_lines = []
        train_lines.extend(train_lines_human)
        train_lines.extend(train_lines_chatgpt)
        train_labels = []
        train_labels.extend(train_labels_human)
        train_labels.extend(train_labels_chatgpt)


        valid_lines_human = val_data['human_text'].values.tolist()
        valid_labels_human = ['human'] * len(valid_lines_human)
        valid_lines_chatgpt = val_data['machine_text'].values.tolist()
        valid_labels_chatgpt = [llm] * len(valid_lines_chatgpt)
        val_lines = []
        val_lines.extend(valid_lines_human)
        val_lines.extend(valid_lines_chatgpt)
        val_labels = []
        val_labels.extend(valid_labels_human)
        val_labels.extend(valid_labels_chatgpt)

        test_lines_human = test_data['human_text'].values.tolist()
        test_labels_human = ['human'] * len(test_lines_human)
        test_lines_chatgpt = test_data['machine_text'].values.tolist()
        test_labels_chatgpt = [llm] * len(test_lines_chatgpt)
        test_lines = []
        test_lines.extend(test_lines_human)
        test_lines.extend(test_lines_chatgpt)
        test_labels = []
        test_labels.extend(test_labels_human)
        test_labels.extend(test_labels_chatgpt)

    elif isinstance(train_file,list) and 'raid' in train_file[0]:
        train_lines = []
        train_labels = []
        for t_data in train_data:
            train_lines_ = t_data['generation'].values[:max_train_per_class].tolist()
            train_labels_ = t_data['model'].values[:max_train_per_class].tolist()
            train_lines.extend(train_lines_)
            train_labels.extend(train_labels_)

        val_lines = []
        val_labels = []
        for t_data in val_data:
            val_lines_ = t_data['generation'].values.tolist()
            val_labels_ = t_data['model'].values.tolist()
            val_lines.extend(val_lines_)
            val_labels.extend(val_labels_)

        test_lines = []
        test_labels = []
        for t_data in test_data:
            test_lines_ = t_data['generation'].values.tolist()
            test_labels_ = t_data['model'].values.tolist()
            test_lines.extend(test_lines_)
            test_labels.extend(test_labels_)

    else:
        raise ValueError('no such dataset')

    train_split = {}
    doc_list = []
    train_doc_list = []
    count = 0
    for line in train_lines:
        train_doc_list.append(line.strip())
        doc_list.append(line.strip())
        train_split[count]='train'
        count+=1
    for line in val_lines:
        doc_list.append(line.strip())
        train_split[count]='val'
        count += 1
    for line in test_lines:
        doc_list.append(line.strip())
        train_split[count]='test'
        count += 1

    labels = []
    for line in train_labels:
        if 'human' in line:
            labels.append(0)
        elif llm in line:
            labels.append(1)

    for line in val_labels:
        if 'human' in line:
            labels.append(0)
        elif llm in line:
            labels.append(1)

    for line in test_labels:
        if 'human' in line:
            labels.append(0)
        elif llm in line:
            labels.append(1)
    print(max(labels),min(labels))

    assert len(labels) == len(doc_list)

    word_freq,doc_lists_ = get_vocab2(train_doc_list,doc_list)
    vocab = list(word_freq.keys())

    words_in_docs, word_doc_freq = build_word_doc_edges1(doc_lists_)
    word_id_map = {word: i for i, word in enumerate(vocab)}

    sparse_graph = build_edges1(doc_lists_, word_id_map, vocab, word_doc_freq, window_size,direction=True)
    docs_dict = {i: doc for i, doc in enumerate(doc_lists_)}
    textdataset = TextDataset('raid_m', sparse_graph, labels, vocab, word_id_map, docs_dict, None,
                       train_test_split=train_split)
    torch.save(textdataset,catch_path)
    return textdataset

def build_edges1(doc_list, word_id_map, vocab, word_doc_freq, window_size=20,direction=True):
    # constructing all windows
    windows = []
    for doc_words in tqdm(doc_list):
        words = doc_words #.split()
        doc_length = len(words)
        if doc_length <= window_size:
            windows.append(words)
        else:
            for i in range(doc_length - window_size + 1):
                window = words[i: i + window_size]
                windows.append(window)
    # constructing all single word frequency
    word_window_freq = defaultdict(int)
    for window in tqdm(windows):
        appeared = set()
        for word in window:
            if word not in appeared:
                word_window_freq[word] += 1
                appeared.add(word)
    # constructing word pair count frequency
    word_pair_count = defaultdict(int)
    for window in tqdm(windows):
        for i in range(1, len(window)):
            for j in range(i):
                word_i = window[i]
                word_j = window[j]
                if word_i not in word_id_map.keys():
                    continue
                if word_j not in word_id_map.keys():
                    continue
                word_i_id = word_id_map[word_i]
                word_j_id = word_id_map[word_j]
                if word_i_id == word_j_id:
                    continue
                if direction:
                    # word_pair_count[(word_i_id, word_j_id)] += 1
                    word_pair_count[(word_j_id, word_i_id)] += 1
                else:
                    word_pair_count[(word_i_id, word_j_id)] += 1
                    word_pair_count[(word_j_id, word_i_id)] += 1
    row = []
    col = []
    weight = []

    # pmi as weights
    num_docs = len(doc_list)
    num_window = len(windows)
    for word_id_pair, count in tqdm(word_pair_count.items()):
        i, j = word_id_pair[0], word_id_pair[1]
        word_freq_i = word_window_freq[vocab[i]]
        word_freq_j = word_window_freq[vocab[j]]
        pmi = log((1.0 * count / num_window) /
                  (1.0 * word_freq_i * word_freq_j / (num_window * num_window)))
        if pmi <= 0:
            continue
        row.append(num_docs + i)
        col.append(num_docs + j)
        weight.append(pmi)

    # frequency of document word pair
    doc_word_freq = defaultdict(int)
    for i, doc_words in enumerate(doc_list):
        words = doc_words#.split()
        for word in words:
            if word not in word_id_map.keys():
                continue
            word_id = word_id_map[word]
            doc_word_str = (i, word_id)
            doc_word_freq[doc_word_str] += 1

    for i, doc_words in enumerate(doc_list):
        words = doc_words #.split()
        doc_word_set = set()
        for word in words:
            if word not in word_id_map:
                continue
            word_id = word_id_map[word]
            freq = doc_word_freq[(i, word_id)]
            row.append(i)
            col.append(num_docs + word_id)
            idf = log(1.0 * num_docs /
                      (word_doc_freq[vocab[word_id]]+1E-8))
            weight.append(freq * idf)
            doc_word_set.add(word)

    number_nodes = num_docs + len(vocab)
    adj_mat = sp.csr_matrix((weight, (row, col)), shape=(number_nodes, number_nodes))
    adj = adj_mat + adj_mat.T.multiply(adj_mat.T > adj_mat) - adj_mat.multiply(adj_mat.T > adj_mat)
    return adj


def get_vocab2(tain_text_list, text_list):
    doc_lists = []
    word_freq = defaultdict(int)
    for doc_words in tqdm(tain_text_list):
        encoded_data = tokenizer(doc_words)
        input_ids = encoded_data["input_ids"]
        words = input_ids
        for word in words:
            word_freq[word] += 1


    for doc_words in tqdm(text_list):
        encoded_data = tokenizer(doc_words)
        input_ids = encoded_data["input_ids"]
        doc_lists.append(input_ids)

    return word_freq,doc_lists


def build_word_doc_edges1(doc_list):
    # build all docs that a word is contained in
    words_in_docs = defaultdict(set)
    for i, doc_words in enumerate(doc_list):
        words = doc_words#.split()
        for word in words:
            words_in_docs[word].add(i)

    word_doc_freq = {}
    for word, doc_list in words_in_docs.items():
        word_doc_freq[word] = len(doc_list)

    return words_in_docs, word_doc_freq

# model_name='davinci'
# model_name='cohere'
# model_name='dolly'
# model_name='bloomz'

# domain = 'reddit'
# domain = 'peerread'
# domain = 'arxiv'

# domain = 'recipes'
# domain = 'poetry'
# domain = 'reviews'

# model_name = 'llama-chat'
# model_name = 'mpt'
# model_name = 'gpt4'
# model_name = 'mistral'


def main(dataset_name = 'raid',domain = 'reviews',model_name='llama-chat'):

    print(dataset_name,domain,model_name)
    num_epochs = 5000
    device = 'cuda:2'
    # validation_metric = 'accuracy'
    validation_metric = 'auroc'

    if dataset_name == 'hc3':
        dir = '../dataset_lists/hc3'
        save_dir = '../dataset_lists/hc3' + '/' + f'unweighted_{domain}'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        train_file = dir + '/' + f'{domain}_train.csv'
        valid_file = dir + '/' + f'{domain}_valid.csv'
        test_file = dir + '/' + f'{domain}_test.csv'
    elif dataset_name == 'm4':
        dir = '../dataset_lists/m4'
        save_dir = '../dataset_lists/m4' + '/' + f'unweighted_{domain}_{model_name}'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        train_file = dir + '/' + f'{domain}_{model_name}_train.csv'
        valid_file = dir + '/' + f'{domain}_{model_name}_valid.csv'
        test_file = dir + '/' + f'{domain}_{model_name}_test.csv'
    elif dataset_name == 'raid':
        dir = '../dataset_lists/raid'
        train_file = []
        valid_file = []
        test_file = []
        save_dir = '../dataset_lists/raid' + '/' + f'unweighted_{domain}_{model_name}_human'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        train_h = dir + '/' + f'{domain}_human_train.csv'
        valid_h = dir + '/' + f'{domain}_human_valid.csv'
        test_h = dir + '/' + f'{domain}_human_test.csv'
        train_m = dir + '/' + f'{domain}_{model_name}_train.csv'
        valid_m = dir + '/' + f'{domain}_{model_name}_valid.csv'
        test_m = dir + '/' + f'{domain}_{model_name}_test.csv'
        train_file.append(train_m)
        train_file.append(train_h)
        valid_file.append(valid_m)
        valid_file.append(valid_h)
        test_file.append(test_m)
        test_file.append(test_h)
    else:
        raise  ValueError('dataset name error')


    textdataset =  build_text_graph_dirdataset(train_file=train_file,valid_file=valid_file,test_file=test_file,
                                               catch_path = save_dir +'/'+'dataset_cache.pth',use_cach=True,
                                               window_size=20,llm=model_name)

    words_length = textdataset.words_length
    train_dataset, val_dataset, test_dataset = textdataset.tvt_split()
    train_dataset.init_node_feats('one_hot_node_init', device)
    val_dataset.init_node_feats('one_hot_node_init', device)
    test_dataset.init_node_feats('one_hot_node_init', device)
    pyg_graph = train_dataset.get_pyg_graph('cpu')

    print(dataset_name)
    print("training:", len(train_dataset.docs), " valid: ", len(val_dataset.docs)," test: ", len(test_dataset.docs))
    print("tokens:",words_length," edges: ", pyg_graph.edge_index.shape[1])
    return

    moving_avg = MovingAverage(1000, validation_metric != 'loss')

    model = TextGNN(pred_type='softmax', node_embd_type='gcn',
            num_layers=2, layer_dim_list=[words_length,64,2], act='silu',
            bn=True, num_labels=2, class_weights=True, dropout=True,use_weight=False)

    model = AveragedModel(model)
    # model.load_state_dict(torch.load(save_dir + '/direct_graph_gnnmodel_best.pth',map_location='cpu'))
    model.to(device)
    pyg_graph = train_dataset.get_pyg_graph(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=5E-4)

    with torch.no_grad():
        val_loss, preds_val = model(pyg_graph, val_dataset)
        test_loss, preds_test = model(pyg_graph, test_dataset)
        eval_res_val = eval(preds_val, val_dataset)
        test_res_val = eval(preds_test, test_dataset)
        pprint(eval_res_val)
        pprint(test_res_val)

    # exit(0)

    best_results_val = 0
    best_results_test = 0
    for epoch in range(num_epochs):
        model.train()
        model.zero_grad()
        loss, preds_train = model(pyg_graph, train_dataset)
        loss.backward()
        optimizer.step()
        loss = loss.item()
        model.update_parameters(model)
        with torch.no_grad():
            val_loss, preds_val = model(pyg_graph, val_dataset)
            test_loss, preds_test = model(pyg_graph, test_dataset)

            val_loss = val_loss.item()
            eval_res_val = eval(preds_val, val_dataset)
            print("Epoch: {:04d}, Train Loss: {:.5f}".format(epoch, loss))
            print("Val Loss: {:.5f}".format(val_loss))

            eval_res_val["loss"] = val_loss

            eval_res_test = eval(preds_test, test_dataset)
            moving_avg.add_to_moving_avg(eval_res_val[validation_metric])
            if best_results_val <= eval_res_val[validation_metric]:
                path = save_dir + '/direct_graph_gnnmodel_best.pth'
                torch.save(model.state_dict(), path)
                best_results_val = eval_res_val[validation_metric]
                best_results_test = eval_res_test[validation_metric]

                if best_results_val >= 1.0:
                    break

            print("current eval: ", eval_res_val[validation_metric])
            print("current test: ", eval_res_test[validation_metric])
            print("best eval: ", best_results_val)
            print("best test: ", best_results_test)

    model.load_state_dict(torch.load(save_dir + '/direct_graph_gnnmodel_best.pth'))
    with torch.no_grad():
        val_loss, preds_val = model(pyg_graph, val_dataset)
        test_loss, preds_test = model(pyg_graph, test_dataset)
        eval_res_val = eval(preds_val, val_dataset)
        test_res_val = eval(preds_test, test_dataset)
        pprint(eval_res_val)
        pprint(test_res_val)

if __name__=="__main__":
    # 'wikihow',
    for domain in ['reddit', 'peerread',  'arxiv']:
        for modelname in ['davinci','cohere','dolly','bloomz']:
            main(dataset_name='m4',domain = domain, model_name=modelname)

    for domain in ['recipes', 'poetry',  'reviews']:
        for modelname in ['llama-chat','gpt4','mpt','mistral']:
            main(dataset_name='raid',domain = domain, model_name=modelname)

    # main()


