import jax
import jax.numpy as jnp
from transformers import GPT2Tokenizer, GPT2Model
from solve.models.cvx_relu_mlp import CVX_ReLU_MLP  
from solve.optimizers.cronos import run, init_cronos_state 
from utils.data_utils import tokenize_data
from inference import inference
from utils.dataset_creation import get_dataset_dicts
from jax.extend.backend import get_backend

print(get_backend().platform)

# JAX version should be 0.4.33
print("Running JAX Version =",jax.__version__)

# models: dilstilGPT, GPT2, GPT2-M


# # 1: Initialize Pre-trained GPT-2
def initialize_pretrained_model():
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces = False)
    model = GPT2Model.from_pretrained("gpt2")
    return tokenizer, model


# # 3: Define Convex Neural Network and Train with CRONOS
def train_with_cronos(prompts, chosen_embeddings, rejected_embeddings):
    """
    Trains a convex neural network using the CRONOS optimizer.
    """
    n_samples = len(prompts)
    feature_dim = chosen_embeddings[0].shape[-1]  # Embedding sizez

    # Instantiate CVX_ReLU_MLP model
    model = CVX_ReLU_MLP(
        X=jnp.stack(prompts), # x must be jax array
        y=jnp.array([1, 0] * n_samples),  # Labels (1 for chosen, 0 for rejected)
        P_S=feature_dim, 
        beta=0.1, 
        rho=1.0, 
        seed=42
    )

    model.init_model()  # Initialize hyperplane cuts and other internal parameters

    # Initialize CRONOS state
    state = init_cronos_state(feature_dim, n_samples, feature_dim)

    # CRONOS parameters
    admm_params = {
        "rank": 10,  # Rank for preconditioner
        "beta": 0.1,
        "gamma_ratio": 0.5,
        "admm_iters": 100,
        "pcg_iters": 20,
        "check_opt": True,
        "verbose": True
    }

    # Train using CRONOS
    print("Training with CRONOS ...")
    (v, w), metrics = run(model, admm_params, "classification")

    return v, w, metrics


#------------------------------------------------------------------------------

# # 5: Full Pipeline
if __name__ == "__main__":
    # Initialize pre-trained GPT-2
    tokenizer, pretrained_model = initialize_pretrained_model()

    # Prepare dataset
    dataset = get_dataset_dicts() 

    #dataset = prepare_dataset()
    prompts, chosen_embeddings, rejected_embeddings = tokenize_data(dataset, tokenizer, pretrained_model)

    # Train with CRONOS optimizer
    v, w, metrics = train_with_cronos(prompts, chosen_embeddings, rejected_embeddings)

    # Perform inference
    test_prompt = "Tell me about Troy"
    result = inference(test_prompt, tokenizer, pretrained_model, v, w)
    print(f"Result for '{test_prompt}': {result}")




# ------------------------------------------------------------------------------
# convex dpo loss for completeness
def convex_dpo_loss(predictions_chosen, predictions_rejected, beta):
    """
    Computes the convexified DPO loss for evaluation.
    """
    logits = predictions_chosen - predictions_rejected
    return -jnp.mean(jnp.log(1 + jnp.exp(-beta * logits)))
