#!/usr/bin/env python

import os, time
import sys
import argparse
import yaml
import numpy as np
import pandas as pd
import sqlite3
import torch
import torch.nn as nn
from torch.nn.utils import _stateless
import functools
from functorch import make_functional, vmap, jacrev
from model import FullyConnected
from scipy.interpolate import interp1d
from Critical_Parameters import get_all_initializations
import gc

#from IPython import embed

#from mm_plots import mm_plots

device=None
pre_act_logs={}
pre_act_names=[]
# post_act_logs={}
# post_act_names=[]
fnet=None


def pre_activation_hook(name):
    global pre_act_logs
    """
        Run activation hooks
            inst: torch.nn.module
            inp: input to forward method
            out: output from forward method
        return:
            inp , out
        
    
    """
    def hook(model, input, output):
        pre_act_logs[name] = output.detach()
    return hook

def post_activation_hook(name):
    global pre_act_logs
    """
        Run activation hooks
            inst: torch.nn.module
            inp: input to forward method
            out: output from forward method
        return:
            inp , out
        
    
    """
    def hook(model, input, output):
        post_act_logs[name] = output.detach()
    return hook
       
def ntk(module: nn.Module, input1: torch.Tensor, input2: torch.Tensor,parameters: dict[str, nn.Parameter] = None,
    compute='full') -> torch.Tensor:
    if compute=='full':
        einsum_expr = 'Naf,Mbf->NMab'
    elif compute=='trace':
        einsum_expr = 'Naf,Maf->NM'
    elif compute=='diagonal':
        einsum_expr = 'Naf,Maf->NMa'
    else:
        raise ValueError(compute)

    if parameters is None:
        parameters = dict(module.named_parameters())
    keys, values = zip(*parameters.items())

    def func(*params: torch.Tensor, _input: torch.Tensor = None):
        _output: torch.Tensor = _stateless.functional_call(
            module, {n: p for n, p in zip(keys, params)}, _input)
        return _output  # (N, C)

    jac1: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
        functools.partial(func, _input=input1), values, vectorize=True)
    jac2: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
        functools.partial(func, _input=input2), values, vectorize=True)
    jac1 = [j.flatten(2) for j in jac1]
    jac2 = [j.flatten(2) for j in jac2]
    result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)]).sum(0)
    return result

 
def fnet_single(params, x):
    global fnet
    return fnet(params, x.unsqueeze(0)).squeeze(0)

def empirical_ntk(fnet_single, params, x1, x2, compute='full'):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = [j.flatten(2) for j in jac1]
    
    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = [j.flatten(2) for j in jac2]
    
    # Compute J(x1) @ J(x2).T
    einsum_expr = None
    if compute == 'full':
        einsum_expr = 'Naf,Mbf->NMab'
    elif compute == 'trace':
        einsum_expr = 'Naf,Maf->NM'
    elif compute == 'diagonal':
        einsum_expr = 'Naf,Maf->NMa'
    elif compute == 'minimal':
        einsum_expr = 'Naf,Mbf->NM'
    else:
        assert False
        
    result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result

def count_parameters(model,verbose=False):
    if verbose:
        table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        if verbose:
            table.add_row([name, params])
        total_params+=params
    if verbose:
        print(table)
        print("Total Trainable Params:"+str(total_params))
    return total_params

def empirical_dntk(fnet_single,params,x1,x2,compute='trace',compute_ntk='full'):
    dNTK_tensor=jacrev(empirical_ntk,argnums=1)(fnet_single, params, x1, x2, compute_ntk)
    if compute=='full':
        return [dNTK_tensor[i]/count_parameters(net) for i in range(len(dNTK_tensor))]
    elif compute=='trace':
        if compute_ntk=='trace':
            dNTK=torch.stack([f.flatten(2).sum(dim=2) for f in dNTK_tensor]).sum(dim=0)
        elif compute_ntk=='full':
            dNTK=torch.stack([f.flatten(3).sum(dim=3) for f in dNTK_tensor]).sum(dim=0)
        return dNTK

def sample(test_input,network_params,n_samples):
    data = []
    width, depth, input_dim, activation, Cw, Cb, control, power = network_params
   
    # Loop over all architectural enembles
    for net_sample_index in range(n_samples):

        # Only store one network at a time
        torch.cuda.empty_cache()
        
        # Create current instance
        global pre_act_logs, post_act_logs,pre_act_names,post_act_names
        pre_act_logs={}
        pre_act_names=[]
        post_act_logs={}
        post_act_names=[]
       
        if activation in ["relu-cswish","relu-cgelu"]:
            act = "relu"
        else:
            act=activation
          
        Net = FullyConnected(width, depth, input_dim, act, Cw, Cb, control, power)
#         for i in  range(Net.num_post_layers):
#             Net.post_layers[i].register_forward_hook(post_activation_hook("post_layers["+str(i)+"]"))
#             post_act_names.append("post_layers["+str(i)+"]")
        for i in  range(Net.num_pre_layers):
            Net.pre_layers[i].register_forward_hook(pre_activation_hook("pre_layers["+str(i)+"]"))
            pre_act_names.append("pre_layers["+str(i)+"]")
            
        Net.to(device)
        
        
        # compute NTK
        
        # convert network to function
        global fnet
        fnet, params = make_functional(Net)
        
        params_new=[]
        for i in range(0,len(params)):
            if i%2==1:
                params_new.append(torch.nn.parameter.Parameter(params[i]*width))
            else:
                params_new.append(params[i])
        params_new=tuple(params_new)
        
        NTKhat_ab=empirical_ntk(fnet_single, params, test_input.to(torch.device("cpu")),
                      test_input.to(torch.device("cpu")), compute='minimal').to(device)
        
        NTKhat_ab=NTKhat_ab.data.cpu().numpy()

        
        # verify that zout is the last element from post_act_logs
        z_out = Net(test_input.to(torch.device("cpu"))).to(device)
        assert np.sum(abs(z_out.data.cpu().numpy()-pre_act_logs[pre_act_names[-1]].data.cpu().numpy()))==0.0
       
        # the output before the last layers (l = L-1)
#         z_out = pre_act_logs[-1]
        sigma_out = Net.activation(pre_act_logs[pre_act_names[-1]])
        sigma_pre_out = Net.activation(pre_act_logs[pre_act_names[-2]])

       # One-point correlator
           

       ### Two-point correlator
       ## Also track the metric G (4.72), which should
       ## predict the next layer's 2-pt correlator (4.73)

        # Definition of Ghat_ab: Eq.(4.70) for the last layer
        Ghat_ab = Cb + Cw*torch.einsum('aj,bj->ab', sigma_out, sigma_out)/width
        Ghat_ab = Ghat_ab.data.cpu().numpy()
        
        # Definition of Ghat_ab: Eq.(4.70) for the pre-last layer
        Ghat_pre_ab = Cb + Cw*torch.einsum('aj,bj->ab', sigma_pre_out, sigma_pre_out)/width
        Ghat_pre_ab = Ghat_pre_ab.data.cpu().numpy()
            
            
        # sum_ij Z_{a,i} Z_{b,j} for the last layer
        ZZhat_ab = torch.einsum('ai,bj->ab', z_out, z_out)
        ZZhat_ab = ZZhat_ab.data.cpu().numpy()
            
            
        # sum_i Z_{a,i} for the last layer
        Zhat_a = np.sum(z_out.data.cpu().numpy(),axis=1)
            
            
        ### Three-point correlator
        # sum_ijk Z_{a,i} Z_{b,j} Z_{c,k} for the last layer
        ZZZhat_abc = torch.einsum('ai,bj,ck->abc', z_out, z_out, z_out)
        ZZZhat_abc = ZZZhat_abc.data.cpu().numpy()
            
            
        ### Four-point correlator    
        # sum_ijkl Z_{a,i} Z_{b,j} Z_{c,k} Z_{d,l} for the last layer
        ZZZZhat_abcd = torch.einsum('ai,bj,ck,dl->abcd', z_out, z_out,z_out, z_out)
        ZZZZhat_abcd = ZZZZhat_abcd.data.cpu().numpy()
            
        data.append([input_dim,width,depth,Cw,Cb,
                     Ghat_ab,Ghat_pre_ab,ZZhat_ab,Zhat_a,ZZZhat_abc,ZZZZhat_abcd,NTKhat_ab])
        
    return data

def ntk_stats(NTKhat):
    NTK = np.mean(NTKhat,axis=0)
    
    # mean 
    NTK_L1 = np.mean(np.abs(NTK))
    NTK_diag_L1 = np.mean(np.abs(NTK.diagonal()))
    
    return [NTK_L1, NTK_diag_L1]
    

def second_moments(Zhat,ZZhat):
    Z = np.mean(Zhat,axis=0)
    ZZ = np.mean(ZZhat,axis=0)
    
    # mean 2-pt correlator
    ZZ_L1 = np.mean(np.abs(ZZ))
    #ZZ_L2 = np.mean(ZZ**2)
    ZZ_diag_L1 = np.mean(np.abs(ZZ.diagonal()))
    #ZZ_diag_L2 = np.mean(ZZ.diagonal()**2)
        
    ZZ_off_diag = ZZ.copy()
    ZZ_off_diag[range(ZZ.shape[0]),range(ZZ.shape[1])] =0.0
    ZZ_off_diag_L1 = np.mean(np.abs(ZZ_off_diag))
    #ZZ_off_diag_L2 = np.mean(ZZ_off_diag**2)
    
    # 2-pt connected correlator
    Z_Z=np.einsum('a,b->ab',Z,Z)
    ZZC = ZZ-Z_Z
        
    ZZC_L1 = np.mean(np.abs(ZZC))
    ZZC_L2 = np.mean(ZZC**2)
        
    ZZC_diag_L1 = np.mean(np.abs(ZZC.diagonal()))
    #ZZC_diag_L2 = np.mean(ZZC.diagonal()**2)
 
    ZZC_off_diag = ZZC.copy()
    ZZC_off_diag[range(ZZC.shape[0]),range(ZZC.shape[1])] =0.0
    ZZC_off_diag_L1 = np.mean(np.abs(ZZC_off_diag))
    #ZZC_off_diag_L2 = np.mean(ZZC_off_diag**2)
        
    return [[ZZ_L1,ZZ_diag_L1,ZZ_off_diag_L1],
            [ZZC_L1,ZZC_diag_L1,ZZC_off_diag_L1]]
    
def third_moments(Zhat,ZZhat,ZZZhat):
    
    Z = np.mean(Zhat,axis=0)
    ZZ = np.mean(ZZhat,axis=0)
    ZZZ = np.mean(ZZZhat,axis=0)
    
    # 3-pt correlator
    ZZZ_L1 = np.mean(np.abs(ZZZ))
    #ZZZ_L2 = np.mean(ZZZ**2)
    #ZZZ_diag = ZZZ[range(ZZZ.shape[0]),range(ZZZ.shape[1]),range(ZZZ.shape[2])]
    #ZZZ_diag_L1 = np.mean(np.abs(ZZZ_diag))
    #ZZZ_diag_L2 = np.mean(ZZZ_diag**2)
        
    #ZZZ_off_diag = ZZZ.copy()
    #ZZZ_off_diag[range(ZZZ.shape[0]),range(ZZZ.shape[1]),range(ZZZ.shape[2])] =0.0
    #ZZZ_off_diag_L1 = np.mean(np.abs(ZZZ_off_diag))
    #ZZZ_off_diag_L2 = np.mean(ZZZ_off_diag**2)

    # 3-pt connected correlator
        
    Z_Z_Z=np.einsum('a,b,c->abc',Z,Z,Z)
    Z_ZZ=np.einsum('a,bc->abc',Z,ZZ)
    ZZZC = (ZZZ -
             np.transpose(Z_ZZ,axes=[0,1,2]) +
             np.transpose(Z_ZZ,axes=[1,0,2]) +
             np.transpose(Z_ZZ,axes=[2,0,1]) +
             2*Z_Z_Z)

        
    ZZZC_L1 = np.mean(np.abs(ZZZC))
    #ZZZC_L2 = np.mean(ZZZC**2)

    #ZZZC_diag = ZZZC[range(ZZZC.shape[0]),range(ZZZC.shape[1]),range(ZZZC.shape[2])]
    #ZZZC_diag_L1 = np.mean(np.abs(ZZZC_diag))
    #ZZZC_diag_L2 = np.mean(ZZZC_diag**2)
        
        
    return [ZZZ_L1,ZZZC_L1]
  
def fourth_moments(Zhat,ZZhat,ZZZhat,ZZZZhat):
    
    Z = np.mean(Zhat,axis=0)
    ZZ = np.mean(ZZhat,axis=0)
    ZZZ = np.mean(ZZZhat,axis=0)
    ZZZZ = np.mean(ZZZZhat,axis=0)
    
    # 4-pt correlator
    ZZZZ_L1 = np.mean(np.abs(ZZZZ))
    #ZZZZ_L2 = np.mean(ZZZZ**2)
    #ZZZZ_diag = ZZZZ[range(ZZZZ.shape[0]),range(ZZZZ.shape[1]),
    #                      range(ZZZZ.shape[2]),range(ZZZZ.shape[3])]
    #ZZZZ_diag_L1 = np.mean(np.abs(ZZZZ_diag))
    #ZZZZ_diag_L2 = np.mean(ZZZZ_diag**2)
        
    #ZZZZ_off_diag = ZZZZ.copy()
    #ZZZZ_off_diag[range(ZZZZ.shape[0]),range(ZZZZ.shape[1]),
    #              range(ZZZZ.shape[2]),range(ZZZZ.shape[3])] =0.0
    #ZZZZ_off_diag_L1 = np.mean(np.abs(ZZZZ_off_diag))
    #ZZZZ_off_diag_L2 = np.mean(ZZZZ_off_diag**2)
        
    # 4-pt connected correlator
    Z_Z_Z_Z=np.einsum('a,b,c,d->abcd',Z,Z,Z,Z)
    Z_ZZZ=np.einsum('a,bcd->abcd',Z,ZZZ)
    ZZ_ZZ=np.einsum('ab,cd->abcd',ZZ,ZZ)
    Z_Z_ZZ=np.einsum('a,b,cd->abcd',Z,Z,ZZ)
        
    ZZZZC = (ZZZZ
             #
             - np.transpose(Z_ZZZ,axes=[0,1,2,3])
             - np.transpose(Z_ZZZ,axes=[1,0,2,3])
             - np.transpose(Z_ZZZ,axes=[2,0,1,3])
             - np.transpose(Z_ZZZ,axes=[3,0,2,1])
             #
             - np.transpose(ZZ_ZZ,axes=[0,1,2,3])
             - np.transpose(ZZ_ZZ,axes=[0,2,1,3])
             - np.transpose(ZZ_ZZ,axes=[0,3,2,1])
             #
             + 2*np.transpose(Z_Z_ZZ,axes=[0,1,2,3])
             + 2*np.transpose(Z_Z_ZZ,axes=[0,2,1,3])
             + 2*np.transpose(Z_Z_ZZ,axes=[0,3,2,1])
             + 2*np.transpose(Z_Z_ZZ,axes=[1,2,0,3])
             + 2*np.transpose(Z_Z_ZZ,axes=[1,3,2,0])
             + 2*np.transpose(Z_Z_ZZ,axes=[2,3,1,0])
             #
             - 6*Z_Z_Z_Z)

    ZZZZC_L1 = np.mean(np.abs(ZZZZC))
    #ZZZZC_L2 = np.mean(ZZZZC**2)
    
    #ZZZZC_diag = ZZZZC[range(ZZZZC.shape[0]),range(ZZZZC.shape[1]),range(ZZZZC.shape[2]),range(ZZZZC.shape[3])]
    #ZZZZC_diag_L1 = np.mean(np.abs(ZZZZC_diag))
    #ZZZZC_diag_L2 = np.mean(ZZZZC_diag**2)
    
    return [ZZZZ_L1,ZZZZC_L1]
        
def metrics(Ghat,Ghat_pre,width):
    # mean metric
    G = np.mean(Ghat,axis=0)
    G_L1 = np.mean(np.abs(G))
    #G_L2 = np.mean(G**2)
        
    G_diag = G.diagonal()
    G_diag_L1 = np.mean(np.abs(G_diag))
    #G_diag_L2 = np.mean(G_diag**2)
        
    G_off_diag = G.copy()
    G_off_diag[range(G.shape[0]),range(G.shape[1])] =0.0
    G_off_diag_L1 = np.mean(np.abs(G_off_diag))
    #G_off_diag_L2 = np.mean(G_off_diag**2)
    
    G_pre = np.mean(Ghat_pre,axis=0)
    G_pre_L1 = np.mean(np.abs(G_pre))
    #G_L2 = np.mean(G**2)
        
    G_pre_diag = G_pre.diagonal()
    G_pre_diag_L1 = np.mean(np.abs(G_pre_diag))
    #G_diag_L2 = np.mean(G_diag**2)
        
    G_pre_off_diag = G_pre.copy()
    G_pre_off_diag[range(G_pre.shape[0]),range(G_pre.shape[1])] =0.0
    G_pre_off_diag_L1 = np.mean(np.abs(G_pre_off_diag))
    
    n_samples = Ghat.shape[0]
    
    # 4-pt vertex from eq.(4.76)
    V_abcd = width*(np.einsum('nab,ncd->abcd',Ghat_pre,Ghat_pre)/n_samples -
             np.einsum('ab,cd->abcd',G_pre,G_pre))
    
    #V_L2 = np.mean(V_abcd**2)
    V_L1 = np.mean(np.abs(V_abcd))
    #V_diag = V_abcd[range(V_abcd.shape[0]),range(V_abcd.shape[1]),range(V_abcd.shape[2]),range(V_abcd.shape[3])]
    #V_diag_L1 = np.mean(np.abs(V_diag))
    #V_diag_L2 = np.mean(V_diag**2)
    
    return [G_L1,G_diag,G_diag_L1,G_off_diag_L1,G_pre_L1,G_pre_diag,G_pre_diag_L1,G_pre_off_diag_L1,V_L1]

def process_correlations(data,width):
    # mean 1-pt correlator
    Zhat_a = np.array([row[8] for row in data])
    Z_a = np.mean(Zhat_a,axis=0)
    Z_L1 = np.mean(np.abs(Z_a))
        
    # 2-pt correlators
    ZZhat_ab = np.array([row[7] for row in data])
    [[ZZ_L1,ZZ_diag_L1,ZZ_off_diag_L1], [ZZC_L1,ZZC_diag_L1,ZZC_off_diag_L1]] = second_moments(Zhat_a,ZZhat_ab)
    
    # 3-pt correlators
    ZZZhat_abc = np.array([row[9] for row in data])
    [ZZZ_L1,ZZZC_L1] = third_moments(Zhat_a,ZZhat_ab,ZZZhat_abc) 
     
   
    # 4-pt correlators
    ZZZZhat_abcd = np.array([row[10] for row in data])
    [ZZZZ_L1,ZZZZC_L1] = fourth_moments(Zhat_a,ZZhat_ab,ZZZhat_abc,ZZZZhat_abcd)
    del ZZZZhat_abcd
    gc.collect()

    del ZZhat_ab
    gc.collect()    
    del Zhat_a
    gc.collect()
    del ZZZhat_abc
    gc.collect()
    # metric
    Ghat_ab = np.array([row[5] for row in data])
    Ghat_pre_ab = np.array([row[6] for row in data])
    [G_L1,G_diag,G_diag_L1,G_off_diag_L1,G_pre_L1,G_pre_diag,G_pre_diag_L1,G_pre_off_diag_L1,V_L1] = metrics(Ghat_ab,Ghat_pre_ab,width)
    del Ghat_ab
    gc.collect()
    del Ghat_pre_ab
    gc.collect()
        
    
    return  [Z_a,
             Z_L1,
             ZZ_L1,
             ZZ_diag_L1,
             ZZ_off_diag_L1,
             ZZC_L1,
             ZZC_diag_L1,
             ZZC_off_diag_L1] + \
            [
             G_L1,
             G_diag,
             G_diag_L1,
             G_off_diag_L1] + \
            [
             G_pre_L1,
             G_pre_diag,
             G_pre_diag_L1,
             G_pre_off_diag_L1] + \
            [
             ZZZ_L1,
             ZZZC_L1] + \
            [
             ZZZZ_L1,
             ZZZZC_L1] + \
            [V_L1]


def process_moments(data, n_samples):
    # Analyse the raw data collected on a first pass over the ensemble
    
   
    # Loop over all the samples from each configuration of hyperpameters
    
    # input_dim    - 0
    # width        - 1
    # depth        - 2
    # Cw           - 3
    # Cb           - 4
    # Ghat_ab      - 5
    # Ghat_pre_ab  - 6
    # ZZhat_ab     - 7
    # Zhat_a       - 8
    # ZZZhat_abc   - 9
    # ZZZZhat_abcd - 10
    # NTKhat_ab    - 11
        
   
    
    
    # Recover network hyperparameters
    input_dim, width, depth, Cw, Cb = data[0][:5]
        
    
    [Z_a,Z_L1,ZZ_L1,ZZ_diag_L1,ZZ_off_diag_L1,
    ZZC_L1, ZZC_diag_L1,ZZC_off_diag_L1,
    G_L1,G_diag,G_diag_L1,G_off_diag_L1,G_pre_L1,
    G_pre_diag,G_pre_diag_L1,G_pre_off_diag_L1,
    ZZZ_L1,ZZZC_L1,ZZZZ_L1,
    ZZZZC_L1,V_L1] = process_correlations(data,width)




    # NTK
    NTKhat_ab = np.array([row[11] for row in data])
    [NTK_L1, NTK_diag_L1] = ntk_stats(NTKhat_ab)
    del NTKhat_ab
    gc.collect()    

    # Record measurements
    return [input_dim, width, depth, Cw, Cb, n_samples, Z_L1] + \
            [
             ZZ_L1,
             ZZ_diag_L1,
             ZZ_off_diag_L1,
             ZZC_L1,
             ZZC_diag_L1,
             ZZC_off_diag_L1] + \
            [
             G_L1,
             G_diag_L1,
             G_off_diag_L1] + G_diag.tolist() + \
            [
             G_pre_L1,
             G_pre_diag_L1,
             G_pre_off_diag_L1] + G_pre_diag.tolist() +\
            [
             ZZZ_L1,
             ZZZC_L1] + \
            [
             ZZZZ_L1,
             ZZZZC_L1] + \
            [V_L1] + [NTK_L1, NTK_diag_L1]

def save_data(post_data,seed,activation,config,critical,control,ninput):
    gdiag_list=[]
    gpre_diag_list=[]
    for i in range(ninput):
        gdiag_list.append('G_diag_'+str(i))
        gpre_diag_list.append('G_pre_diag_'+str(i))
    df = pd.DataFrame(
        [post_data],
        columns= 
        ['input_dim','width','depth','Cw','Cb','n_samples'] + 
        ['Z_L1'] +    
        [
          'ZZ_L1',
          'ZZ_diag_L1',
          'ZZ_off_diag_L1',
          'ZZC_L1',
          'ZZC_diag_L1',
          'ZZC_off_diag_L1'] +
        [
          'G_L1',
          'G_diag_L1',
          'G_off_diag_L1'] + gdiag_list +
        [
          'G_pre_L1',
          'G_pre_diag_L1',
          'G_pre_off_diag_L1'] + gpre_diag_list +
        [
          'ZZZ_L1',
          'ZZZC_L1'] +
        [
          'ZZZZ_L1',
          'ZZZZC_L1'] +
        ['V_L1'] + ['NTK_L1', 'NTK_diag_L1']
    )
            
    df['ratio'] = df['depth']/df['width']
    df['seed'] = seed
    df['critical'] =critical
    df['act'] = activation
    df['control'] = control
    df['power'] = config['ARCHITECTURE']['activation_power']
    
    # create the results folder
    folder = config["RESULT"]["path"]
    
    if not os.path.exists(folder):
        os.mkdir(folder) 
        
    # save the data to sql database
    dbname = folder +"/stats.db"
    connection = sqlite3.connect(dbname, timeout = 180);
    df.to_sql("metric", connection, if_exists="append");
    connection.close();

     
    
def parse_args():
    """Parse args."""
    # Initialize the command line parser
    parser = argparse.ArgumentParser()
    # Read command line argument
    parser.add_argument('--config', default='config/input_relu.yaml', type=str, help='where the hyperparameters are stored, yaml file')
    parser.add_argument('--seed', default=-1, type=int, help='seed for RNG, if positive will be used instead of the one from the config file')
    parser.add_argument('--stats_name', default='metric', type=str, help='name of table for stats db')
    parser.add_argument('--critical', default=-1, type=int, help='if to perform critical initialization')
    parser.add_argument('--depth', default=-1, type=int, help='depth for computations')
    parser.add_argument('--activation_control',default=-1.0,type=float,help='activation control parameter, if positive will be used instead of the config value')
    parser.add_argument('--device', default='cpu', type=str, help='cuda/cpu')
    args = parser.parse_args()
    
    return args

def main():
    """Main."""
    args = parse_args()
    
    with open(args.config, "r") as stream:
        try:
            config = yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)
            
    # Get cpu or gpu device for training.
    if args.device == "cuda":
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.cuda.empty_cache()
    elif args.device=="cpu":
        device=torch.device("cpu")
        
    print("Using "+str(device)+" device")
        
   
       
    # architecture parameters    
    architecture = config['ARCHITECTURE']
    
    widths = architecture['width']
    depth = architecture['depth']
    if args.depth>0:
        depth = args.depth
    input_dim = architecture['input_dim']
    variance = architecture['variance']
    
    
    activation = architecture['activation']
    control = architecture['activation_control']
    power = architecture['activation_power']
    
    critical = args.critical
    
    if args.activation_control>0.0:
        control = args.activation_control/100.0
    
    if activation in ["swish", "gelu", "gumbel", "guderman","algebraic", "relu-cswish","relu-cgelu"]:
        if critical ==1:
            print("critical")
            if activation in ["relu-cswish","swish"]:
                act="swish"
            elif activation in ['relu-cgelu','gelu']:
                act="gelu"
            else:
                act=activation
            df = get_all_initializations()
            Kstar = control ** 2 * df.loc[df['act'] == act]['kstar'].values[0]
            Cw = df.loc[df['act'] == act]['cw'].values[0]
            Cb = control ** 2 * df.loc[df['act'] == act]['cb'].values[0]
            #print(Cw,": ",Cb)
            variance = [Cw, Cb]
    else:
            variance = [2.0, 0.0] 
    print("activation: "+activation) 
    
    # sampling parameters
    seed = config['SAMPLING']['seed']
    if args.seed>0:
        seed=args.seed
    n_samples = config['SAMPLING']['num_samples']
    n_inputs = config['SAMPLING']['num_inputs']
    
    
    # set random seed for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    # create the results folder
    folder = config["RESULT"]["path"]
    if not os.path.exists(folder):
        os.mkdir(folder)
    
    Cw=variance[0]; Cb=variance[1]
    with torch.no_grad():
        # Generate common inputs for all networks of same xdim
        if not os.path.exists(folder+"/input_data.pt"):
            print("generate input dataset")
            test_input = torch.linspace(0,5,n_inputs+1).to(device)
            input_tensor=[]
            for i in range(input_dim):
                input_tensor+=[test_input[1:]]
            input_tensor=torch.stack(input_tensor).to(device)
            input_tensor=input_tensor.T
            torch.save(input_tensor,folder+"/input_data.pt")
        else:
            input_tensor = torch.load(folder+"/input_data.pt").to(device)
            if critical == 1 and (activation in ["swish","gelu","relu-cswish","relu-cgelu","gumbel","guderman","algebraic" ]):
                input_tensor*=np.sqrt(Kstar)
        print("Cw,Cb = "+str(Cw)+", "+str(Cb))
        print("depth = "+str(depth))
        for width in widths:
            print("width = "+str(width))
            network_params = width, depth, input_dim, activation, Cw, Cb, control, power
            time_start = time.perf_counter()
            data = sample(input_tensor,network_params,n_samples)
            time_end = time.perf_counter()
            print("time to samle = ",(time_end - time_start))
            time_start = time.perf_counter()
            data = process_moments(data,n_samples)
            time_end = time.perf_counter()
            print("time to process data = ",(time_end - time_start))
            save_data(data,seed,activation,config,critical,control,n_inputs)
            del data
            gc.collect()
            

    # n_cols=np.shape(total_data)[2]
    
    # num_variances=len(variance_list)
    # num_widths=len(widths)
    # num_depths=len(depths)
    # num_input_dims=len(input_dims)
    # n_rows=num_variances*num_depths*num_widths*num_input_dims*n_samples

    # total_data = np.reshape(total_data,[n_rows,n_cols])
        
    # # Make a DataFrame
    # df = pd.DataFrame(total_data,
    # columns=['input_dim','width','depth','Cw','Cb','index','zout_mean','zout_std','G_L2_avg','G_diag_L2_avg','G_offdiag_L2_avg']
    # )


    
if __name__ == '__main__':
    """Entry point."""

    time_start = time.perf_counter()
    main()
    time_elapsed = (time.perf_counter() - time_start)
    print('All done, in ' + str(time_elapsed) + 's')
