from typing import Union

import jax
import jax.numpy as jnp  # JAX NumPy
from flax.training import train_state  # Useful dataclass to keep train state
import optax  # Optimizers
import wandb
import numpy as np
from functools import partial

from keras.src.backend.jax.nn import sigmoid

from evosax import ParameterReshaper
from flax import linen as nn
optimizer = {
    'adam': optax.adam,
    'sgd': optax.sgd
}
net : Union[nn.Module, None] = None
param_reshaper : Union[ParameterReshaper, None] = None

@jax.jit
def smooth_indicator(x, epsilon=1e-6, k=50.0):
    """
    Smooth approximation of the indicator function.
    Returns a value close to 1 if |x| > epsilon, otherwise close to 0.
    """
    return 1 / (1 + jnp.exp(-k * (jnp.abs(x) - epsilon)))

@jax.jit
def compute_loss(x, epsilon=1e-6, k=50.0, target_proportion=0.8):
    """
    Compute the differentiable loss based on the proportion of elements in x
    that are non-zero or greater than a small threshold epsilon.

    Arguments:
    x : jnp.array
        Input array.
    epsilon : float
        Small threshold to consider a value as "non-zero".
    k : float
        Controls the smoothness of the indicator function.
    target_proportion : float
        Target proportion of non-zero (or sufficiently large) elements.

    Returns:
    loss : float
        Differentiable loss value.
    """
    smooth_indicators = smooth_indicator(x, epsilon, k) * 2 - 1
    proportion = jnp.mean(smooth_indicators)
    loss = jnp.log(1 + jnp.exp((proportion - target_proportion) * k)) / k
    return loss


def init_param_reshaper(p_reshaper):
    global param_reshaper
    param_reshaper = p_reshaper


def prepare_batches(X, Y, batch_size, rng, batch_in_gpu=False):
    """Prepare batches for training."""
    train_ds_size = len(X)
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rng, train_ds_size) #if batch_in_gpu else np.array(generate_permutation(rng, train_ds_size))
    perms = perms[:steps_per_epoch * batch_size].reshape((steps_per_epoch, batch_size))

    X_batches = np.take(X, perms, axis=0)
    Y_batches = np.take(Y, perms, axis=0)

    return X_batches, Y_batches

def create_train_state(rng, network : nn.Module, learning_rate, momentum):
    """Create initial `TrainState`."""
    global net
    net = network
    params = net.init(rng, jnp.ones(wandb.config.pholder), rng)['params']
    # schedule = optax.exponential_decay(learning_rate, 1, 0.999)
    tx = optimizer[wandb.config.opt_name](learning_rate, momentum)
    return train_state.TrainState.create(
        apply_fn=net.apply, params=params, tx=tx)

def update_train_state(learning_rate, momentum, params):
    """Update `TrainState`."""
    tx = optimizer[wandb.config.opt_name](learning_rate, momentum)
    return train_state.TrainState.create(
        apply_fn=net.apply, params=params, tx=tx)

@jax.jit
def cross_entropy_loss(*, logits, labels, num_classes=10):
    labels_onehot = jax.nn.one_hot(labels, num_classes=10) #TODO: fix this
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

@jax.jit
def compute_metrics(*, logits, labels, num_classes=10):
    loss = cross_entropy_loss(logits=logits, labels=labels, num_classes=num_classes)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy


@jax.jit
def compute_metrics_eval(logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1])
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels))
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy


@jax.jit
def eval_step(params, batch, rng):
    logits = net.apply({'params': params}, batch['image'], rng)
    loss, accuracy = compute_metrics_eval(logits=logits, labels=batch['label'])
    # predictions = jnp.argmax(logits, axis=-1)
    return loss, accuracy


def eval_model(params, test_ds, rng):
    rng, rng_net = jax.random.split(rng)
    test_size = test_ds['image'].shape[0]
    total_loss = 0
    total_accuracy = 0


    for i in range(test_size // wandb.config.batch_size):
        batch = {k: v[i * wandb.config.batch_size:(i + 1) * wandb.config.batch_size] for k, v in test_ds.items()}
        loss, accuracy = eval_step(params, batch, rng_net)
        total_loss += loss
        total_accuracy += accuracy


    return (total_loss / (test_size // wandb.config.batch_size),
            total_accuracy / (test_size // wandb.config.batch_size),
            )


@partial(jax.jit, static_argnames=('padding',))
def train_step(base, noise, w, x, y, lr, padding, rng):
    """Train for a single step."""

    def loss_fn(_w):
        global param_reshaper
        combined_x = base + jnp.dot(noise, _w).reshape(-1)[:-padding if padding > 0 else None]
        params = param_reshaper.reshape_single_net(combined_x.squeeze())
        _logits = net.apply({'params': params}, x, rng)
        _loss = cross_entropy_loss(logits=_logits, labels=y)
        # new_loss = compute_loss(x=_w, target_proportion=0.1)
        # _loss = _loss - 100 * new_loss
        return _loss, _logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(w)
    w = w - grads * lr

    loss, accuracy = compute_metrics(logits=logits, labels=y)
    return  w, loss, accuracy


@partial(jax.jit, static_argnames=('padding',))
def train_step2_m(base, noise, grads, w, v, w_momentum, v_momentum, x, y, lr, beta, padding, rng):
    """Train for a single step with momentum."""

    def loss_fn(_w, _v):
        global param_reshaper
        combined_x = base + jnp.dot(noise, _w).reshape(-1)
        if padding > 0:
            combined_x = combined_x[:-padding]
        combined_x += jnp.dot(grads, _v).reshape(-1)
        params = param_reshaper.reshape_single_net(combined_x.squeeze())
        _logits = net.apply({'params': params}, x, rng)
        _loss = cross_entropy_loss(logits=_logits, labels=y)
        return _loss, _logits

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True, argnums=(0, 1))
    (loss, logits), (w_grad, v_grad) = grad_fn(w, v)

    # Update momentum variables
    w_momentum = beta * w_momentum + (1 - beta) * w_grad
    v_momentum = beta * v_momentum + (1 - beta) * v_grad

    # Update parameters using momentum
    w = w - lr * w_momentum
    v = v - lr * v_momentum

    # Compute metrics
    loss_value, accuracy = compute_metrics(logits=logits, labels=y)

    return w, v, w_momentum, v_momentum, loss_value, accuracy

@partial(jax.jit, static_argnames=('padding',))
def train_step_m(base, noise, w, w_momentum, x, y, lr, beta, padding, rng):
    """Train for a single step with momentum."""

    def loss_fn(_w):
        global param_reshaper
        combined_x = base + jnp.dot(noise, _w).reshape(-1)
        if padding > 0:
            combined_x = combined_x[:-padding]
        params = param_reshaper.reshape_single_net(combined_x.squeeze())
        _logits = net.apply({'params': params}, x, rng)
        _loss = cross_entropy_loss(logits=_logits, labels=y)
        return _loss, _logits

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), w_grad = grad_fn(w)

    # Update momentum variables
    w_momentum = beta * w_momentum + (1 - beta) * w_grad

    # Update parameters using momentum
    w = w - lr * w_momentum

    # Compute metrics
    loss_value, accuracy = compute_metrics(logits=logits, labels=y)

    return w, w_momentum, loss_value, accuracy


def train_epoch(base, noise, x, y, batch_size, lr, parts, padding, rng):
    """Train for a single epoch."""
    train_ds_size = x.shape[0]
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    perms = jax.device_put(perms, device=jax.devices('cpu')[0])
    # Rearrange your data once according to the permuted indices, instead of in every step
    xb = np.take(x, perms, axis=0, mode='wrap')
    yb = np.take(y, perms, axis=0, mode='wrap')

    batch_loss = []
    batch_acc = []
    rng, rng_net = jax.random.split(rng)
    w = jnp.zeros((noise.shape[1], parts))

    for i in range(steps_per_epoch):
        xbi = jax.device_put(xb[i], device=jax.devices('gpu')[0])
        ybi = jax.device_put(yb[i], device=jax.devices('gpu')[0])
        w, loss, accuracy = train_step(base, noise, w, xbi, ybi, lr, padding, rng_net)

        # print(loss.mean(), new_loss.mean())
        # for l in losses:
        #     wandb.log({'loss': l})
        # wandb.log({'Mahi loss': loss, 'Mahi accuracy': accuracy})

        batch_loss.append(loss)
        batch_acc.append(accuracy)

    return w, jnp.array(batch_loss).mean(), jnp.array(batch_acc).mean()

@partial(jax.jit, static_argnames=('padding',))
def train_step2(base, noise, grads, w, v, x, y, lr, padding, rng):
    """Train for a single step."""

    def loss_fn(_w, _v):
        # _w, _v = w_v
        global param_reshaper
        combined_x = base + jnp.dot(noise, _w).reshape(-1)[:-padding if padding > 0 else None]
        combined_x += jnp.dot(grads, _v).reshape(-1)
        params = param_reshaper.reshape_single_net(combined_x.squeeze())
        _logits = net.apply({'params': params}, x, rng)
        _loss = cross_entropy_loss(logits=_logits, labels=y)
        # new_loss = compute_loss(x=_w, target_proportion=0.1)
        # _loss = _loss - 100 * new_loss
        return _loss, _logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True, argnums=(0, 1))
    (loss, logits), grads = grad_fn(w, v)
    w_grad, v_grad = grads
    w = w - w_grad * lr
    v = v - v_grad * lr
    loss, accuracy = compute_metrics(logits=logits, labels=y)
    return  w, v, loss, accuracy

def train_epoch2(base, noise, grads ,x, y, batch_size, lr, parts, padding, rng):
    """Train for a single epoch."""
    train_ds_size = x.shape[0]
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    perms = jax.device_put(perms, device=jax.devices('cpu')[0])
    # Rearrange your data once according to the permuted indices, instead of in every step
    xb = np.take(x, perms, axis=0, mode='wrap')
    yb = np.take(y, perms, axis=0, mode='wrap')

    batch_loss = []
    batch_acc = []
    rng, rng_net = jax.random.split(rng)
    w = jnp.zeros((noise.shape[1], parts))
    v = jnp.zeros((grads.shape[1], 1))
    for i in range(steps_per_epoch):
        xbi = jax.device_put(xb[i], device=jax.devices('gpu')[0])
        ybi = jax.device_put(yb[i], device=jax.devices('gpu')[0])
        w, v, loss, accuracy = train_step2(base, noise, grads, w, v, xbi, ybi, lr, padding, rng_net)

        # print(loss.mean(), new_loss.mean())
        # for l in losses:
        #     wandb.log({'loss': l})
        # wandb.log({'Mahi loss': loss, 'Mahi accuracy': accuracy})

        batch_loss.append(loss)
        batch_acc.append(accuracy)

    return w, v, jnp.array(batch_loss).mean(), jnp.array(batch_acc).mean()

def train_epoch2_m(base, noise, grads ,x, y, batch_size, lr, parts, padding, rng):
    """Train for a single epoch."""
    train_ds_size = x.shape[0]
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    perms = jax.device_put(perms, device=jax.devices('cpu')[0])
    # Rearrange your data once according to the permuted indices, instead of in every step
    xb = np.take(x, perms, axis=0, mode='wrap')
    yb = np.take(y, perms, axis=0, mode='wrap')

    batch_loss = []
    batch_acc = []
    rng, rng_net = jax.random.split(rng)
    w = jnp.zeros((noise.shape[1], parts))
    v = jnp.zeros((grads.shape[1], 1))
    w_momentum = jnp.zeros((noise.shape[1], parts))
    v_momentum = jnp.zeros((grads.shape[1], 1))
    beta = 0.9
    for i in range(steps_per_epoch):
        xbi = jax.device_put(xb[i], device=jax.devices('gpu')[0])
        ybi = jax.device_put(yb[i], device=jax.devices('gpu')[0])
        w, v, w_momentum, v_momentum, loss, accuracy = train_step2_m(base, noise, grads, w, v, w_momentum, v_momentum, xbi, ybi, lr, beta, padding, rng_net)

        # print(loss.mean(), new_loss.mean())
        # for l in losses:
        #     wandb.log({'loss': l})
        # wandb.log({'Mahi loss': loss, 'Mahi accuracy': accuracy})

        batch_loss.append(loss)
        batch_acc.append(accuracy)

    return w, v, jnp.array(batch_loss).mean(), jnp.array(batch_acc).mean()


def train_epoch_m(base, noise ,x, y, batch_size, lr, parts, padding, rng):
    """Train for a single epoch."""
    train_ds_size = x.shape[0]
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    perms = jax.device_put(perms, device=jax.devices('cpu')[0])
    # Rearrange your data once according to the permuted indices, instead of in every step
    xb = np.take(x, perms, axis=0, mode='wrap')
    yb = np.take(y, perms, axis=0, mode='wrap')

    batch_loss = []
    batch_acc = []
    rng, rng_net = jax.random.split(rng)
    w = jnp.zeros((noise.shape[1], parts))
    w_momentum = jnp.zeros((noise.shape[1], parts))
    beta = 0.9
    for i in range(steps_per_epoch):
        xbi = jax.device_put(xb[i], device=jax.devices('gpu')[0])
        ybi = jax.device_put(yb[i], device=jax.devices('gpu')[0])
        w, w_momentum, loss, accuracy = train_step_m(base, noise, w, w_momentum, xbi, ybi, lr, beta, padding, rng_net)

        # print(loss.mean(), new_loss.mean())
        # for l in losses:
        #     wandb.log({'loss': l})
        # wandb.log({'Mahi loss': loss, 'Mahi accuracy': accuracy})

        batch_loss.append(loss)
        batch_acc.append(accuracy)

    return w, jnp.array(batch_loss).mean(), jnp.array(batch_acc).mean()