from utils.configs import config_path, config

import jax
import jax.numpy as jnp  # JAX NumPy
from flax.training import train_state  # Useful dataclass to keep train state
import optax  # Optimizers
import wandb
import numpy as np
from functools import partial

optimizer = {
    'adam': optax.adam,
    'sgd': optax.sgd
}
net = None

def prepare_batches(X, Y, batch_size, rng, batch_in_gpu=False):
    """Prepare batches for training."""
    train_ds_size = len(X)
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rng, train_ds_size) #if batch_in_gpu else np.array(generate_permutation(rng, train_ds_size))
    perms = perms[:steps_per_epoch * batch_size].reshape((steps_per_epoch, batch_size))

    X_batches = np.take(X, perms, axis=0)
    Y_batches = np.take(Y, perms, axis=0)

    return X_batches, Y_batches


def create_train_state(rng, network, learning_rate, momentum):
    """Create initial `TrainState`."""
    global net
    net = network
    params = net.init(rng, jnp.ones(wandb.config.pholder), rng)['params']
    # schedule = optax.exponential_decay(learning_rate, 1, 0.999)
    tx = optimizer[wandb.config.opt_name](learning_rate, momentum)
    return train_state.TrainState.create(
        apply_fn=net.apply, params=params, tx=tx)

def update_train_state(learning_rate, momentum, params):
    """Update `TrainState`."""
    tx = optimizer[wandb.config.opt_name](learning_rate, momentum)
    return train_state.TrainState.create(
        apply_fn=net.apply, params=params, tx=tx)


@jax.jit
def cross_entropy_loss(*, logits, labels):
    labels_onehot = jax.nn.one_hot(labels, num_classes=config.num_classes)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

@jax.jit
def compute_metrics_eval(*, logits, labels):
    loss = cross_entropy_loss(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy

@jax.jit
def compute_metrics(*, logits, labels, params):
    loss = cross_entropy_loss(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

    l1_penalty = sum(jnp.sum(jnp.abs(p)) for p in jax.tree_util.tree_leaves(params)) * config.l1_reg
    l2_penalty = sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params)) * config.l2_reg
    return loss, accuracy, l1_penalty, l2_penalty

@jax.jit
def compute_metrics_eval(logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1])
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels))
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy


@jax.jit
def eval_step(params, batch, rng):
    logits = net.apply({'params': params}, batch['image'], rng)
    loss, accuracy = compute_metrics_eval(logits=logits, labels=batch['label'])
    predictions = jnp.argmax(logits, axis=-1)
    return loss, accuracy, predictions, batch['label']


def calculate_class_wise_accuracy(predictions, labels, num_classes):
    class_correct = jnp.zeros(num_classes)
    class_total = jnp.zeros(num_classes)

    for i in range(num_classes):
        class_mask = (labels == i)
        class_total = class_total.at[i].set(jnp.sum(class_mask))
        class_correct = class_correct.at[i].set(jnp.sum(predictions[class_mask] == i))

    class_accuracy = class_correct / class_total
    return class_accuracy


def demographic_parity(class_accuracy):
    return jnp.std(class_accuracy)


def eval_model(params, test_ds, rng):
    rng, rng_net = jax.random.split(rng)
    test_size = test_ds['image'].shape[0]
    total_loss = 0
    total_accuracy = 0
    all_predictions = []
    all_labels = []

    for i in range(test_size // wandb.config.batch_size):
        batch = {k: v[i * wandb.config.batch_size:(i + 1) * wandb.config.batch_size] for k, v in test_ds.items()}
        loss, accuracy, predictions, labels = eval_step(params, batch, rng_net)
        total_loss += loss
        total_accuracy += accuracy
        all_predictions.append(predictions)
        all_labels.append(labels)

    all_predictions = jnp.concatenate(all_predictions)
    all_labels = jnp.concatenate(all_labels)

    class_accuracy = calculate_class_wise_accuracy(all_predictions, all_labels, num_classes=10)
    fairness_metric = demographic_parity(class_accuracy)

    return (total_loss / (test_size // wandb.config.batch_size),
            total_accuracy / (test_size // wandb.config.batch_size),
            class_accuracy,
            fairness_metric)


@jax.jit
def train_step(state, X, y, rng):
    """Train for a single step, including L1 and L2 regularization."""

    def loss_fn(params):
        _logits = net.apply({'params': params}, X, rng)
        _loss = cross_entropy_loss(logits=_logits, labels=y)
        l1_penalty = sum(jnp.sum(jnp.abs(p)) for p in jax.tree_util.tree_leaves(params))
        l2_penalty = sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
        _loss += config.l1_reg * l1_penalty + config.l2_reg * l2_penalty
        return _loss, _logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    loss, accuracy, l1_loss, l2_loss = compute_metrics(logits=logits, labels=y, params=state.params)
    return state, loss, accuracy, l1_loss, l2_loss

def to_gpu(x):
    return jax.device_put(x, device=jax.devices('gpu')[0])

def to_cpu(x):
    return jax.device_put(x, device=jax.devices('cpu')[0])

def train_epoch(state, X, Y, b_rng, t_rng):
    train_ds_size = len(X)
    batch_size = config.batch_size # // config.n_clients
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(b_rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size].reshape((steps_per_epoch, batch_size))
    # cast perms to numpy array
    # perms = np.array(perms)
    if config.dataset == 'cifar10' or config.cpu_batching:
        perms = jax.device_put(perms, device=jax.devices('cpu')[0])
    X_batches = np.take(X, perms, axis=0, mode='wrap')
    Y_batches = np.take(Y, perms, axis=0, mode='wrap')
    # X, Y = to_cpu(X), to_cpu(Y)
    # X_batches, Y_batches = to_cpu(X_batches), to_cpu(Y_batches)
    t_state = state
    batch_loss = []
    batch_acc = []
    batch_l1_loss = []
    batch_l2_loss = []
    for x, y in zip(X_batches, Y_batches):
        if config.dataset == 'cifar10' or config.cpu_batching:
            x, y = jax.device_put(x, device=jax.devices('gpu')[0]), jax.device_put(y, device=jax.devices('gpu')[0])
        t_state, loss, accuracy, l1_loss, l2_loss = train_step(t_state, x, y, t_rng)
        batch_loss.append(loss)
        batch_acc.append(accuracy)
        batch_l1_loss.append(l1_loss)
        batch_l2_loss.append(l2_loss)

    return (t_state, jnp.array(batch_loss).mean(), jnp.array(batch_acc).mean(),
            jnp.array(batch_l1_loss).mean(), jnp.array(batch_l2_loss).mean())


@jax.pmap
def train_step_pmap(state, X, y, rng):
    """Train for a single step."""

    def loss_fn(params):
        logits = net.apply({'params': params}, X, rng)
        loss = cross_entropy_loss(logits=logits, labels=y)
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    loss, accuracy = compute_metrics(logits=logits, labels=y)
    return state, loss, accuracy

@jax.pmap
def train_step_pmap2(state, X, y, rng):
    return jax.vmap(train_step, in_axes=(0, 0, 0, None))(state, X, y, rng)

def train_single_epoch_pmap(state, X, y, batch_size, rng):
    """Train for a single epoch."""
    train_ds_size = X.shape[2]
    steps_per_epoch = train_ds_size // batch_size

    perm = jax.random.permutation(rng, train_ds_size)[batch_size]
    # perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    # perms = perms.reshape((steps_per_epoch, batch_size))
    batch_loss = []
    batch_acc = []
    rng, rng_net = jax.random.split(rng)
    # for perm in perms:
    X_batch = jnp.take(X, perm, axis=2)
    Y_batch = jnp.take(y, perm, axis=2)
    rngs = jax.random.split(rng_net, num=X.shape[0])
    state, loss, accuracy = train_step_pmap2(state, X_batch, Y_batch, rngs)
    batch_loss.append(loss)
    batch_acc.append(accuracy)

    return state, jnp.array(batch_loss).mean(), jnp.array(batch_acc).mean()


def train_epoch_pmap(state, X, y, batch_size, rng):
    """Train for a single epoch."""
    train_ds_size = X.shape[2]
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    rng, rng_net = jax.random.split(rng)

    # Rearrange your data once according to the permuted indices, instead of in every step
    X = jnp.take(X, perms, axis=2)
    Y = jnp.take(y, perms, axis=2)

    total_loss = 0
    total_acc = 0

    for i in range(steps_per_epoch):
        # Now you can simply slice your array instead of using jnp.take
        X_batch = X[:, :, i]
        Y_batch = Y[:, :, i]
        rngs = jax.random.split(rng_net, num=X.shape[0])
        state, loss, accuracy = train_step_pmap2(state, X_batch, Y_batch, rngs)

        # total_loss = jax.pmap(lambda x,y: x+y, in_axes=(None, 0))(total_loss, loss)
        # total_acc = jax.pmap(lambda x,y: x+y, in_axes=(None, 0))(total_acc, accuracy)

    # total_loss = total_loss / steps_per_epoch
    # total_acc = total_acc / steps_per_epoch

    return state, total_loss, total_acc


def train_single_batch(state, X, y, batch_size, rng):
    """Train for a single epoch."""
    train_ds_size = len(X)
    # steps_per_epoch = train_ds_size // batch_size

    perm = jax.random.permutation(rng, train_ds_size)[batch_size]
    # perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    # perms = perms.reshape((steps_per_epoch, batch_size))
    batch_loss = []
    batch_acc = []
    rng, rng_net = jax.random.split(rng)
    # for perm in perms:
    X_batch = jnp.take(X, perm, axis=0)
    Y_batch = jnp.take(y, perm, axis=0)
    state, loss, accuracy = train_step(state, X_batch, Y_batch, rng_net)
    batch_loss.append(loss)
    batch_acc.append(accuracy)

    return state, jnp.array(batch_loss).mean(), jnp.array(batch_acc).mean()

@partial(jax.jit, backend='cpu', static_argnums=(1,))
def generate_permutation(rng_key, size):
    return jax.random.permutation(rng_key, size)


