import jax.numpy as jnp
import numpy as np
import jax
from functools import partial
from torch.utils import data
import os
import torch
from jax.scipy.linalg import eigh
from jax import vmap, pmap, jit, lax, devices




os.environ["CUDA_VISIBLE_DEVICES"]="0"

num_devices = len(devices())


def compute_u_from_index(index, X, x_train, row_mean, matrix_mean, evecs_t, num_test_samples=None):
    num_examples, num_samples, dim = x_train.shape[0], x_train.shape[1], x_train.shape[2]
    if num_test_samples is None:
       num_test_samples = num_samples
    x_test = jnp.array(X[index])
    samples = x_test.reshape(1,num_samples,dim)[:,:num_test_samples,:]

    samples_inner_product_1 = vmap(vmap(kme,(None,0)),(0,None))(x_train[:num_examples//5].reshape(num_examples//5, num_samples, dim),samples.reshape(-1,num_test_samples, dim))
    samples_inner_product_2 = vmap(vmap(kme,(None,0)),(0,None))(x_train[num_examples//5:2*num_examples//5].reshape(num_examples//5,num_samples, dim),samples.reshape(-1,num_test_samples, dim))
    samples_inner_product_3 = vmap(vmap(kme,(None,0)),(0,None))(x_train[2*num_examples//5:3*num_examples//5].reshape(num_examples//5,num_samples, dim),samples.reshape(-1,num_test_samples, dim))
    samples_inner_product_4 = vmap(vmap(kme,(None,0)),(0,None))(x_train[3*num_examples//5:4*num_examples//5].reshape(num_examples//5,num_samples, dim),samples.reshape(-1,num_test_samples, dim))
    samples_inner_product_5 = vmap(vmap(kme,(None,0)),(0,None))(x_train[4*num_examples//5:num_examples].reshape(num_examples//5,num_samples, dim),samples.reshape(-1,num_test_samples, dim))

    samples_inner_product = jnp.concatenate([samples_inner_product_1,samples_inner_product_2,samples_inner_product_3,samples_inner_product_4,samples_inner_product_5],axis=0)

    test_total = -jnp.mean(samples_inner_product.flatten())*jnp.ones((num_examples,1)) + row_mean.reshape(num_examples,1) + matrix_mean * jnp.ones((num_examples,1)) +  samples_inner_product
    u = evecs_t.T @ test_total
    u = u.T
    u = u.reshape(-1)
    return u







def k_fn(x, y, sigma=15.0):
    diffs = (x - y) / sigma
    r2 = jnp.sum((diffs**2))
    return jnp.exp(-0.5 * r2)

kernel = vmap(vmap(k_fn, in_axes=(None,0)), in_axes=(0,None))

@jit
def kme(x, y):
    return jnp.mean(kernel(x,y))



def extract_digits():
  idx = []
  for k in range(100):
    if k//10 % 10 < 5 and k % 10 < 5:
      idx.append(k)
    elif k//10 % 10 >= 9 and k % 10 >= 5:
      idx.append(k)
    if k//10 % 10 >= 5 and k//10 % 10 <= 8 :
      idx.append(k)
  return jnp.array(idx)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def compute_u_train_test_ref(X):
    X = jnp.array(X.detach().cpu())
    training_idx = extract_digits()
    remaining_indices = np.setdiff1d(np.arange(100), training_idx)
    num_examples = 70
    num_samples, dim = X.shape[1], X.shape[2]
    x_train = X[training_idx].reshape(num_examples,num_samples,dim)
    x_test = X[remaining_indices].reshape(30,num_samples,dim)
    u_train = jnp.mean(x_train, axis=1)
    u_test = jnp.mean(x_test, axis=1)
    u_train = torch.from_numpy(np.array(u_train))
    u_test = torch.from_numpy(np.array(u_test))
    return u_train, u_test
    
def sort_eigen(x):
    # eigenvectors might be not unique by multiplying -1. 
    # Since we want to use the standardized eigen, we standardize them such that the first term of each eigen is positive
    sign = jnp.sign(x[0]).reshape(1,x.shape[1])
    sign = jnp.tile(sign, (x.shape[0],1))
    x = x * sign
    return x

def compute_u_train_test(X, num_sensors=10):
    X = jnp.array(X.detach().cpu())
    training_idx = extract_digits()
    num_examples = 70
    num_samples, dim = X.shape[1], X.shape[2]
    x_train = X[training_idx].reshape(num_examples,num_samples,dim)

    num_devices = len(devices())
    rows, cols = jnp.triu_indices(num_examples)
    idx = jnp.stack([rows, cols]).T
    num_entries = idx.shape[0]
    entries_per_device = num_entries // num_devices
    batch_size = 5
    num_iter = entries_per_device / batch_size
    if not num_iter.is_integer():
        # print('Number of devices: %d' % (num_devices))
        # print('Total number of entries: %d' % (num_entries))
        # print('Entries per device: %d' % (entries_per_device))
        # print('Entries per device per batch: %d' % (batch_size))
        # print('Number of iterations per device: %d' % (num_iter))
    #else:
        raise ValueError('Entries per device not divisble by batch-size. Try choosing a different batch-size.')

    # Compute entries
    idx = idx.reshape(num_devices, int(num_iter), batch_size, 2)
    kme_fn = lambda idx: kme(x_train[idx[0],:,:], x_train[idx[1],:,:])
    def body_fn(carry, idx):
        out = vmap(kme_fn)(idx)
        return out, out

    scan_fn = lambda idx: lax.scan(body_fn, jnp.zeros(batch_size), idx)[1]
    entries = pmap(scan_fn, axis_name='i')(idx)

    # Construct covariance matrix
    cov = jnp.zeros((num_examples, num_examples))
    cov = cov.at[rows,cols].set(entries.reshape(num_entries,))
    cov = cov + cov.T - jnp.diag(jnp.diag(cov))

    # Centered covariance matrix
    matrix_mean = jnp.mean(cov)
    row_mean = -jnp.mean(cov, axis=0)
    row_mean_tile = jnp.tile(row_mean,(num_examples,1))
    cov = cov + matrix_mean + row_mean_tile + row_mean_tile.T
    jnp.save('data/mnist_cov.npy',cov)
    print("Done")
    # Compute eigendecomposition
    evals, evecs = eigh(cov)
    evecs = sort_eigen(evecs)
    idx = jnp.abs(evals).argsort()[::-1]
    evals = evals[idx]
    evecs = evecs[:,idx]
 

    evals_t = evals[:num_sensors]
    evecs_t = evecs[:,:num_sensors]
    ratio = jnp.sum(evals_t)/jnp.sum(evals)
    print("Variance Explained by first {} components: {}".format(num_sensors,ratio))
    # # take first num_sensors largest evals
    u_train = evecs_t * evals_t[None,:]
    remaining_indices = np.setdiff1d(np.arange(100), training_idx)

    u_test = []
    for k in range(len(remaining_indices)):
        u = compute_u_from_index(remaining_indices[k], X, x_train, row_mean, matrix_mean, evecs_t)
    #print(idx)
        u_test.append(u)
    u_test = jnp.array(u_test)
    u_train = torch.from_numpy(np.array(u_train))
    u_test = torch.from_numpy(np.array(u_test))
    return u_train, u_test

def compute_u_test(X, num_sensors=10, num_test_samples = 20):
    X = jnp.array(X.detach().cpu())
    training_idx = extract_digits()
    num_examples = 70
    num_samples, dim = X.shape[1], X.shape[2]
    x_train = X[training_idx].reshape(num_examples,num_samples,dim)

    num_devices = len(devices())
    rows, cols = jnp.triu_indices(num_examples)
    idx = jnp.stack([rows, cols]).T
    num_entries = idx.shape[0]
    entries_per_device = num_entries // num_devices
    batch_size = 5
    num_iter = entries_per_device / batch_size
    if not num_iter.is_integer():
        # print('Number of devices: %d' % (num_devices))
        # print('Total number of entries: %d' % (num_entries))
        # print('Entries per device: %d' % (entries_per_device))
        # print('Entries per device per batch: %d' % (batch_size))
        # print('Number of iterations per device: %d' % (num_iter))
    #else:
        raise ValueError('Entries per device not divisble by batch-size. Try choosing a different batch-size.')

    # Compute entries
    idx = idx.reshape(num_devices, int(num_iter), batch_size, 2)
    kme_fn = lambda idx: kme(x_train[idx[0],:,:], x_train[idx[1],:,:])
    def body_fn(carry, idx):
        out = vmap(kme_fn)(idx)
        return out, out

    scan_fn = lambda idx: lax.scan(body_fn, jnp.zeros(batch_size), idx)[1]
    entries = pmap(scan_fn, axis_name='i')(idx)

    # Construct covariance matrix
    cov = jnp.zeros((num_examples, num_examples))
    cov = cov.at[rows,cols].set(entries.reshape(num_entries,))
    cov = cov + cov.T - jnp.diag(jnp.diag(cov))

    # Centered covariance matrix
    matrix_mean = jnp.mean(cov)
    row_mean = -jnp.mean(cov, axis=0)
    row_mean_tile = jnp.tile(row_mean,(num_examples,1))
    cov = cov + matrix_mean + row_mean_tile + row_mean_tile.T
    #jnp.save('data/mnist_cov.npy',cov)
    print("Done")
    # Compute eigendecomposition
    evals, evecs = eigh(cov)
    evecs = sort_eigen(evecs)
    idx = jnp.abs(evals).argsort()[::-1]
    evals = evals[idx]
    evecs = evecs[:,idx]
 

    evals_t = evals[:num_sensors]
    evecs_t = evecs[:,:num_sensors]
    ratio = jnp.sum(evals_t)/jnp.sum(evals)
    print("Variance Explained by first {} components: {}".format(num_sensors,ratio))
    # # take first num_sensors largest evals
    remaining_indices = np.setdiff1d(np.arange(100), training_idx)

    u_test = []
    for k in range(len(remaining_indices)):
        u = compute_u_from_index(remaining_indices[k], X, x_train, row_mean, matrix_mean, evecs_t, num_test_samples=num_test_samples)
    #print(idx)
        u_test.append(u)
    u_test = jnp.array(u_test)
    u_test = torch.from_numpy(np.array(u_test))
    return u_test