import jax
import optax
from tqdm import tqdm
import jax.numpy as jnp
from flax.training import train_state
from flax.training.common_utils import onehot


def create_train_state(model, learning_rate, params, mask = None):
    if mask:
        tx = optax.multi_transform({'trainable': optax.adam(learning_rate), 'frozen': optax.set_to_zero()},mask)
    else:
        tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

def cross_entropy_loss(logits, labels):
    one_hot = onehot(labels, logits.shape[-1])
    return optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return {'loss': loss, 'accuracy': accuracy}
@jax.jit
def train_step(state, batch, rng):
    def loss_fn(params):
        logits = state.apply_fn(
            {'params': params},
            batch['input'],
            pad_mask=batch['mask'],
            deterministic=False,
            rngs={'dropout': rng}
        )
        loss = cross_entropy_loss(logits, batch['label'])
        return loss, logits
    grads, logits = jax.grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, batch['label'])
    return state, metrics
@jax.jit
def eval_step(state, batch):
    logits = state.apply_fn({'params': state.params},batch['input'],pad_mask=batch['mask'],deterministic=True)
    return compute_metrics(logits, batch['label'])

def train_epoch(state, train_loader, rng):
    batch_metrics = []
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training")
    for batch_idx, batch in pbar:
        rng, input_rng = jax.random.split(rng)
        state, metrics = train_step(state, batch, input_rng)
        batch_metrics.append(metrics)
        # Update tqdm postfix with current batch metrics
        pbar.set_postfix({"loss": f"{metrics['loss']:.4f}", "acc": f"{metrics['accuracy']:.4f}"})
    # Aggregate epoch metrics
    avg_loss = jnp.mean(jnp.array([m['loss'] for m in batch_metrics]))
    avg_accuracy = jnp.mean(jnp.array([m['accuracy'] for m in batch_metrics]))
    return state, {'loss': avg_loss, 'accuracy': avg_accuracy}

def evaluate_model(state, dataloader):
    batch_metrics = []
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Evaluating")
    for batch_idx, batch in pbar:
        metrics = eval_step(state, batch)
        batch_metrics.append(metrics)
        # Show current batch metrics
        pbar.set_postfix({
            "loss": f"{metrics['loss']:.4f}",
            "acc": f"{metrics['accuracy']:.4f}"
        })
    avg_loss = jnp.mean(jnp.array([m['loss'] for m in batch_metrics]))
    avg_accuracy = jnp.mean(jnp.array([m['accuracy'] for m in batch_metrics]))
    return {'loss': avg_loss, 'accuracy': avg_accuracy}