import jax.numpy as jnp
from flax.training import train_state
import jax
#import transformers
from transformers import GPT2Tokenizer, GPT2Model, FlaxGPT2LMHeadModel, AutoTokenizer, AutoModelForCausalLM
from solve.models.cvx_relu_mlp import CVX_ReLU_MLP  
from solve.optimizers.cronos import run, init_cronos_state 
from utils.data_utils import tokenize_data
from inference import inference
from utils.dataset_creation import get_dataset_dicts
from jax.extend.backend import get_backend
from jax import lax
import wandb
import time
from jax.lib import xla_bridge
import os
import pickle

# TO DO: check dtype of jax bfloat16 everywhere
# JAX version should be 0.4.33


print(get_backend().platform)
print("Running JAX Version =",jax.__version__)


# models: dilstilGPT, GPT2, GPT2-M, FlaxGPT2 
# methods: DPO, SFT, SimPO, ORPO
# datasets: imdB, StanfordSHP, edu-tutor 

PROJECT_DIR = os.getcwd()
SAVE_DIR = os.path.join(PROJECT_DIR, "embeddings")
os.makedirs(SAVE_DIR, exist_ok=True)  

MODEL_NAME = 'distilgpt2'
policy_dtype = jnp.bfloat16

wandb.init(
    project="cvx_dpo",
    # config={
    # "architecture": model,
    # "dataset": dataset,
    # "method": method,
    # }
    name="gpt1_customdataset",
)

# 1: Initialize Pre-trained GPT-2 (Flax)
def initialize_pretrained_model():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, clean_up_tokenization_spaces=False)  
    tokenizer.padding_side = "left" 
    tokenizer.pad_token = tokenizer.eos_token  # Ensure pad_token is set once again
    model = model = GPT2Model.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True)#, torch_dtype=policy_dtype)

    return tokenizer, model


# 3: Enter CRONOS
def train_with_cronos(prompts, chosen_embeddings, rejected_embeddings):
    """
    Trains a convex neural network using the CRONOS optimizer in JAX.
    """
    n_samples = len(prompts)
    feature_dim = chosen_embeddings.shape[-1]  

    # Instantiate CVX_ReLU_MLP model
    model = CVX_ReLU_MLP(
        X=prompts,  # must be JAX array for input features
        y=jnp.array([1, 0] * n_samples),  # Labels (1 for chosen, 0 for rejected)
        P_S=feature_dim,
        beta=0.1,
        rho=1.0,
        seed=42
    )
    model.init_model()  

    # Initialize CRONOS state
    state = init_cronos_state(feature_dim, n_samples, feature_dim)

    # CRONOS parameters
    admm_params = {
        "rank": 10,  # Rank for preconditioner
        "beta": 0.1,
        "gamma_ratio": 0.5,
        "admm_iters": 100,
        "pcg_iters": 20,
        "check_opt": True,
        "verbose": True
    }

    # Train using CRONOS
    print("Training with CRONOS ...")
    (v, w), metrics = run(model, admm_params, "classification")

    return v, w, metrics

# log VRAM
def get_vram_usage():
    backend = xla_bridge.get_backend()
    return backend.memory_info().used / (1024 ** 2)  # GPU VRAM usage in MB

# log TFLOPS
def get_tflops(start_time, num_operations):
    elapsed_time = time.time() - start_time  # Time in seconds
    return num_operations / (elapsed_time * 1e12)  # TFLOPs


# ---------------------------------------------------------------------------
# 5: Pipeline
if __name__ == "__main__":
    # Initialize pre-trained GPT-2
    tokenizer, pretrained_model = initialize_pretrained_model()
    print("Finished loading model!")

    # Prepare dataset
    dataset = get_dataset_dicts()
    print("Finished loading dataset, checking for saved embeddings...")

    # Load or generate embeddings
    embeddings_exist = all(
        os.path.exists(os.path.join(SAVE_DIR, f"{name}_{i}.pkl"))
        for i in range(len(dataset))
        for name in ["prompt", "chosen", "rejected"]
    )

    if embeddings_exist:
        # Load embeddings from disk
        prompts, chosen_embeddings, rejected_embeddings = [], [], []
        for i in range(len(dataset)):
            with open(os.path.join(SAVE_DIR, f"prompt_{i}.pkl"), "rb") as f:
                prompts.append(jnp.array(pickle.load(f)))
            with open(os.path.join(SAVE_DIR, f"chosen_{i}.pkl"), "rb") as f:
                chosen_embeddings.append(jnp.array(pickle.load(f)))
            with open(os.path.join(SAVE_DIR, f"rejected_{i}.pkl"), "rb") as f:
                rejected_embeddings.append(jnp.array(pickle.load(f)))

        # Stack embeddings into JAX arrays
        prompts = jnp.stack(prompts)
        chosen_embeddings = jnp.stack(chosen_embeddings)
        rejected_embeddings = jnp.stack(rejected_embeddings)
        print("Loaded embeddings from disk.")

    else:
        # Tokenize data and generate embeddings
         # Tokenize data and generate embeddings
        # prompts, chosen_embeddings, rejected_embeddings = tokenize_data(dataset, tokenizer, pretrained_model)
        # print(prompts.dtype)
        # exit()
        prompts, chosen_embeddings, rejected_embeddings = tokenize_data(dataset, tokenizer, pretrained_model, save_dir=SAVE_DIR)
        print("Generated and saved embeddings to disk.")

    print(f"Prompts dtype: {prompts.dtype}")

    # Train with CRONOS optimizer
    print("Starting training with CRONOS...")
    v, w, metrics = train_with_cronos(prompts, chosen_embeddings, rejected_embeddings)

    #wandb.log({"acc": acc, "loss": metrics, "tflops": get_tflops(start_time, num_operations),})
    wandb.finish()

    # Perform inference
    test_prompt = "Tell me about Troy"
    result = inference(test_prompt, tokenizer, pretrained_model, v, w)
    print(f"Result for '{test_prompt}': {result}")


# ------------------------------------------------------------------------------
# convex dpo loss for completeness
def convex_dpo_loss(predictions_chosen, predictions_rejected, beta):
    """
    Computes the convexified DPO loss for evaluation.
    """
    logits = predictions_chosen - predictions_rejected
    return -jnp.mean(jnp.log(1 + jnp.exp(-beta * logits)))
