import jax.numpy as jnp
from jax import jit, vmap
from jax.nn import relu
from .dataset_creation import get_dataset_dicts
import torch
import pickle
import os


# # dataset_dicts = get_dataset_dicts() 


# # from cronos gpt 2 exp same idea
# hidden_size = model(**batch).last_hidden_state.size(-1)
# last_hidden_states_np = model(**batch).last_hidden_state.detach().cpu().numpy()

def tokenize_data(dataset, tokenizer, model, save_dir):
    """
    Tokenizes and generates embeddings for the dataset in a format compatible with JAX.
    Saves embeddings to disk for reuse in future runs.
    """
    # Ensure tokenizer has a pad_token set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Ensure save directory exists
    os.makedirs(save_dir, exist_ok=True)

    prompts = []
    embeddings_chosen = []
    embeddings_rejected = []

    for idx, entry in enumerate(dataset):
        # Construct unique file paths for each entry's embeddings
        prompt_path = os.path.join(save_dir, f"prompt_{idx}.pkl")
        chosen_path = os.path.join(save_dir, f"chosen_{idx}.pkl")
        rejected_path = os.path.join(save_dir, f"rejected_{idx}.pkl")

        if os.path.exists(prompt_path) and os.path.exists(chosen_path) and os.path.exists(rejected_path):
            # Load saved embeddings if they exist
            with open(prompt_path, "rb") as f:
                prompt_embedding = pickle.load(f)
            with open(chosen_path, "rb") as f:
                chosen_embedding = pickle.load(f)
            with open(rejected_path, "rb") as f:
                rejected_embedding = pickle.load(f)
        else:
            # Tokenize and generate embeddings
            prompt_tokens = tokenizer(entry["prompt"], return_tensors="pt", padding=True, truncation=True)
            chosen_tokens = tokenizer(entry["chosen"], return_tensors="pt", padding=True, truncation=True)
            rejected_tokens = tokenizer(entry["rejected"], return_tensors="pt", padding=True, truncation=True)

            # Extract embeddings from model (outputs are PyTorch tensors)
            prompt_embedding = model(**prompt_tokens).last_hidden_state.mean(dim=1)
            chosen_embedding = model(**chosen_tokens).last_hidden_state.mean(dim=1)
            rejected_embedding = model(**rejected_tokens).last_hidden_state.mean(dim=1)

            # Convert PyTorch tensors to NumPy (for JAX compatibility)
            prompt_embedding = prompt_embedding.detach().numpy()
            chosen_embedding = chosen_embedding.detach().numpy()
            rejected_embedding = rejected_embedding.detach().numpy()

            # Save embeddings to disk
            with open(prompt_path, "wb") as f:
                pickle.dump(prompt_embedding, f)
            with open(chosen_path, "wb") as f:
                pickle.dump(chosen_embedding, f)
            with open(rejected_path, "wb") as f:
                pickle.dump(rejected_embedding, f)

        # Convert embeddings to JAX arrays and append
        prompts.append(jnp.array(prompt_embedding))
        embeddings_chosen.append(jnp.array(chosen_embedding))
        embeddings_rejected.append(jnp.array(rejected_embedding))

    # Stack the lists into JAX arrays for efficient processing
    prompts = jnp.stack(prompts)
    embeddings_chosen = jnp.stack(embeddings_chosen)
    embeddings_rejected = jnp.stack(embeddings_rejected)

    return prompts, embeddings_chosen, embeddings_rejected





# # 2: Prepare Dataset
# def prepare_dataset():
#     """
#     Simulates preference data for the prompt "Tell me about Troy."
#     """
#     dataset = [
#         {
#             "prompt": "Tell me about Troy",
#             "chosen": "Troy was an ancient city located in present-day Turkey.",
#             "rejected": "Troy is a name."
#         },
#         # Add more samples as needed...
#     ] # dataset is a dictionary
#     return dataset