import jax.numpy as jnp
import numpy as np
import jax
from functools import partial
from torch.utils import data
import os
import torch
from jax.scipy.linalg import eigh
from jax import vmap, pmap, jit, lax, devices




os.environ["CUDA_VISIBLE_DEVICES"]="0"
# Data generator

num_devices = len(devices())








def k_fn(x, y, sigma=25.0):
    diffs = (x - y) / sigma
    r2 = jnp.sum((diffs**2))
    return jnp.exp(-0.5 * r2)

kernel = vmap(vmap(k_fn, in_axes=(None,0)), in_axes=(0,None))

@jit
def kme(x, y):
    return jnp.mean(kernel(x,y))




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



X_1 = torch.load("EMNIST.pt")
X_2 = torch.load("omniglot_train.pt")
X_3 = torch.load("omniglot_test.pt")
X= torch.cat([X_1,X_2,X_3], dim=0)
num_examples = 964+53+X_3.shape[0]
num_samples = 20

#X = torch.load("data/characters.pth")
#X = jnp.array(X)

X = X[:num_examples,:,:]
X = X.reshape(num_examples,num_samples,-1)
X = jnp.array(X)

load_cov = False

if not load_cov:
    num_devices = len(devices())
    rows, cols = jnp.triu_indices(num_examples)
    idx = jnp.stack([rows, cols]).T
    num_entries = idx.shape[0]
    entries_per_device = num_entries // num_devices
    batch_size = 1
    num_iter = entries_per_device / batch_size
    if not num_iter.is_integer():
        # print('Number of devices: %d' % (num_devices))
        # print('Total number of entries: %d' % (num_entries))
        # print('Entries per device: %d' % (entries_per_device))
        # print('Entries per device per batch: %d' % (batch_size))
        # print('Number of iterations per device: %d' % (num_iter))
    #else:
        raise ValueError('Entries per device not divisble by batch-size. Try choosing a different batch-size.')

    # Compute entries
    idx = idx.reshape(num_devices, int(num_iter), batch_size, 2)
    kme_fn = lambda idx: kme(X[idx[0],:,:], X[idx[1],:,:])
    def body_fn(carry, idx):
        out = vmap(kme_fn)(idx)
        return out, out

    scan_fn = lambda idx: lax.scan(body_fn, jnp.zeros(batch_size), idx)[1]
    entries = pmap(scan_fn, axis_name='i')(idx)

    # Construct covariance matrix
    cov = jnp.zeros((num_examples, num_examples))
    cov = cov.at[rows,cols].set(entries.reshape(num_entries,))
    cov = cov + cov.T - jnp.diag(jnp.diag(cov))

    # Centered covariance matrix
    matrix_mean = jnp.mean(cov)
    row_mean = -jnp.mean(cov, axis=0)
    row_mean_tile = jnp.tile(row_mean,(num_examples,1))
    cov = cov + matrix_mean + row_mean_tile + row_mean_tile.T
    jnp.save('data/shapes_cov.npy',cov)
    print("Done")
else:
    cov = jnp.load('data/shapes_cov.npy')
# Compute eigendecomposition
evals, evecs = eigh(cov)
idx = jnp.abs(evals).argsort()[::-1]
evals = evals[idx]
evecs = evecs[:,idx]

num_sensors = 1024
evals_t = evals[:num_sensors]
evecs_t = evecs[:,:num_sensors]
ratio = jnp.sum(evals_t)/jnp.sum(evals)
print("Variance Explained by first {} components: {}".format(num_sensors,ratio))


