import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd


def lira_hinge_loss(logits, labels):
    target_pred = logits[jnp.arange(logits.shape[0]), ..., labels]
    logits = logits.at[jnp.arange(logits.shape[0]), ..., labels].set(float("-inf"))
    return target_pred - jnp.max(logits, axis=-1)


def get_perms_per_pair_model_jax(key, dataset_size, add_canary=False):
    indices = jnp.arange(dataset_size, dtype=int)
    indices = jax.random.permutation(key, indices, independent=True)
    if not add_canary:
        set1_indices = indices[:dataset_size // 2]
        set2_indices = indices[-dataset_size // 2:]
    else:
        canary_index = dataset_size
        set1_indices = jnp.append(indices[-(dataset_size // 2):], canary_index)
        set2_indices = indices[:(dataset_size // 2) + 1]
        # breakpoint()
    return set1_indices[None, :], set2_indices[None, :]

def pretty_print(x, num_parts = 2):
    fields = x.split("_")[2::2]
    values = x.split("_")[3::2]
    filetered = list(map(lambda x: x if x[0] != "mnist" else ("canary_label", x[1]), filter(lambda x: x[0] != "cname", zip(fields, values))))
    part_size = len(filetered) // num_parts

    parts = [filetered[i * part_size :(i+1) * part_size] for i in range(num_parts)] + [filetered[num_parts * part_size:]]
    return "\n".join([", ".join(f"{key}: {value}" for key, value in part) for part in parts])



def get_perms_per_pair_model(rng, dataset_size, add_canary=False):
    indices = np.arange(dataset_size, dtype=int)
    hs = dataset_size // 2
    rng.shuffle(indices)
    if not add_canary:
        # we need exactly half of the dataset to be in the train/test sets
        set1_indices = indices[:hs]
        set2_indices = indices[hs:][hs:]
    else:
        canary_index = dataset_size
        # we assume add/remove relationship
        set1_indices = np.append(indices[:hs], canary_index)
        set2_indices = indices[hs:][hs:]
    return set1_indices[None, :], set2_indices[None, :]

def generate_model_perms_given_train_perms(keys, train_perms, num_models, num_epochs, steps_per_epoch, batch_size):
    """Generates a list of permutations for a single epoch ensuring that the canary is in every epoch"""
    model_perms = []
    for model_idx in range(num_models):
        epoch_perms = []
        for epoch_idx in range(num_epochs):
            key = keys[model_idx, epoch_idx]
            indices = jax.random.permutation(key, train_perms[model_idx])
            # pdb.set_trace()
            perms = indices[-(steps_per_epoch * batch_size):].reshape(steps_per_epoch, batch_size)[None, :] # this step skips the last batch if it is not full however it ensures that the canary is in every epoch (for models that include the canary) since the canary is always the last sample in the dataset
            epoch_perms.append(perms)
        model_perms.append(jnp.concatenate(epoch_perms, axis=0)[None, :])

    return jnp.concatenate(model_perms, axis=0)


def generate_train_test_model_perms(key, num_models, num_epochs, dataset_size, batch_size, force_num_steps_per_epoch=None):
    """Generates the train, test and (train) model permutaitons in the correct order for training models and evaluating the LiRA statistic afterwards.

    train_perms: Array(num_models, dataset_size)
    test_perms: Array(num_models, dataset_size)
    model_perms: Array(num_epochs, steps_per_epoch, batch_size, num_models)

    if `force_num_steps_per_epoch` is not None:
        model_perms: Array(num_epochs * (steps_per_epoch/force_num_steps_per_epoch), force_num_steps_per_epoch, batch_size, num_models)

    """
    if num_models % 2 != 0: # if odd number of models, we are doing a lira split with an evaluation model at the end
        raise ValueError("Need an even number of models (half for IN, half for OUT)")

    # generate train and test permutations
    key, *subkeys= jax.random.split(key, num=(num_models // 2 ) + 1)
    model_1_perms, model_2_perms = zip(*[get_perms_per_pair_model_jax(key, dataset_size, add_canary=True) for key in subkeys])

    train_perms = jnp.concatenate([*model_1_perms, *model_2_perms], axis=0)
    test_perms = jnp.concatenate([*model_2_perms, *model_1_perms], axis=0)

    # sanity check: ensure that each example is seen exactly half of the time in train, and half of the time not in train.
    assert np.all(pd.Series(train_perms.ravel()).value_counts().sort_index() == pd.Series(test_perms.ravel()).value_counts().sort_index())

    train_ds_size = train_perms[0].shape[0]
    steps_per_epoch = train_ds_size // batch_size
    canary_index = dataset_size

    key, *model_keys = jax.random.split(key, num=(num_models * num_epochs) + 1)
    model_keys = jnp.array(model_keys).reshape((num_models,num_epochs, 2))
    model_perms = generate_model_perms_given_train_perms(model_keys, train_perms, num_models, num_epochs, steps_per_epoch, batch_size)

    model_perms = model_perms.transpose((1, 2, 3, 0)) # final order: num_epochs, steps_per_epoch, batch_size, num_models
    model_perms_in = model_perms[:, :, :, :num_models // 2]
    model_perms_out = model_perms[:, :, :, num_models // 2:]

    # sanitry check: make sure all permutations are unique
    assert jnp.all(jnp.sort(jnp.unique(model_perms[0, :, :, 0].ravel())) == jnp.sort(model_perms[0, :, :, 0].ravel()))
    # sanity check: every IN model contains the canary in at least one epoch
    assert jnp.all(jnp.sum(model_perms_in == canary_index, axis=(0, 1, 2)) > 0)
    # sanity check: no OUT model contains the canary
    assert not jnp.any(model_perms_out == canary_index)

    if force_num_steps_per_epoch is not None:
        raise NotImplementedError("Not supported yet")

    return train_perms, test_perms, model_perms_in, model_perms_out
