import numpy as np
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import matplotlib
matplotlib.rcParams['image.cmap'] = 'jet'
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def compute_graph_laplacian_matrix(w):
    n = w.shape[0]
    l = 0.0 - w
    d = np.sum(w, axis=1)

    for idx in range(n):
        l[idx, idx] += d[idx]
    return l

def generate_feature_values(zmat, zvec, l):
    dimh, dimy = zmat.shape
    sval, svec = np.linalg.eigh(l)
    print(sval[:6])
    y = np.dot(svec[:, :dimh], zmat)
    y = y / np.sqrt(np.mean(y * y) / dimy)
    y = y # + 0.1 * np.random.normal(0.0, 1.0, y.shape)
    f = np.dot(svec[:, :dimh], zvec).reshape(-1, 1)
    f = f / np.sqrt(np.mean(f * f))
    # f = np.dot(svec[:, (5 * dimh) : (6 * dimh)], zvec).reshape(-1, 1)
    return (y, f)

def generate_graph_weights_and_features(numv, dimy, dimx, dimh, gamma_ref=100.0):
    ### numv - number of verticies
    ### dimy - dimension of the ambient space
    ### dimh - dimension of the space of functions
    
    x = np.random.rand(numv, dimx)  ### example with a cube
    gamma = gamma_ref / (numv ** (1.0 / dimy))
    w = np.zeros((numv, numv))
    for idx0 in range(numv):
        for idx1 in range(numv):
            dx = x[idx1, :] - x[idx0, :]
            w[idx0, idx1] = np.exp(- gamma * np.sum(dx * dx))
    lmat = compute_graph_laplacian_matrix(w)
    for idx0 in range(numv):
        w[idx0, :] = w[idx0, :] / np.sum(w[idx0, :])
    zmat = np.random.normal(0.0, 1.0, (dimh, dimy))
    zvec = np.random.normal(0.0, 1.0, (dimh, 1))
    zvec[0, 0] = 0.0
    zvec[2:, 0] = 0.0 * zvec[2:, 0]
    zvec[1, 0] = 1.0
    y, f = generate_feature_values(zmat, zvec, lmat)
    # for idx in range(numv):
    #     y[idx, :] = y[idx, :] / np.sqrt(np.sum(y[idx, :] * y[idx, :]))
    # y = y / np.sqrt(np.mean(y * y)) * np.sqrt(dimy)
    # f = f / np.sqrt(np.mean(f * f))
    print('y.shape = ' + str(y.shape))
    print('np.mean(f * f) = ' + str(np.mean(f * f)))
    return (torch.from_numpy(w).double(),
            torch.from_numpy(y).double(),
            torch.from_numpy(f).double())


