# -*- coding: utf-8 -*-


import numpy as np
from statistics import *
from scipy.stats import sigmaclip
import matplotlib.pyplot as plt
sigma_tau = 0.0
sigma_epsilon = 0.0

def calculate_renormalization(X, x, Y, P, N_0, width_range):
  for N in width_range:
    u_0_lists.append([])
    u_large_lists.append([])
    u_large_true_lists.append([])
    f_lists.append([])
    f_large_lists.append([])
    df_lists.append([])
    df_large_lists.append([])
    alpha = P/N
    for rep in range(nb_rep):
      K_0 = np.dot(X,X.T)
      K_0_x_X = np.dot(x,X.T)
      K_0_x_x = np.dot(x,x.T)
      print(K_0.shape)
      print(K_0_x_X.shape)
      W_large = np.random.multivariate_normal(np.zeros(N_0), 1/N_0*np.eye(N_0), size = large_N)
      H_large = np.dot(W_large, X.T)
      H_x_large = np.dot(W_large, x.T)
      phi_H_large = H_large
      phi_H_x_large = H_x_large
      #phi_H_large = np.maximum(H_large,np.zeros(H_large.shape))
      #phi_H_x_large = np.maximum(H_x_large, np.zeros(H_x_large.shape))
      K_large = (1/large_N)*np.dot(phi_H_large.T,phi_H_large)
      K_x_x_large = (1/large_N)*np.dot(phi_H_x_large.T,phi_H_x_large)
      K_x_X_large = (1/large_N)*np.dot(phi_H_x_large.T,phi_H_large)
      r_0 = 1/P*np.dot(Y.T,np.dot(np.linalg.pinv(K_0),Y))
      print(r_0)
      Delta = (alpha - 1)**2 + 4*alpha*r_0
      u_0_plus = (1 - alpha + np.sqrt(Delta))/2
      u_0_minus = (1 - alpha - np.sqrt(Delta))/2
      print(u_0_plus, u_0_minus)
      u_0_lists[-1].append(u_0_plus)
      r_large = 1/P*np.dot(Y.T,np.dot(np.linalg.pinv(K_large),Y))
      print(r_large)
      Delta_large = (alpha - 1)**2 + 4*alpha*r_large
      u_large_plus = (1 - alpha + np.sqrt(Delta_large))/2
      u_large_minus = (1 - alpha - np.sqrt(Delta_large))/2
      print(u_large_plus, u_large_minus)
      u_large_lists[-1].append(u_large_plus)
      W = np.random.multivariate_normal(np.zeros(N_0), 1/N_0*np.eye(N_0), size = N)
      H = np.dot(W, X.T)
      H_x = np.dot(W, x.T)
      phi_H = H
      phi_H_x = H_x
      #phi_H = np.maximum(H,np.zeros(H.shape))
      #phi_H_x = np.maximum(H_x,np.zeros(H_x.shape))
      K = (1/N)*np.dot(phi_H.T,phi_H)
      K_x_X = (1/N)*np.dot(phi_H_x.T, phi_H)
      K_x_x = (1/N)*np.dot(phi_H_x.T,phi_H_x)
      f_mp_2 = np.dot(K_x_X,np.dot(np.linalg.pinv(K),Y.T))**2
      f_large_2 = np.dot(K_x_X_large,np.dot(np.linalg.pinv(K_large),Y.T))**2
      #f_large_2 = np.dot(K_0_x_X,np.dot(np.linalg.pinv(K_0),Y.T))**2
      df_mp = K_x_x - np.dot(K_x_X, np.dot(np.linalg.pinv(K), K_x_X.T))
      df_large = K_x_x_large - np.dot(K_x_X_large, np.dot(np.linalg.pinv(K_large), K_x_X_large.T))
      #df_large = K_0_x_x - np.dot(K_0_x_X, np.dot(np.linalg.pinv(K_0), K_0_x_X.T))
      f_lists[-1].append(f_mp_2)
      df_lists[-1].append(df_mp)
      f_large_lists[-1].append(f_large_2)
      df_large_lists[-1].append(df_large)
      u_large_true_lists[-1].append(df_mp/df_large)
  return True

def calculate_marchenko_pastur_map(eig_values, n , N, d):
  #Iterate over the recursive sequence to reach the solution to Marchenko-Pastur fixed point equation
  gamma_in = n/d
  nb_bins = 5000
  density = np.histogram(eig_values, bins = nb_bins, density = True)
  print(density)
  old_real_part = density[1]
  old_rho_mp = np.concatenate((np.ones(1),density[0]))
  mask = np.where((old_rho_mp != 0.) & (old_real_part <= 20))
  old_rho_mp = old_rho_mp[mask]
  old_rho_mp[0] = 0.
  old_real_part = old_real_part[mask]
  right = 20
  left = 0.0001
  nb_dis =2000
  real_part = np.linspace(left,right,nb_dis)
  rho_mp = np.interp(real_part, old_real_part, old_rho_mp)

  z_grid = real_part + 0.0001j
  epsilon = 0.0001
  gamma =n/N
  m_prev = np.random.normal(loc =0, scale = 1.0, size = z_grid.shape)
  m_next = ((right-left)/nb_dis)*sum([rho_mp[x]/(real_part[x]*(1 -gamma -gamma*z_grid*m_prev) - z_grid) for x in range(real_part.shape[0])])
  iter =1
  while np.max(np.abs(m_next- m_prev)) > epsilon:
    m_prev = m_next
    m_next = ((right-left)/nb_dis)*sum([rho_mp[x]/(real_part[x]*(1 -gamma -gamma*z_grid*m_prev) - z_grid) for x in range(real_part.shape[0])])
    iter += 1
  print(m_next)
  print(iter)
  marchenko_pastur_map = (1/np.pi)*np.abs(m_next.imag)
  return real_part, marchenko_pastur_map

import sys
def calculate_predictor_estimates_N(X,x,Y,P, N_0, N, marchenko_pastur_map, real_part):
  likelihood_mp_list = []
  mean_mp_list = []
  variance_mp_list = []
  M= 1000
  steps = np.array([real_part[x] - real_part[x-1] for x in range(1,real_part.shape[0])])
  cdf = np.cumsum(marchenko_pastur_map[1:]*steps)
  for rep in range(nb_rep):
    uniform_samples = np.tile(np.random.uniform(0, 1, M), (steps.shape[0],1)).T
    mp_diff = cdf - uniform_samples
    mp_samples = real_part[np.where(mp_diff >=0, mp_diff, np.inf).argmin(axis=1)]
    Lambda_K = np.diag(mp_samples)
    V_K = np.random.multivariate_normal(np.zeros(M),1/M*np.eye(M), size = P)
    v_K = np.random.multivariate_normal(np.zeros(M), 1/M*np.eye(M), size = 1).T
    K_approx = np.dot(V_K, np.dot(Lambda_K, V_K.T))
    K_x_X_approx = np.dot(V_K, np.dot(Lambda_K, v_K))
    K_x_x_approx = np.dot(v_K.T, np.dot(Lambda_K, v_K))
    K_inv_approx = np.linalg.pinv(K_approx)
    likelihood_mp = 1/(np.sqrt(2*np.pi)**P)*(1/np.sqrt(np.linalg.det(K_approx)))*np.exp(-1/2*np.dot(Y.T,np.dot(K_inv_approx,Y)))
    mean_mp = np.dot(K_x_X_approx.T, np.dot(K_inv_approx,Y))*likelihood_mp
    variance_mp = (K_x_x_approx - np.dot(K_x_X_approx.T, np.dot(K_inv_approx,K_x_X_approx)))*likelihood_mp
    likelihood_mp_list.append(likelihood_mp)
    mean_mp_list.append(mean_mp)
    variance_mp_list.append(variance_mp)
  full_likelihood_mp = mean(likelihood_mp_list)
  f_approx_array= np.array(mean_mp_list)/full_likelihood_mp
  df_approx_array = np.array(variance_mp_list)/full_likelihood_mp
  return f_approx_array, df_approx_array

def calculate_predictor_estimates(X,x,Y,P, N_0, width_range, mnist = False):
  for N in width_range:
      alpha = P/N

      eigenvalues_K = []
      eigenvalues_K_0 = []
      eigenvalues_K_large = []
      for rep in range(10):
        K_0 = np.dot(X,X.T)
        W = np.random.multivariate_normal(np.zeros(N_0), np.eye(N_0), size = N)
        W_large = np.random.multivariate_normal(np.zeros(N_0), np.eye(N_0), size = large_N)
        H = np.dot(W, X.T)
        H_large = np.dot(W_large, X.T)
        #phi_H = np.maximum(H,np.zeros(H.shape))
        #phi_H_large = np.maximum(H_large,np.zeros(H_large.shape))
        #phi_H = np.tanh(H)
        #phi_H_large = np.tanh(H_large)
        phi_H = H
        phi_H_large = H_large
        print(phi_H.shape)
        print(phi_H_large.shape)
        K = (1/N)*np.dot(phi_H.T,phi_H)
        K_large = (1/large_N)*np.dot(phi_H_large.T,phi_H_large)
        print("K: ", K.shape)
        print("K_large: ", K_large.shape)
        eigenvalues_K_large = eigenvalues_K_large + list(np.linalg.eigvals(K_large).real)
        eigenvalues_K = eigenvalues_K + list(np.linalg.eigvals(K).real)
        eigenvalues_K_0 = eigenvalues_K_0 + list(np.linalg.eigvals(K_0).real)
      #eigenvalues_K_large = np.array(eigenvalues_K_large)[np.where(np.array(eigenvalues_K_large) <= 1.5)]
      if mnist:
        real_part, marchenko_pastur_map = calculate_marchenko_pastur_map_mnist(eigenvalues_K_large,P, N, N_0)
      else:
        real_part, marchenko_pastur_map = calculate_marchenko_pastur_map(eigenvalues_K_large,P, N, N_0)

      print(mean(eigenvalues_K))
      fig = plt.figure()
      plt.hist(eigenvalues_K, density= True, bins =1000, label = 'Simulated')
      #bins =100
      #left, right = plt.xlim()
      left, right = (-0.1, 1.5)
      plt.xlim((left,right))
      plt.plot(real_part, marchenko_pastur_map, label = 'Theoretical')
      plt.xlabel('Eigenvalues of output kernel random matrix')
      plt.ylabel('Probability density')
      plt.legend()
      plt.show()

      fig = plt.figure()
      plt.hist(eigenvalues_K_0, density= True, bins =100, label = 'Simulated')
      left, right = plt.xlim()
      plt.legend()
      plt.figure()
      plt.xlim((left,right))
      plt.hist(eigenvalues_K_large, density= True, bins =100, label = 'Simulated')
      plt.xlabel('Eigenvalues of output kernel random matrix')
      plt.ylabel('Probability density')
      plt.legend()
      plt.show()

      f_approx_array, df_approx_array = calculate_predictor_estimates_N(X,x,Y,P, N_0, N, marchenko_pastur_map, real_part)
      f_approx_lists.append(f_approx_array)
      df_approx_lists.append(df_approx_array)
  return True

!pip3 install pyro-ppl

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

pyro.set_rng_seed(101)

PyroLinear = pyro.nn.PyroModule[torch.nn.Linear]

class BNN(pyro.nn.PyroModule):

    def __init__(self, input_size, hidden_size):
        super().__init__()

        self.fc1 = PyroLinear(input_size, hidden_size)
        self.fc1.weight = pyro.nn.PyroSample(dist.Normal(0., 1.).expand([hidden_size, input_size]).to_event(2))
        self.fc1.bias   = pyro.nn.PyroSample(dist.Normal(0., 1.).expand([hidden_size]).to_event(1))

        self.fc2 = PyroLinear(hidden_size, 1)
        self.fc2.weight = pyro.nn.PyroSample(dist.Normal(0., 1.).expand([1, hidden_size]).to_event(2))

        self.relu = torch.nn.ReLU()

    def forward(self, x, y=None):
        x = self.relu(self.fc1(x))
        sigma = pyro.sample("sigma", dist.Uniform(0., 0.1))
        mean = self.fc2(x)
        print("mean", mean.shape)
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean,sigma), obs=y)
            print(obs.shape)
        return mean

def train_bnn(X,x,Y,P, N_0, width_range):
  for N in width_range:
    bnn = BNN(N_0, N)
    guide = pyro.infer.autoguide.AutoDiagonalNormal(bnn)
    pyro.clear_param_store()
    svi = SVI(model=bnn,
          guide=guide,
          optim=Adam({"lr": 1e-2}),
          loss=Trace_ELBO(num_particles=1))
    num_epochs = 1

    train_dataset = TensorDataset(torch.tensor(X).to(torch.float32),torch.tensor(Y).to(torch.float32))

    #print(len(train_dataset))
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)

    losses = []
    for i in range(num_epochs):
      for batch in train_loader:
        losses.append(svi.step(batch[0], batch[1]))

    predictive = pyro.infer.Predictive(model=bnn, guide=guide, num_samples=50)

    predictions = predictive.forward(torch.tensor(x.reshape(1,N_0)).to(torch.float32))['obs'][:,0,0]

    pred_bnn_lists.append(predictions)
  return True

#Renormalization factor IID Gaussian

u_0_lists = []
u_large_lists = []
u_large_true_lists = []
f_lists = []
f_large_lists = []
df_lists= []
df_large_lists = []
f_approx_lists = []
df_approx_lists = []
u_approx_lists = []
pred_bnn_lists = []
nb_rep=10
max_N = 1000
large_N = 10*max_N
width_range = range(20,max_N, 50)
P = 200
N_0 = 500
alpha_0 = P/N_0
mean_x = np.zeros(N_0)
cov_x = 1/N_0*np.eye(N_0)
beta = 10*np.ones(N_0)
X_raw = np.random.multivariate_normal(mean_x, cov_x, size = P+1)
tau = np.random.normal(loc = 0, scale = sigma_tau, size = P+1)
Y = np.dot(X_raw[0:P,:],beta) + tau[0:P]
X = X_raw[0:P,:]
x = X_raw[P,:]
y = np.dot(x, beta) + tau[P]

calculate_renormalization(X,x,Y,P,N_0,width_range)
calculate_predictor_estimates(X,x,Y,P, N_0, width_range, mnist = False)
u_approx_list = [np.mean(df_approx_lists[i])/mean(df_large_lists[i]) for i in range(len(df_lists))]
train_bnn(X,x,Y, P, N_0, width_range)

plt.figure()
plt.plot(list(width_range), u_approx_list)
plt.errorbar(x=list(width_range), y = [mean(u_large_list) for u_large_list in u_large_lists], yerr= [2*stdev(u_large_list)/len(u_large_list) for u_large_list in u_large_lists], marker ='o', linestyle ='')

plt.figure()
plt.plot(list(width_range), [torch.mean(pred_bnn) for pred_bnn in pred_bnn_lists])
plt.errorbar(x=list(width_range), y = [np.mean(f_approx_list) for f_approx_list in f_approx_lists], yerr= [2*np.std(f_approx_list)/f_approx_list.shape[0] for f_approx_list in f_approx_lists], marker ='o', linestyle ='')

plt.figure()
plt.plot(list(width_range), [torch.std(pred_bnn)**2 for pred_bnn in pred_bnn_lists])
plt.errorbar(x=list(width_range), y = [np.mean(df_approx_list) for df_approx_list in df_approx_lists], yerr= [2*np.std(df_approx_list)/df_approx_list.shape[0] for df_approx_list in df_approx_lists], marker ='o', linestyle ='')

import os
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

from statistics import *

div_factor = 120
class SubLoader(datasets.MNIST):
    def __init__(self, *args, **kwargs):
        super(SubLoader, self).__init__(*args, **kwargs)
        dataset_size = self.targets.shape[0]
        mask1 = np.where((self.targets == 0) | (self.targets ==1))
        print(self.data.shape)
        self.data = self.data[mask1]
        print(self.data.shape)
        self.targets = self.targets[mask1]
        dataset_size = self.targets.shape[0]
        mask2 = np.random.randint(low=0,high=dataset_size,size=dataset_size//div_factor)
        print(self.data.shape)
        self.data = self.data[mask2]
        print(self.data.shape)
        self.targets = self.targets[mask2]
        print(self.targets)
    def get_data(self):
      return self.data, self.targets



#train_dataset = datasets.MNIST('./data', train=True, download=True,
train_dataset = SubLoader('./data', train=True, download=True,
                               transform=transforms.Compose([
                               transforms.ToTensor(),
                               ]))

#test_dataset = datasets.MNIST('./data', train=False, download=True,
test_dataset = SubLoader('./data', train=False, download=True,
                              transform=transforms.Compose([
                              transforms.ToTensor(),
                              ]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_dataset.targets.shape[0], shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset,  batch_size=test_dataset.targets.shape[0], shuffle=True)

import sys
import matplotlib.pyplot as plt
fig1, ax1 = plt.subplots()
fig2, ax2 = plt.subplots()
fig3, ax3 = plt.subplots()
for batch in train_loader:

  ax1.imshow(batch[0][13,0,:,:])
  ax2.imshow(batch[0][10,0,:,:])
  ax3.imshow(batch[0][49,0,:,:])
  break

def calculate_marchenko_pastur_map_mnist(eig_values, n , N, d):
  #Iterate over the recursive sequence to reach the solution to Marchenko-Pastur fixed point equation
  gamma_in = n/d
  nb_bins = 5000
  density = np.histogram(eig_values, bins = nb_bins, density = True)
  print(density)
  old_real_part = density[1]
  old_rho_mp = np.concatenate((np.ones(1),density[0]))
  mask = np.where((old_rho_mp != 0.))
  #& (old_real_part <= 20)
  old_rho_mp = old_rho_mp[mask]
  old_rho_mp[0] = 0.
  old_real_part = old_real_part[mask]
  right = 20
  left = 0.0001
  nb_dis =2000
  real_part = np.linspace(left,right,nb_dis)
  rho_mp = np.interp(real_part, old_real_part, old_rho_mp)


  z_grid = real_part + 0.0001j
  epsilon = 0.0001
  gamma =n/N
  m_prev = np.random.normal(loc =0, scale = 1.0, size = z_grid.shape)
  m_next = ((right-left)/nb_dis)*sum([rho_mp[x]/(real_part[x]*(1 -gamma -gamma*z_grid*m_prev) - z_grid) for x in range(real_part.shape[0])])
  iter =1
  while np.max(np.abs(m_next- m_prev)) > epsilon:
    m_prev = m_next
    m_next = ((right-left)/nb_dis)*sum([rho_mp[x]/(real_part[x]*(1 -gamma -gamma*z_grid*m_prev) - z_grid) for x in range(real_part.shape[0])])
    iter += 1
  print(m_next)
  print(iter)
  marchenko_pastur_map = (1/np.pi)*np.abs(m_next.imag)
  return real_part, marchenko_pastur_map

#Renormalization MNIST
u_0_lists = []
u_large_lists = []
u_large_true_lists = []
f_lists = []
f_large_lists = []
df_lists= []
df_large_lists = []
f_approx_lists = []
df_approx_lists = []
nb_rep=10
max_N = 1000
large_N = 10*max_N
#width_range = range(20,max_N, 50)
width_range = [500,600]
for batch in train_loader:
    break
X_raw = batch[0].reshape(batch[0].shape[0],-1)
P = X_raw.shape[0] -1
X = X_raw[0:P,:]
x = X_raw[P,:]
Y = batch[1][0:P]
print(Y.shape)
y = batch[1][P]
N_0 = X.shape[1]

calculate_renormalization(X,x,Y,P,N_0,width_range)
calculate_predictor_estimates(X,x,Y,P, N_0, width_range, mnist = True)
u_approx_list = [np.mean(df_approx_lists[i])/mean(df_large_lists[i]) for i in range(len(df_lists))]

plt.figure()
plt.plot(list(width_range), u_approx_list)
plt.errorbar(x=list(width_range), y = [mean(u_large_list) for u_large_list in u_large_lists], yerr= [2*stdev(u_large_list)/len(u_large_list) for u_large_list in u_large_lists], marker ='o', linestyle ='')
