import time
import argparse
import torch
import keyword
from transformers import (
    LlamaTokenizerFast,
    LlamaTokenizer,
    AutoTokenizer,
    AutoModelForCausalLM,
)

from datasets import load_dataset
from src.utils import device, touch, load_models, dtype, warmup, torch_timer

import numpy as np
import keyword
import nltk
from nltk.tokenize import word_tokenize
from nltk.parse import RecursiveDescentParser
import keyword
import networkx as nx
from nltk.tokenize import word_tokenize
import matplotlib.pyplot as plt


def eda_distances(dataset, tokenizer_path):
    """Exploratory data analysis of a specific dataset"""
    if tokenizer_path == "hf-internal-testing/llama-tokenizer":
        tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer_path)
    else:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

    distances = []
    samples = []
    total_tokens = 0

    if dataset == "openai_humaneval":
        # We anlayze the distances between new lines, as that is the
        # granularity of how humans are expected to provide feedback
        dataset = load_dataset(dataset)
        examples = dataset["test"]
        separator = tokenizer.encode("\n")[1:]

        for example in examples:
            code = example["prompt"] + example["canonical_solution"]
            samples.append(code)

    elif dataset == "xsum":
        # We anlayze the distances between new lines, as that is the
        # granularity of how humans are expected to provide feedback
        dataset = load_dataset("EdinburghNLP/xsum")
        examples = dataset["test"]
        separator = tokenizer.encode(".")[1:]

        for example in examples:
            text = example["document"] + example["summary"]
            samples.append(text)

    elif dataset == "gsm8k":
        # We anlayze the distances between new lines, as that is the
        # granularity of how humans are expected to provide feedback
        dataset = load_dataset(dataset, "main")
        examples = dataset["test"]
        separator = tokenizer.encode("\n")[1:]

        for example in examples:
            text = example["question"] + example["answer"]
            samples.append(text)

    elif dataset == "finance-alpaca":
        # We anlayze the distances between new lines, as that is the
        # granularity of how humans are expected to provide feedback
        dataset = load_dataset("gbharti/finance-alpaca")
        examples = dataset["train"]
        separator = tokenizer.encode("\n")[1:]

        for example in examples:
            text = example["instruction"] + example["output"]
            samples.append(text)

    for sample in samples:
        ids = tokenizer.encode(sample)
        total_tokens += len(ids)

        # record the number of tokens bewteen newlines
        last_newline = 0
        for i, id in enumerate(ids):
            if id in separator:
                distances.append(i - last_newline)
                last_newline = i

    # print the frequencies of distances, no plotting
    freq = {}
    for distance in distances:
        if distance in freq:
            freq[distance] += 1
        else:
            freq[distance] = 1

    # sort by most frequent
    freq = dict(sorted(freq.items(), key=lambda item: item[1], reverse=True))
    print(f"Total tokens: {total_tokens}")
    return


def eda():
    tokenizer_path = "hf-internal-testing/llama-tokenizer"

    eda_distances("openai_humaneval", tokenizer_path)
    eda_distances("EdinburghNLP/xsum", tokenizer_path)
    eda_distances("gsm8k", tokenizer_path)
    eda_distances("gbharti/finance-alpaca", tokenizer_path)
