import os
import sys
import argparse
import random
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, PreTrainedTokenizer
from datasets import load_dataset, IterableDataset
from tqdm import tqdm
from scipy.optimize import curve_fit
from typing import List, Tuple, Dict
import json
import pandas as pd

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from configs import Config
from utils import dataclass_from_file, compute_compression_rates, get_base_vocab_size
from fast_compression import batch_lzw_compress
from lzw_tokenizer import LZW_Tokenizer


def compute_token_efficiency(tokenizer: PreTrainedTokenizer, text: str) -> float:

    num_tokens = len(tokenizer.encode(text))
    num_bytes = len(text.encode("utf-8"))
    bytes_per_token = num_bytes / num_tokens
    return bytes_per_token


def compute_token_efficiency_for_dataset(
    tokenizer: PreTrainedTokenizer,
    dataset: IterableDataset,
    num_samples: int = 1000,
) -> float:
    total_efficiency = 0.0
    count = 0

    for sample in tqdm(dataset, desc="Computing token efficiency", total=num_samples):
        text = sample["text"]
        efficiency = compute_token_efficiency(tokenizer, text)
        total_efficiency += efficiency
        count += 1
        if count >= num_samples:
            break

    return total_efficiency / count if count > 0 else 0.0


def compute_token_efficiency_matrix(
    tokenizers: List[PreTrainedTokenizer],
    datasets: List[IterableDataset],
    num_samples: int = 100,
) -> np.ndarray:
    token_efficiency_matrix = np.zeros((len(tokenizers), len(datasets)))
    for i, tokenizer in enumerate(tokenizers):
        for j, dataset in enumerate(datasets):
            token_efficiency_matrix[i, j] = compute_token_efficiency_for_dataset(
                tokenizer, dataset, num_samples
            )
    return token_efficiency_matrix


def prettyprint_token_efficiency_matrix(
    token_efficiency_matrix: np.ndarray, column_names: List[str], row_names: List[str]
) -> None:
    df = pd.DataFrame(
        token_efficiency_matrix, index=row_names, columns=column_names
    ).round(2)
    print(df)


if __name__ == "__main__":

    NUM_SAMPLES = 1_000

    tokenizer_32k_llama = AutoTokenizer.from_pretrained(
        "meta-llama/Llama-2-7b-hf", use_fast=True
    )

    lzw_tokenizer_32k_llama = LZW_Tokenizer(tokenizer_32k_llama)

    tokenizer_128k_llama = AutoTokenizer.from_pretrained(
        "meta-llama/Llama-3.2-1B-Instruct", use_fast=True
    )

    lzw_tokenizer_128k_llama = LZW_Tokenizer(tokenizer_128k_llama)

    tokenizer_128k_deepseek = AutoTokenizer.from_pretrained(
        "deepseek-ai/DeepSeek-V3-Base", use_fast=True
    )
    lzw_tokenizer_128k_deepseek = LZW_Tokenizer(tokenizer_128k_deepseek)

    tokenizer_150k_qwen = AutoTokenizer.from_pretrained(
        "Qwen/Qwen3-0.6B", use_fast=True
    )

    lzw_tokenizer_150k_qwen = LZW_Tokenizer(tokenizer_150k_qwen)

    tokenizer_200k_phi = AutoTokenizer.from_pretrained(
        "microsoft/Phi-4-mini-instruct", use_fast=True
    )

    lzw_tokenizer_200k_phi = LZW_Tokenizer(tokenizer_200k_phi)

    tokenizer_256K_gemma = AutoTokenizer.from_pretrained(
        "google/gemma-3-1b-it", use_fast=True
    )

    lzw_tokenizer_256K_gemma = LZW_Tokenizer(tokenizer_256K_gemma)

    all_tokenizers = [
        tokenizer_32k_llama,
        lzw_tokenizer_32k_llama,
        tokenizer_128k_llama,
        lzw_tokenizer_128k_llama,
        tokenizer_200k_phi,
        lzw_tokenizer_200k_phi,
        tokenizer_256K_gemma,
        lzw_tokenizer_256K_gemma,
        tokenizer_150k_qwen,
        lzw_tokenizer_150k_qwen,
        tokenizer_128k_deepseek,
        lzw_tokenizer_128k_deepseek,
    ]

    tokenizer_names = [
        "Llama-32K",
        "Llama-32K-LZW",
        "Llama-128K",
        "Llama-128K-LZW",
        "Phi-200K",
        "Phi-200K-LZW",
        "Gemma-256K",
        "Gemma-256K-LZW",
        "Qwen-150K",
        "Qwen-150K-LZW",
        "DeepSeek-128K",
        "DeepSeek-128K-LZW",
    ]

    # code knowledge math chat multilingual

    code_dataset: IterableDataset = load_dataset(
        "XXX/zip2zip-1B", split="train", name="code", streaming=True
    )
    math_dataset: IterableDataset = load_dataset(
        "XXX/zip2zip-1B", split="train", name="math", streaming=True
    )
    chat_dataset: IterableDataset = load_dataset(
        "XXX/zip2zip-1B", split="train", name="chat", streaming=True
    )
    multilingual_dataset: IterableDataset = load_dataset(
        "XXX/zip2zip-1B", split="train", name="multilingual", streaming=True
    )
    knowledge_dataset: IterableDataset = load_dataset(
        "XXX/zip2zip-1B", split="train", name="knowledge", streaming=True
    )

    all_datasets = [
        code_dataset,
        math_dataset,
        chat_dataset,
        multilingual_dataset,
        knowledge_dataset,
    ]

    dataset_names = ["code", "math", "chat", "multilingual", "knowledge"]

    token_efficiency_matrix = compute_token_efficiency_matrix(
        all_tokenizers, all_datasets, num_samples=NUM_SAMPLES
    )

    prettyprint_token_efficiency_matrix(
        token_efficiency_matrix, column_names=dataset_names, row_names=tokenizer_names
    )
