import os
import json
import argparse
import logging
import time
import numpy as np
import torch
import dgl
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, RandomSampler, DataLoader
from transformers import RobertaConfig, RobertaModel, RobertaTokenizer
from gensim.models import Word2Vec
from cpp_tokenizer import tokenize_c
from java_tokenizer import tokenize_java
from tqdm import tqdm

import pdb

# logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(f"word2vec_training_{time.strftime('%Y%m%d%H%M%S')}.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def preprocess_data(language, datasource):
    all_tokens = []
    logger.info("Start preprocessing the data...")
    
    for ind in tqdm(range(len(datasource)), desc="Handle file nodes"):
        codes = datasource.loc[ind].nodes_codes
        tokens = []
        for code in codes:
            if language == 'java':
                code_tokens = tokenize_java(code)
            elif language == 'c':
                code_tokens = tokenize_c(code)
            tokens.append(code_tokens)
        all_tokens.extend(tokens)
    
    logger.info(f"The preprocessing is completed, and a total of {len(all_tokens)} token sequences are extracted")
    return all_tokens

def visualize_stats(all_tokens, output_dir):
    """Visualize the statistics of token length distribution"""
    token_lengths = [len(tokens) for tokens in all_tokens]
    
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    plt.hist(token_lengths, bins=50, color='skyblue', edgecolor='black')
    plt.title('Token Sequence Length Distribution')
    plt.xlabel('Sequence Length')
    plt.ylabel('Frequency of Occurrence')
    plt.axvline(np.mean(token_lengths), color='red', linestyle='dashed', linewidth=1, label=f'average {np.mean(token_lengths):.1f}')
    plt.axvline(np.median(token_lengths), color='green', linestyle='dashed', linewidth=1, label=f'middle {np.median(token_lengths)}')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    sorted_lengths = sorted(token_lengths)
    cum_sum = np.cumsum(sorted_lengths)
    plt.plot(sorted_lengths, cum_sum/cum_sum[-1], 'b-')
    plt.title('The cumulative proportion of Token length')
    plt.xlabel('Sequence Length')
    plt.ylabel('Cumulative Proportion')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'token_stats.png'))
    plt.close()
    logger.info(f"The Token statistics graph has been saved to {os.path.join(output_dir, 'token_stats.png')}")

def train_word2vec():
    start_time = time.time()
    
    input_files = [
        # java: bears and bug and D4J
        '../data/processed/bears_and_bugs/train.jsonl',
        '../data/processed/bears_and_bugs/valid.jsonl',
        '../data/processed/bears_and_bugs/test.jsonl',
        # c: SNOL
        # "../data/processed/SNOL_joern/train.jsonl",
        # "../data/processed/SNOL_joern/valid.jsonl",
        # "../data/processed/SNOL_joern/test.jsonl",
    ]
    project = 'bears_and_bugs'
    language = 'java'
    logger.info(f'Start processing the {project} dataset...')
    output_dir = f'../saved_models/word2vec/{project}_word2vec'
    
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"Output directory: {output_dir}")
    
    datasource = pd.concat(
        (pd.read_json(file, lines=True) for file in input_files),
        ignore_index=True
    )
    
    all_tokens = preprocess_data(language, datasource)
    # visualize_stats(all_tokens, output_dir)
    
    # training word2vec
    logger.info("training word2vec...")
    
    vector_size=128
    window=10
    min_count=5
    workers=12
    epochs=10
    sg=1
    logger.info(f"Model parameters: vector_size={vector_size}, \
                window={window}, min_count={min_count}, workers={workers}, epochs={epochs}, sg={sg}")
        
    model = Word2Vec(
        sentences=all_tokens,
        vector_size=vector_size,
        window=window,
        min_count=window,
        workers=workers,
        epochs=epochs,
        sg=sg,
        compute_loss=True
    )
    
    # saving model
    model_path = os.path.join(output_dir, 'word2vec_model.bin')
    model.save(model_path)
    logger.info(f"The model has been saved to: {model_path}")
    
    vector_path = os.path.join(output_dir, 'word_vectors.kv')
    model.wv.save(vector_path)
    logger.info(f"The word vector has been saved to: {vector_path}")
    
    with open(os.path.join(output_dir, 'training_report.txt'), 'w') as f:
        f.write(f"Training completion time: {time.ctime()}\n")
        f.write(f"Total time consumption: {time.time() - start_time:.2f}秒\n")
        f.write(f"Vocabulary list size: {len(model.wv)}\n")
        f.write(f"The total number of tokens in the corpus: {model.corpus_total_words}\n")
        
        f.write("\"n High-frequency Vocabulary:\n")
        for word, freq in sorted(model.wv.key_to_index.items(), key=lambda x: x[1], reverse=True)[:10]:
            f.write(f"{word}: {freq}次\n")
        
        if hasattr(model, 'get_latest_training_loss'):
            f.write(f"Final training loss: {model.get_latest_training_loss()}\n")
    
    logger.info(f"The training report has been saved to {os.path.join(output_dir, 'training_report.txt')}")
    
if __name__ == '__main__':
    train_word2vec()
    
    # test_code = "int main() { return 0; }"
    # test_tokens = extract_tokens(test_code)
    # print(f"Tokenization test: {test_tokens}")
    # ['int', 'main', '(', ')', '{', 'return', '0', ';', '}']