import time
from functools import partial
from math import sqrt
import numpy as np
import torch
import jax
import jax.numpy as jnp
from sklearn.utils import gen_batches, check_random_state


def get_inverse_permutation(perm):
    """Compute the inverse of a permutation"""
    cnt = len(perm)
    iperm = np.empty(cnt, dtype=np.uint32)
    for n in np.arange(cnt):
        iperm[perm[n]] = n
    return iperm


class HopfieldNet(object):
    """Classical Hopfield Network.
    XXX TODO: Use PyTorch instead of JAX here!
    """

    def update(self, *args, **kwargs):
        """Apply update operator"""
        return self.retrieve(*args, **kwargs)

    def retrieve(self, X, Q, batch_size=200, backend="cpu",
                 scale=1.):
        device = jax.devices(backend=backend)[0]
        X = jax.device_put(X, device)
        Q = jax.device_put(Q, device)
        n_train, n_features = X.shape
        W = X.T@X / scale ** 2
        W *= 1. - jnp.eye(n_features)

        out = jax.device_put(jnp.zeros_like(Q), device)
        for batch in gen_batches(len(out), batch_size):
            q = Q[batch]
            out = out.at[batch].set(jnp.sign(q @ W.T))
        return jax.device_get(out)


class SPSHopfieldNet(HopfieldNet):
    """GPU-based implementation of our proposed Product-of-Sums Hopfield
    Network (PSHN).
    """
    def __init__(self, n_features, group_size, random_state=None):
        HopfieldNet.__init__(self)
        assert n_features % group_size == 0
        self.n_features = n_features
        self.group_size = group_size
        self.random_state = random_state

    def retrieve(self, X, Q, permute_neurons=True, batch_size=200,
                 backend="cpu", random_state=None):
        """Memory retrieval: For each row of Q (i.e a query), we retrieve the
        corresponding row of X (the database of memories)
        """
        X = torch.tensor(X).to(backend)
        device = X.device
        print("Pytoch: Using device %s" % device)
        Q = torch.tensor(Q).to(device)
        n_groups = self.n_features // self.group_size
        assert X.shape[1] == self.n_features
        assert Q.shape[1] == self.n_features
        if permute_neurons:
            rng = check_random_state(random_state)
            perm = rng.permutation(self.n_features).astype(int)
            iperm = get_inverse_permutation(perm).astype(int)
            X = X[:, perm]
            Q = Q[:, perm]

        out = torch.zeros_like(Q, dtype=Q.dtype).to(device)
        out = out.reshape((len(out), n_groups, self.group_size))
        X_ = X.reshape(len(X), n_groups, self.group_size)
        for batch in gen_batches(len(Q), batch_size):
            q = Q[batch]
            q = q.reshape(len(q), n_groups, self.group_size) / self.group_size
            Z = torch.einsum('mkg,Mkg->mMk', q, X_)

            Z = Z.prod(axis=2, keepdims=True) / Z
            Z = torch.nan_to_num(Z, nan=0.)
            aux = torch.einsum('mMk,Mkg->mkg', Z, X_)
            out[batch] = aux
        out = out.reshape(-1, self.n_features)
        out = torch.sign(out)
        if permute_neurons:
            out = out[:, iperm]
        return out.cpu().numpy()


def F_(X, Q, mode="Ours", group_size=28, deg=20):
    n_features = X.shape[-1]
    overlap = X@Q.T /10 # sqrt(n_features)
    if mode == "Classical":
        return overlap ** 2
    elif mode == "Demircigil-poly":
        return overlap ** deg
    elif mode == "Demircigil-expo":
        return jnp.exp(overlap)
    else:
        raise NotImplementedError(mode)


def update(X, queries, permute_neurons=False, n_iters=1,
           mode="Demircigil-expo", random_state=None, verbose=1,
           deg=4, batch_size=200, backend="cpu", loopy=False,
           scale=1., **kwargs):
    """Vectorized implementation of Hopfield network update rule."""
    if verbose:
        print("Running", mode, "...")
    if mode == "Ours":
        n_features = len(X[0])
        permute_neurons = kwargs.pop("permute_neurons", True)
        sps_hnet = SPSHopfieldNet(n_features=n_features, **kwargs)
        out = sps_hnet.update(X, queries, permute_neurons=permute_neurons,
                              batch_size=batch_size,
                              backend=backend)
        return out
    if mode == "Classical":
        hnet = HopfieldNet()
        return hnet.update(X, queries, batch_size=batch_size,
                           backend=backend)

    # XXX jax-ify the following code block!
    devices = jax.devices(backend=backend)
    X = jax.device_put(X, devices[0])
    queries = jax.device_put(queries, devices[0])
    out = jax.device_put(jnp.zeros_like(queries, dtype=X.dtype), devices[0])

    queries = queries.copy()
    n_features = X.shape[1]
    F__ = partial(F_, mode=mode, deg=deg, **kwargs)
    rng = check_random_state(random_state)
    if permute_neurons:
        n_inner_iters = 1
    else:
        n_inner_iters = 1
    for _ in range(n_iters):
        for _ in range(n_inner_iters):
            out_ = jax.device_put(jnp.zeros_like(out), devices[0])
            if permute_neurons:
                perm = rng.permutation(n_features)
                iperm = get_inverse_permutation(perm)
                X = X[:, perm]
                queries = queries[:, perm]
            for batch in gen_batches(len(queries), batch_size):
                # XXX find a way to vectorize the following loop ?
                if loopy:
                    for neuron in range(n_features):
                        plus = queries[batch].copy()
                        minus = queries[batch].copy()
                        plus = plus.at[:, neuron].set(1.)
                        minus = minus.at[:, neuron].set(-1.)
                        energy_diff = F__(X, plus) - F__(X, minus)
                        if np.ndim(energy_diff) == 2:
                            energy_diff = energy_diff.mean(axis=0)
                        out_ = out_.at[batch, neuron].set(energy_diff)
                else:
                    q = queries[batch]
                    Z = q@X.T # mxM
                    Z /= scale
                    if mode == "Demircigil-expo":
                        Z = jnp.exp(Z)
                        out_ = out_.at[batch].set(jnp.sign(Z @ X))
                    elif mode == "Demircigil-poly":
                        Z = np.power(Z, deg - 1)
                        out_ = out_.at[batch].set(jnp.sign(Z @ X))
                    else:
                        raise NotImplementedError
            if permute_neurons:
                X = X[:, iperm]
                queries = queries[:, iperm]
                out_ = out_[:, iperm]
            out += out_
        out = jnp.sign(out)
        queries = out.copy()
    return out


def retrieve(*args, **kwargs):
    return update(*args, **kwargs)


def run_experiment(conf, db, corrupt_input_data,
                   true_input_data=None,
                   mega_bytes=1000, backend="cpu",
                   compute_metrics=None, verbose=1,
                   **kwargs):
    n_features = db.shape[-1]
    res = []
    deg = 4
    M = len(db)
    t0 = time.time()
    mode = conf["mode"]
    mem_unit = M * n_features
    batch_size = mega_bytes * 1000000 // mem_unit
    if mode == "Indexing":
        def retriever(X, Q):
            sim = corrupt_input_data@db.T
            return db[sim.argmax(axis=1)]
    else:
        retriever = partial(
            update, verbose=0,
            batch_size=batch_size, backend=backend,
            **{k: v for (k, v) in conf.items()
                if k not in ["method", "use_grouping"]},
            **kwargs)
    if conf.get("use_grouping", False):
        recon = grouper.predict(corrupt_input_data, final=retriever)
    else:
        recon = retriever(db, corrupt_input_data)
    duration = time.time() - t0

    if compute_metrics is None and true_input_data is not None:
        success = np.all(recon == true_input_data, axis=1).mean()
        metrics = dict(succ=success)
    else:
        metrics = compute_metrics(locals())
    metrics["duration"] = duration
    if verbose:
        print("conf=%s, DB size=%s, metrics=%s" % (conf, M, metrics))
    res.append(dict(M=M, **metrics,
                    **{k: v for (k, v) in conf.items()}))
    return res
