import jax
import optax
import numpy as np
import jax.numpy as jnp
from tqdm import tqdm
from flax.training.train_state import TrainState
import wandb 
# ---------- Accuracy ----------

def create_train_state(model,args, len_train_loader, label_mask = None):
    total_steps = args.epochs * len_train_loader
    warmup_steps = int(0.1 * total_steps)
    schedule_fn = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=args.lr,
        warmup_steps=warmup_steps,
        decay_steps=total_steps,
        end_value=0.0,
    )
    if label_mask: 
        tx = optax.multi_transform(
            transforms={
                'trainable': optax.chain(
                    optax.clip_by_global_norm(1.0),
                    optax.adamw(
                        learning_rate=schedule_fn,
                        b1=0.9, b2=0.999, eps=1e-8,
                        weight_decay=args.weight_decay,
                  )
                ),
                'frozen': optax.set_to_zero()
            },
            param_labels=label_mask
        )
    else:
        tx = optax.set_to_zero()
    return TrainState.create(apply_fn=model.__call__, params=model.params, tx=tx)

def accuracy(logits, labels, topk=(1,)):
    maxk = max(topk)
    batch_size = labels.shape[0]
    topk_preds = jnp.argsort(logits, axis=-1)[:, -maxk:][:, ::-1]  # Top-k predictions
    res = []
    for k in topk:
        correct = (topk_preds[:, :k] == labels[:, None])
        correct = jnp.any(correct, axis=1)
        correct = jnp.sum(correct)
        res.append(100.0 * correct / batch_size)
    return res  # list of [acc@1, acc@5]
@jax.jit
def train_step(state, batch, rng):
    def loss_fn(params):
        outputs = state.apply_fn(params = params, pixel_values=batch['images'], train=True)
        logits = outputs.logits
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['labels']).mean()
        return loss, logits

    (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)

    acc1, acc5 = accuracy(logits, batch['labels'], topk=(1, 5))
    metrics = {'loss': loss,'acc1': acc1,'acc5': acc5,}
    return state, metrics
@jax.jit
def eval_step(state, batch):
    outputs = state.apply_fn(params= state.params, pixel_values=batch['images'], train=False)
    logits = outputs.logits
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['labels']).mean()
    acc1, acc5 = accuracy(logits, batch['labels'], topk=(1, 5))
    metrics = {'loss': loss,'acc1': acc1,'acc5': acc5,}
    return metrics
def train_epoch(state, train_loader, rng):
    total_loss, total_acc1, total_acc5, total_samples = 0.0, 0.0, 0.0, 0
    pbar = tqdm(train_loader, desc="Training", leave=False)

    for step, (images, labels) in enumerate(pbar):
        batch = {
            'images': jax.device_put(np.asarray(images)),
            'labels': jax.device_put(np.asarray(labels)),
        }
        rng, step_rng = jax.random.split(rng)
        state, metrics = train_step(state, batch, step_rng)

        batch_size = batch['images'].shape[0]
        total_loss += float(metrics['loss']) * batch_size
        total_acc1 += float(metrics['acc1']) * batch_size / 100.0
        total_acc5 += float(metrics['acc5']) * batch_size / 100.0
        total_samples += batch_size

        if (step + 1) % 200 == 0:
            wandb.log({
                "train/step_loss": float(metrics['loss']),
                "train/step_acc1": float(metrics['acc1']),
                "train/step_acc5": float(metrics['acc5']),
            })

        pbar.set_postfix({
            'Loss': f'{metrics["loss"]:.4f}',
            'Acc@1': f'{metrics["acc1"]:.2f}%',
            'Acc@5': f'{metrics["acc5"]:.2f}%'
        })

    avg_loss = total_loss / total_samples
    avg_acc1 = total_acc1 / total_samples * 100
    avg_acc5 = total_acc5 / total_samples * 100
    return state, {"loss": avg_loss, "acc1": avg_acc1, "acc5": avg_acc5}
def evaluate(state, val_loader, batch_limit=0):
    total_loss, total_acc1, total_acc5, total_batches = 0.0, 0.0, 0.0, 0
    pbar = tqdm(enumerate(val_loader), total=len(val_loader), desc="Evaluating", leave=False)

    for i, (images, labels) in pbar:
        if batch_limit > 0 and i >= batch_limit:
            break
        
        batch = {
            'images': jax.device_put(np.asarray(images)),
            'labels': jax.device_put(np.asarray(labels)),
        }
        metrics = eval_step(state, batch)

        total_loss += float(metrics['loss'])
        total_acc1 += float(metrics['acc1'])
        total_acc5 += float(metrics['acc5'])
        total_batches += 1

        avg_loss = total_loss / total_batches
        avg_acc1 = total_acc1 / total_batches
        avg_acc5 = total_acc5 / total_batches

        pbar.set_postfix({'Loss': f'{avg_loss:.4f}','Acc@1': f'{avg_acc1:.2f}%','Acc@5': f'{avg_acc5:.2f}%',})
        wandb.log({'val/step_loss': avg_loss,'val/step_acc1': avg_acc1,'val/step_acc5': avg_acc5,})
    print(f"Top-1 Accuracy: {avg_acc1:.2f}%, Top-5 Accuracy: {avg_acc5:.2f}%, Loss: {avg_loss:.4f}")
    return avg_acc1, avg_acc5, avg_loss
