import pandas as pd
import os
import numpy as np
import json
import math
import typing
from typing import List,Dict,Any
from transformers import AutoTokenizer,PreTrainedTokenizerFast
from tokenizers import decoders,models,normalizers,pre_tokenizers,processors,trainers,Tokenizer
import nltk
import torch
import torch.nn as nn
from collections import Counter
import torch.nn.functional as F
from torch.nn import TransformerEncoderLayer, Linear
import math
import datasets
import dataset
from tools.params import get_params
import argparse
from tqdm import tqdm
import spacy
import pytextrank


if __name__=="__main__":
    parser = argparse.ArgumentParser()  
    parser.add_argument('--config', default='ERROR')
    args = parser.parse_args()
    config_file = args.config
    data_name = config_file.split('_')[0].split('/')[-1]
    config = get_params(config_file)
    data_config = config['DATA']
    tokenizer_name = data_config['tokenizer_type'] + '_' + data_config['tokenizer_name'].replace('/','_')
    print("data_config ",data_config)
    daobj = dataset.get_data(data_config,'gpu',config,pre_cache=False)
    stat = []
    bs = 512
    cache_prefix = {'hp':'hyperpartisan','imdb':'imdb','mr':'mr','lun':'lun','ng20':'ng20','r8':'r8','bbcn':'bbcn','bs':'bs','bs-pair':'bs', 'eurlex':'eurlex','eurlex-inverse':'eurlex'}   
    # LOC_ID = daobj.tokenizer.get_vocab()['[LOC]']
    # PAD_ID = daobj.tokenizer.get_vocab()['[PAD]']  
    attn_mode = getattr(daobj,'attn_mode',{'name':'default','param1':0})
    assert attn_mode['name'] in ('key_phrase_split','key_phrase_split2','key_phrase_chunk_rep','key_phrase_chunk_rep2'), 'Only support key_phrase_split and key_phrase_chunk_rep now'

    mid_name1 = 'key_phrase_chunk_rep'
    mid_name2 = 'key_phrase_chunk_rep2'
    
    f_mid_name1 = 'key_phrase_split'
    f_mid_name2 = 'key_phrase_split2'
    
    max_len_name = 'whole_doc' 
    # dataset_name = daobj.global_config['DATA']['dataset_name']
    # if dataset_name == 'eurlex':
    #     if daobj.global_config['DATA'].get('inverse', False) == True:
    #         dataset_name = 'eurlex_inverse'
    # elif dataset_name == 'bs':
    #     if daobj.global_config['DATA'].get('pair', False) == True:
    #         dataset_name = 'bs-pair'
    nlp = spacy.load("en_core_web_sm")
    nlp.add_pipe("textrank")
    for split in ['train','val','test']:
        if split not in daobj.datasets:
            continue
        all_idx = []
        all_targets = []
        all_sentences = []
        all_sentences_loc = []
        all_kp_scores = []
        all_kps = []
        dsobj = daobj.datasets[split]
        pre_tokenized_cache1 = torch.load('results/cache/tokenized_results/{}_{}_{}_{}_{}.pt'.format(daobj.global_config['DATA']['dataset_name'],mid_name1,split,tokenizer_name,max_len_name))
        pre_tokenized_cache2 = torch.load('results/cache/tokenized_results/{}_{}_{}_{}_{}.pt'.format(daobj.global_config['DATA']['dataset_name'],mid_name2,split,tokenizer_name,max_len_name))
        for one_sample_cache1, one_sample_cache2 in zip(pre_tokenized_cache1,pre_tokenized_cache2):
            document = one_sample_cache1['sentences']
            doc = nlp(document)
            num_phrases = len(list(doc._.phrases))
            num_sents = len(list(doc.sents))
            tr = doc._.textrank
            textranked_sentences = [] #sorted by the important score
            for sent_id, sentence, score in tr.summary(limit_phrases=num_phrases, limit_sentences=num_sents, preserve_order=False):
                textranked_sentences.append([sent_id, str(sentence),score])

            original_ordered_sentences = sorted(textranked_sentences,key=lambda x: x[0])
            start_pos = 1 # 0 is the [CLS] token
            for sentence_data in original_ordered_sentences:
                tokenized_sent = daobj.tokenizer(sentence_data[1],padding=False,
                    truncation=False,
                    return_attention_mask = True,
                    return_offsets_mapping=True,
                    return_special_tokens_mask = True,
                    return_token_type_ids = True,
                    return_tensors = None)
                
                length = len(tokenized_sent['input_ids'][1:-1])
                sentence_data.append([start_pos, start_pos+length]) # end pos doesn't need to reduce 1, because we will use it like a[start:end]
                start_pos = start_pos+length
            one_sample_cache1['textrank_data'] = textranked_sentences
            one_sample_cache2['textrank_data'] = textranked_sentences

        torch.save(pre_tokenized_cache1,'results/cache/tokenized_results/{}_{}_{}_{}_{}.pt'.format(daobj.global_config['DATA']['dataset_name'],mid_name1,split,tokenizer_name,max_len_name))
        torch.save(pre_tokenized_cache2,'results/cache/tokenized_results/{}_{}_{}_{}_{}.pt'.format(daobj.global_config['DATA']['dataset_name'],mid_name2,split,tokenizer_name,max_len_name))
            
       

    