from sklearn.feature_extraction.text import CountVectorizer
import argparse
import pickle
import sys
import re
import string
from nltk.corpus import stopwords
import os
import argparse

def preprocess_text(text):
    text = text.lower()
    text = text.translate(str.maketrans('', '', string.punctuation))
    text = re.sub(r'\d+', '', text)
    words = text.split()
    english_stopwords = set(stopwords.words('english'))
    words = [word for word in words if word not in english_stopwords]
    processed_text = ' '.join(words)
    return processed_text

def remove_unwanted_function_pointer(vectorizer):
    vectorizer.preprocessor = None
    return vectorizer
    
def process_imdb(vocab_size=20000):
    from tensorflow.keras.datasets import imdb

    print("Loading IMDb dataset from Keras...")
    (X_train_tokenized, _), (_, _) = imdb.load_data(num_words=vocab_size)

    word_index = imdb.get_word_index()
    index_to_word = {v + 3: k for k, v in word_index.items()}
    index_to_word[0] = "[PAD]"
    index_to_word[1] = "[START]"
    index_to_word[2] = "[UNK]"
    index_to_word[3] = "[UNUSED]"

    X_train_raw = [" ".join(index_to_word.get(i, "[UNK]") for i in review) for review in X_train_tokenized]

    vectorizer_X = CountVectorizer(
        preprocessor=preprocess_text,
        max_features=vocab_size,
        binary=True
    )
    X_train = vectorizer_X.fit_transform(X_train_raw)

    print("Vectorization complete.")
    print("Shape of X_train:", X_train.shape)

    return vectorizer_X, X_train

def process_1billion(num_words=40000, input_file="train_v2.txt"):
    import nltk

    nltk.download('stopwords', quiet=True)

    print(sys.version)
    with open(input_file, encoding='utf-8') as f:
        lines = f.read().split("\n")
    sentences = [sentence for sentence in lines if sentence.strip()]
    print(f"Number of sentences: {len(sentences)}")

    vectorizer_X = CountVectorizer(
        preprocessor=preprocess_text,
        max_features=num_words,
        binary=True
    )
    X = vectorizer_X.fit_transform(sentences)
    current_vocab_size = len(vectorizer_X.vocabulary_)
    print(f"Vocabulary size: {current_vocab_size}")
    print("Vectorisation completed")

    return vectorizer_X, X

def main(dataset_name=None):
    if dataset_name == None:
        print("Select the dataset:")
        print("1. IMDB")
        print("2. 1 Billion Word")
        
        choice = input("Enter the number corresponding to your choice: ")
        if choice == '1':
            vectorizer_X, X = process_imdb(vocab_size=20000)
            db_name = "imdb"
        elif choice == '2':
            vectorizer_X, X = process_1billion(num_words=40000, input_file="train_v2.txt")
            db_name = "1billion"
        else:
            print("Invalid choice. Please select a valid dataset.")
            return
    else:
        if dataset_name != "imdb" and dataset_name != "1billion":
            print("Invalid dataset name. Please select either 'imdb' or '1billion'.")
            return
        db_name = dataset_name
    
    vectorizer_X = remove_unwanted_function_pointer(vectorizer_X)

    os.makedirs(f"data/{db_name}", exist_ok=True)
    with open(f"data/{db_name}/X.pickle", "wb") as f_X:
        pickle.dump(X, f_X, protocol=4)
        
    vectorizer_X = remove_unwanted_function_pointer(vectorizer_X)
    with open(f"data/{db_name}/vectorizer_X.pickle", "wb") as f:
        pickle.dump(vectorizer_X, f)
    
    tokenized_sentences = []
    vocabulary = vectorizer_X.vocabulary_
    reverse_vocab = {index: word for word, index in vocabulary.items()}
    for row in X:
        word_indices = row.indices  # Non-zero indices in the sparse matrix row
        sentence = [reverse_vocab[index] for index in word_indices]
        tokenized_sentences.append(sentence)
    with open(f"data/{db_name}/tokenized_sentences.pickle", "wb") as f:
        pickle.dump(tokenized_sentences, f, protocol=4)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_name', type=str, required=False, help='Name of the dataset to use')
    args = parser.parse_args()

    if args.dataset_name:
        input = lambda _: args.dataset_name
    main(args.dataset_name)