import math
from functools import partial

import haiku as hk
import jax
import jax.numpy as jnp
import optax
from tqdm import tqdm

from medium_rl.network.encoder import EncoderTransformer

jnp.set_printoptions(precision=4, edgeitems=32, linewidth=200)

CLS = 0
PAD = 1
EOS = 2


def get_prob(output_logits):
    return jax.nn.sigmoid(output_logits)


def classification_loss_fn(output_logits, y):
    prob = get_prob(output_logits).squeeze(-1)

    # Binary cross entropy loss
    loss = -jnp.mean(y * jnp.log(prob) + (1 - y) * jnp.log(1 - prob))

    accuracy = jnp.mean((prob > 0.5).astype(jnp.float32) == y)

    return loss, accuracy


def regression_loss_fn(output_logits, y):
    # MSE
    squared_error = (output_logits - y) ** 2
    mse = jnp.mean(squared_error)
    mae = jnp.mean(jnp.abs(output_logits - y))

    return mse, mae


@partial(jax.jit, static_argnames=["forward", "optimizer", "loss_fn"])
def train_step(state, rng, x_batch, y_batch, forward, optimizer, loss_fn):
    def _loss_fn(params, rng, x, y):
        output_logits = forward.apply(params, rng, x, is_training=True)
        return loss_fn(output_logits, y)

    params = state["params"]

    # Compute loss and gradients
    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
    (loss, accuracy), grads = grad_fn(params, rng, x_batch, y_batch)

    # Update parameters
    updates, opt_state = optimizer.update(grads, state["opt_state"], params)
    params = optax.apply_updates(params, updates)

    return (
        {
            "step": state["step"] + 1,
            "params": params,
            "opt_state": opt_state,
        },
        loss,
        accuracy,
    )


def split_dataset(x, y, rng, val_percent=0.2):
    """Split data into training and validation sets using JAX."""
    dataset_size = x.shape[0]
    val_size = int(dataset_size * val_percent)
    indices = jax.random.permutation(rng, dataset_size)

    # Split indices
    val_indices = indices[:val_size]
    train_indices = indices[val_size:]

    x_train = x[train_indices]
    y_train = y[train_indices]
    x_val = x[val_indices]
    y_val = y[val_indices]

    return x_train, x_val, y_train, y_val


def eval_step(params, x, y, forward, loss_fn, batch_size):
    num_samples = x.shape[0]

    def evaluate_batch(i):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, num_samples)
        x_batch = x[start_idx:end_idx]
        y_batch = y[start_idx:end_idx]

        logits = forward.apply(params, None, x_batch, is_training=False)
        loss, acc = loss_fn(logits, y_batch)
        return (loss * (end_idx - start_idx), acc * (end_idx - start_idx))

    total_loss, total_acc = 0, 0
    for i in range(math.ceil(num_samples / batch_size)):
        loss, acc = evaluate_batch(i)
        total_loss += loss
        total_acc += acc

    return total_loss / num_samples, total_acc / num_samples


def get_val_stats(params, x, forward, batch_size):
    num_samples = x.shape[0]
    all_logits = jnp.zeros((num_samples,))

    for i in range(math.ceil(num_samples / batch_size)):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, num_samples)
        x_batch = x[start_idx:end_idx]

        logits = forward.apply(params, None, x_batch, is_training=False)
        all_logits = all_logits.at[start_idx:end_idx].set(logits.squeeze(-1))

    return {
        "min": all_logits.min(),
        "max": all_logits.max(),
        "mean": all_logits.mean(),
        "std": all_logits.std(),
    }


def train_model(
    x,
    y,
    model_cfg,
    type="classification",
    batch_size=256,
    learning_rate=1e-4,
    max_epochs=250,
    val_percent=0.2,
    weight_decay=1e-6,
    patience=15,
    seed=0,
):
    rng = jax.random.PRNGKey(seed)
    rng, split_rng = jax.random.split(rng)
    x_train, x_val, y_train, y_val = split_dataset(
        x,
        y,
        split_rng,
        val_percent=val_percent,
    )
    num_train_samples = x_train.shape[0]
    forward = hk.transform(lambda x, is_training=False: EncoderTransformer(**model_cfg)(x, is_training))
    rng, init_rng = jax.random.split(rng)
    params = forward.init(init_rng, jnp.ones((1, x.shape[1]), dtype=jnp.int32))
    optimizer = optax.adamw(learning_rate, weight_decay=weight_decay)
    opt_state = optimizer.init(params)

    if type == "classification":
        loss_fn = classification_loss_fn
        aux_name = "acc"
    elif type == "regression":
        loss_fn = regression_loss_fn
        aux_name = "mae"
    else:
        raise ValueError(f"Invalid type: {type}")

    state = {
        "step": jnp.array(0),
        "params": params,
        "opt_state": opt_state,
    }
    best_params, best_val_loss = state["params"], float("inf")
    early_stop_count = 0
    pbar = tqdm(range(max_epochs))

    for epoch in pbar:
        rng, shuffle_rng = jax.random.split(rng)
        shuffled_indices = jax.random.permutation(shuffle_rng, x_train.shape[0])

        # Initialize metrics accumulators for this epoch
        epoch_train_losses = []
        epoch_train_accs = []

        # Loop over dataset
        for i in range(math.ceil(num_train_samples / batch_size)):
            rng, train_rng = jax.random.split(rng)
            curr_start, curr_end = i * batch_size, (i + 1) * batch_size
            curr_idxs = shuffled_indices[curr_start:curr_end]
            x_batch, y_batch = x_train[curr_idxs], y_train[curr_idxs]
            state, train_loss, train_acc = train_step(state, train_rng, x_batch, y_batch, forward, optimizer, loss_fn)

            # Accumulate metrics
            epoch_train_losses.append(train_loss)
            epoch_train_accs.append(train_acc)

        # Validation set
        val_loss, val_acc = eval_step(state["params"], x_val, y_val, forward, loss_fn, batch_size)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_params = state["params"]
            early_stop_count = 0
        else:
            early_stop_count += 1
            if early_stop_count > patience:
                print("Stopping early")
                break

        # Calculate mean metrics for the epoch
        mean_train_loss = sum(epoch_train_losses) / len(epoch_train_losses)
        mean_train_acc = sum(epoch_train_accs) / len(epoch_train_accs)

        # Update progress bar with mean metrics
        pbar.set_description(f"Epoch {epoch}")
        pbar.set_postfix(
            {
                "train_loss": f"{mean_train_loss:.4f}",
                f"train_{aux_name}": f"{mean_train_acc:.4f}",
                "val_loss": f"{val_loss:.4f}",
                f"val_{aux_name}": f"{val_acc:.4f}",
            }
        )

    val_stats = get_val_stats(best_params, x_val, forward, batch_size)

    return model_cfg, state, best_params, val_stats
