import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn.functional as F
import argparse
import matplotlib
import pandas as pd

from tqdm import tqdm
import glob
import os
from datetime import datetime
from itertools import product
import time
import math
import sys
import ast
import torchvision.transforms as transforms

sys.path.append("./src")
from CIM_EBD_Models import EBDCorInfoMaxHopfield
from torch_utilsn import *

import warnings
warnings.filterwarnings("ignore")

from IPython.core.debugger import Pdb
sys.path.append("./src")
from IPython.display import clear_output

# NEW: Enable cudnn benchmark for performance
torch.backends.cudnn.benchmark = True


def my_clip(M, clevel, msg=''):
    if torch.sum(torch.abs(M) > clevel) > 0:
        print(msg)
    out = M * (abs(M) < clevel) + torch.sign(M) * clevel * (abs(M) >= clevel)
    return out


def update_list(lr, cnt2):
    # Calculating the index to keep
    index_to_keep = np.mod(cnt2, 3)
    # Setting all elements to zero except the one at index_to_keep
    new_lr = [0 if i != index_to_keep else lr[index_to_keep] for i in range(len(lr))]
    return new_lr


def run_experiment(
    lambda_eb_list=[0.999999],
    lambda_eb2_=0.99999999,
    beta=1,
    EPS_DIV=0.1,
    SCALE_FB=1.0,
    lr_decay_multiplier_list=[0.95],
    neural_lr_start_list=[0.5],  # 0.05/EPS_DIV with EPS_DIV=0.1
    neural_lr_decay_multiplier=0.01,
    neural_dynamic_iterations_nudged=10,
    neural_dynamic_iterations_free_list=[50],
    hopfield_g_list=[0.1],
    batchsize=1,
    n_epochs=25000,
    seed_list=[104, 114, 124, 134, 144, 154, 164, 174, 184, 194],
    lateral_init_scalev=1.0,
    lateral_scalev=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0,1.0, 1.0, 1.0, 1.0, 1.0],
    Wff_init_scalev=[1.0,1.0, 1.0, 1.0, 1.0, 1.0,1.0, 1.0, 1.0, 1.0, 1.0],
    subt_meanv=0,
    include_backv=0.0,
    momentum_ffv=0.9999,
    momentum_fbv=0.9999,
    act_l1_lr_ffv=[0.01575, 0.01575, 0.01575, 0.01575,0.0175, 0.01575, 0.01575, 0.01575, 0.00675, 0],
    layer_pow_targetv=[2.5, 2.5, 2.5, 2.5,2.5,2.5, 2.5, 2.5, 2, 5, 0.1],
    act_pow_lr_ffv=[0.002, 0.002,0.002, 0.002,0.002, 0.002, 0.002, 0.002, 0.005, 1e-10],
    lr_erv=[9.6, 9.6, 9.6, 9.6, 9.6, 9.6, 9.6, 9.6, 6.0, 10000.0],
    lr_er2v=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    br_update_lateralv=0.5/15,
    br_update_lateralv_cont=0.003333,
    lambda_list=[0.99999995],
    lr_frvv=1e-18,
    L1_DIVIDE=1e4,
    lr_weight_l1vv=[0, 0, 0, 0, 0, 0,0, 0, 0, 0],
    lr_weight_fb_l1vv=[0, 0, 0, 0, 0, 0,0, 0, 0, 0],
    lr_weight_bb_l2=[0, 0, 0, 0, 0, 0,0, 0, 0, 0],
    lr_weight_fb_l2=[0, 0, 0, 0, 0, 0,0, 0, 0, 0],
    lr_weight_l2=[0, 0, 0, 0, 0, 0,0, 0, 0, 0],
    weight_decayv=True,
    L2_SCALE=4.0,
    CNT_SCALE=3,
    pickle_name_for_results = "BPCIFAR10batch1DeepN1d.pkl",
    train_startv=[9,8,7,6,5,4,3,2,1,0,0,0],
    load_network_weightsv=0,
    pickle_name_for_weightsv='TestWeights.pkl',
    bias_include=0,
    weight_prune_period=10000,
    weight_prune_scale=0.00001,
    architecture = [32 * 32 * 3, 500, 500, 500, 500, 500, 500, 500, 500, 500, 10]
):
    # Compute epsilon and neural_lr_stop from EPS_DIV
    epsilon = 0.15 / EPS_DIV
    neural_lr_stop = 0.001 / EPS_DIV

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # Adjust file/directory organization
    current_directory = os.getcwd()
    working_path = current_directory
    os.chdir(working_path)

    if not os.path.exists("../Results"):
        os.mkdir("../Results")

    

    RESULTS_DF = pd.DataFrame(columns=['setting_number', 'seed', 'Model', 'Hyperparams', 'Trn_ACC_list', 'Tst_ACC_list'])

    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
                                         std=(3*0.2023, 3*0.1994, 3*0.2010))
    ])

    transform_train = transforms.Compose([
    # Pad by 4 pixels on each side, then random crop back to 32x32
    transforms.RandomCrop(32, padding=4),
    # Randomly flip the image horizontally
    transforms.RandomHorizontalFlip(),
    # Convert the PIL image to a Tensor
    transforms.ToTensor(),
    # Normalize using CIFAR-10 mean and std
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    ),])
    transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    ),])
    cifar_dset_train = torchvision.datasets.CIFAR10('data', train=True, transform=transform, download=True)
    train_loader = torch.utils.data.DataLoader(cifar_dset_train, batch_size=batchsize, shuffle=True, num_workers=4)
    train_loader2 = torch.utils.data.DataLoader(cifar_dset_train, batch_size=20, shuffle=True, num_workers=4)

    cifar_dset_test = torchvision.datasets.CIFAR10('data', train=False, transform=transform, download=True)
    test_loader = torch.utils.data.DataLoader(cifar_dset_test, batch_size=20, shuffle=False, num_workers=4)

    # Define the activation function (assuming hard_sigmoid is available from torch_utilsn)
    activation = hard_sigmoid

    

    # Define lr_start_list using SCALE_FB
    lr_start_list = [{
        'ff': np.array([0.1, 0.1, 0.1, 0.1,0.1,0.1,0.1, 0.11, 0.06, 0.035]),
        'fb': np.array([np.nan, 1.125 * SCALE_FB, 0.375 * SCALE_FB, 0.375 * SCALE_FB,  0.375 * SCALE_FB,0.375 * SCALE_FB, 0.375 * SCALE_FB,0.375 * SCALE_FB,0.375 * SCALE_FB, 0.375 * SCALE_FB, 0.375 * SCALE_FB])
    }]

    neural_lr_rule_list = ["constant"]
    use_three_phase_list = [False]

    setting_number = 0

    # Loop over the hyperparameter grid
    for lambda_, lambda_eb_, lr_start, lr_decay_multiplier, neural_lr_start, neural_lr_rule, neural_dynamic_iterations_free, hopfield_g, use_three_phase in product(
        lambda_list, lambda_eb_list, lr_start_list, lr_decay_multiplier_list,
        neural_lr_start_list, neural_lr_rule_list, neural_dynamic_iterations_free_list,
        hopfield_g_list, use_three_phase_list
    ):
        setting_number += 1
        hyperparams_dict = {
            "lr_start": lr_start,
            "lr_decay_multiplier": lr_decay_multiplier,
            "neural_dynamic_iterations_free": neural_dynamic_iterations_free,
            "neural_dynamic_iterations_nudged": neural_dynamic_iterations_nudged,
            "neural_lr_rule": neural_lr_rule,
            "neural_lr": neural_lr_start,
            "epsilon": epsilon,
            "lambda": lambda_,
            "lambda_eb": lambda_eb_,
            "architecture": architecture,
            "three_phase": use_three_phase
        }
        for seed_ in seed_list:
            np.random.seed(seed_)
            torch.manual_seed(seed_)
            trn_acc_list = []
            tst_acc_list = []
            Web_list = []
            Web2_list = []

            # Activation sparsity lists (initialization)
            act_sp_list0 = [0]
            act_sp_list1 = [0]
            act_sp_list2 = [0]
            # Derive act_l1_lr_fbv as in the original code
            act_l1_lr_fbv = [0] + act_l1_lr_ffv[:-1]

            # Initialize the model
            model = EBDCorInfoMaxHopfield(
                architecture=architecture,
                lambda_=lambda_,
                lambda_eb_=lambda_eb_,
                epsilon=epsilon,
                lr_fr=lr_frvv,
                lr_er=lr_erv,
                lr_er2=lr_er2v,
                include_forw=1,
                include_back=include_backv,
                use_preact=0,
                subt_mean=subt_meanv,
                br_update_lateral=br_update_lateralv,
                activation=activation,
                act_l1_lr_ff=act_l1_lr_ffv,
                act_l1_lr_fb=act_l1_lr_fbv,
                momentum_ff=momentum_ffv,
                momentum_fb=momentum_fbv,
                lateral_init_scale=lateral_init_scalev,
                lateral_scale=lateral_scalev,
                Wff_init_scale=Wff_init_scalev,
                layer_pow_target=layer_pow_targetv,
                act_pow_lr_ff=act_pow_lr_ffv,
                lr_weight_l2=lr_weight_l2,
                lr_weight_bb_l2=lr_weight_bb_l2,
                lr_weight_fb_l2=lr_weight_fb_l2,
                lr_weight_l1=lr_weight_l1vv,
               # lr_weight_l2v=[0,0,0,0,0,0,0,0,0,0],
               # lr_weight_fb_l2=[0,0,0,0,0,0,0,0,0,0],
                train_start=train_startv,
                load_network_weights=load_network_weightsv,
                pickle_name_for_weights=pickle_name_for_weightsv,
                bias_include=bias_include
            )

            # Save initial Web weights to measure change later
            #Webi = model.Web[0]['weight']
            #Web2i = model.Web[1]['weight']
            #Webin = torch.norm(Webi, 'fro') + 1e-15
            #Web2in = torch.norm(Web2i, 'fro') + 1e-15
            lr_weight_l2vv=lr_weight_l2
            lr_weight_bb_l2vv=lr_weight_bb_l2
            lr_weight_fb_l2vv=lr_weight_fb_l2
            
            cnt = 0
            cnt2 = 0
            for epoch_ in range(n_epochs):
                eind = epoch_ / 5 + 1
                lind = 1 / eind
                # Update lambda_eb dynamically
                lambdaeb = lambda_eb_ * lind + lambda_eb2_ * (1 - lind)
                model.lambda_eb_ = lambdaeb

                if epoch_ > 0:
                    #lr_erv = [40 * 2 * 1.2 / 10, 40 * 2 * 1.2 / 10, 40 * 2 * 1.2 / 10, 25 * 2 * 1.2 / 10.0, 1e5 / 10.0]
                             
                    br_update_lateralv = br_update_lateralv_cont

                model.args_w = 100 * torch.rand(1).item() + 0.5
                model.args_ph = 3.14 * torch.rand(1).item()
                model.non_fun_der_scale = 1.0
                for idx, (x, y) in tqdm(enumerate(train_loader)):
                    cnt += 1
                    if np.mod(cnt, 10) == 9:
                        cnt2 += 1

                    if np.mod(cnt, 100) == 99:
                        model.args_w = 100 * torch.rand(1).item() + 0.5
                        model.args_ph = 3.14 * torch.rand(1).item()
                        model.non_fun_der_scale = 1.0
                    x, y = x.to(device), y.to(device)
                    x = x.view(x.size(0), -1).T
                    y_one_hot = F.one_hot(y, 10).to(device).T
                    take_debug_logs_ = (idx % 500 == 0)
                    # Use passed CNT_SCALE
                    current_CNT_SCALE = CNT_SCALE

                    if epoch_ < 240:
                        model.momentum_ff = 1 / (cnt2 + 1) * 0.99 + (1 - 1 / (cnt2 + 1)) * 0.999
                        model.momentum_fb = 1 / (cnt2 + 1) * 0.99 + (1 - 1 / (cnt2 + 1)) * 0.999
                    if epoch_ < 400:
                        lr = {
                            'ff': lr_start['ff'] / (current_CNT_SCALE * cnt2 + 1),
                            'fb': lr_start['fb'] / (current_CNT_SCALE * cnt2 + 1)
                        }
                    if epoch_ < 240:
                        lr_weight_l1v = [val / (current_CNT_SCALE * cnt2 / 30 * 30 + 1) for val in lr_weight_l1vv]
                        lr_weight_fb_l1v = [val / (current_CNT_SCALE * cnt2 / 30 * 30 + 1) for val in lr_weight_fb_l1vv]
                        lr_weight_l2v = [val / (current_CNT_SCALE * cnt2 / 30 / 10 + 1) for val in lr_weight_l2vv]
                        lr_weight_fb_l2v = [val / (current_CNT_SCALE * cnt2 / 30 / 10 + 1) for val in lr_weight_fb_l2vv]
                        lr_weight_bb_l2v = [val / (current_CNT_SCALE * cnt2 / 30 / 10 + 1) for val in lr_weight_bb_l2vv]
                        model.br_update_lateral = br_update_lateralv * lr['ff'][0] / lr_start['ff'][0] / (1 / (current_CNT_SCALE * cnt2 / 1e3 + 1))
                        model.lr_er = [val / (1 / (current_CNT_SCALE * cnt2 / 1e3 + 1)) for val in lr_erv]
                
                    neurons = model.batch_step_hopfield(
                        x, y_one_hot, hopfield_g,
                        lr, neural_lr_start, neural_lr_stop, neural_lr_rule,
                        neural_lr_decay_multiplier, neural_dynamic_iterations_free,
                        neural_dynamic_iterations_nudged, beta,
                        use_three_phase, take_debug_logs_,
                        weight_decay=weight_decayv,
                        lr_weight_l1=lr_weight_l1v,
                        lr_weight_fb_l1=lr_weight_fb_l1v,
                        lr_weight_l2=lr_weight_l2v,
                        lr_weight_fb_l2=lr_weight_fb_l2v, lr_weight_bb_l2=lr_weight_bb_l2v, epoch=epoch_
                    )
                    q = torch.argmax(neurons[2], axis=0)
                    # Update training accuracy estimate
                    tr_ar = 0.99 * (torch.sum(1.0 * (q == y)).item() / batchsize) if cnt2 > 0 else 0
                    act_sp_list0.append(0.99 * act_sp_list0[-1] + 0.01 * (torch.sum(neurons[0] == 0).item() / neurons[0].numel()))
                    act_sp_list1.append(0.99 * act_sp_list1[-1] + 0.01 * (torch.sum(neurons[1] == 0).item() / neurons[1].numel()))
                    act_sp_list2.append(0.99 * act_sp_list2[-1] + 0.01 * (torch.sum(neurons[2] == 0).item() / neurons[2].numel()))
                    if np.mod(cnt,weight_prune_period)==(weight_prune_period-1):
                        for wpi in range(len(model.Wff)):
                            WW=model.Wff[wpi]['weight']
                            WWscale=torch.max(torch.abs(WW))*weight_prune_scale
                            WWn=WW*(torch.abs(WW)>WWscale)
                            model.Wff[wpi]['weight']=WWn
                        for wpi in range(len(model.B)):
                            WW=model.B[wpi]['weight']
                            WWscale=torch.max(torch.abs(WW))*weight_prune_scale
                            WWn=WW*(torch.abs(WW)>WWscale)
                            model.B[wpi]['weight']=WWn
                        
                        
                    
                    if np.mod(idx, 50000) == 49999:
                        trn_acc = evaluateEBDCorInfoMaxHopfield(
                            model, train_loader2, hopfield_g, neural_lr_start,
                            neural_lr_stop, neural_lr_rule,
                            neural_lr_decay_multiplier, neural_dynamic_iterations_free,
                            device, printing=False
                        )
                        tst_acc = evaluateEBDCorInfoMaxHopfield(
                            model, test_loader, hopfield_g, neural_lr_start,
                            neural_lr_stop, neural_lr_rule,
                            neural_lr_decay_multiplier, neural_dynamic_iterations_free,
                            device, printing=False
                        )
                        trn_acc_list.append(trn_acc)
                        tst_acc_list.append(tst_acc)


                        Result_Dict = {
                            "setting_number": setting_number,
                            "seed": seed_,
                            "Model": "EBDv9",
                            "Hyperparams": hyperparams_dict,
                            "Trn_ACC_list": trn_acc_list,
                            "Tst_ACC_list": tst_acc_list,
                            "B":model.B,
                            "Wff":model.Wff,
                            "Wfb":model.Wfb,
                            "Web":model.Web,
                            "Web2":model.Web2,
                            "layer_bias":model.layer_bias,
                            "neurons": neurons,
                            "act_sp_list0": act_sp_list0,
                            "act_sp_list1": act_sp_list1,
                            "act_sp_list2": act_sp_list2,
                            #"Web_list": Web_list,
                            #"Web2_list": Web2_list
                        }
                        RESULTS_DF = pd.concat([RESULTS_DF, pd.DataFrame([Result_Dict])], ignore_index=True)
                        RESULTS_DF.to_pickle(os.path.join("../Results", pickle_name_for_results))
                # End of epoch: record final measurements
               

                Result_Dict = {
                    "setting_number": setting_number,
                    "seed": seed_,
                    "Model": "EBDv9",
                    "Hyperparams": hyperparams_dict,
                    "Trn_ACC_list": trn_acc_list,
                    "Tst_ACC_list": tst_acc_list,
                    "neurons": neurons,
                    "act_sp_list0": act_sp_list0,
                    "act_sp_list1": act_sp_list1,
                    "act_sp_list2": act_sp_list2,
                    #"Web_list": Web_list,
                    #"Web2_list": Web2_list
                }
                RESULTS_DF = pd.concat([RESULTS_DF, pd.DataFrame([Result_Dict])], ignore_index=True)
                RESULTS_DF.to_pickle(os.path.join("../Results", pickle_name_for_results))
                torch.save(model, '../Results/BPCIFAR10batch1DeepNew.pth')
    RESULTS_DF.to_pickle(os.path.join("../Results", pickle_name_for_results))


def main():
    parser = argparse.ArgumentParser(
        description="Train EBDCorInfoMaxHopfield with command-line hyperparameters"
    )
    # List-type arguments are parsed via ast.literal_eval using the helper below.
    def str2list(v):
        try:
            return ast.literal_eval(v)
        except Exception:
            raise argparse.ArgumentTypeError("Expected a list, got {}".format(v))
    
    parser.add_argument('--lambda_eb_list', type=float, nargs='+', default=[0.999999],
                        help="List of lambda_eb values (e.g. \"[0.99999, 0.999999]\")")
    parser.add_argument('--lambda_eb2_', type=float, default=0.99999999,
                        help="Lambda EB2 value")
    parser.add_argument('--beta', type=float, default=1,
                        help="Beta value")
    parser.add_argument('--EPS_DIV', type=float, default=0.1,
                        help="EPS_DIV value")
    parser.add_argument('--SCALE_FB', type=float, default=1.0,
                        help="Scale_FB value")
    parser.add_argument('--lr_decay_multiplier_list', type=str2list, default=[0.95],
                        help="List of lr decay multipliers")
    parser.add_argument('--neural_lr_start_list', type=str2list, default=[0.5],
                        help="List of neural lr start values")
    parser.add_argument('--neural_lr_decay_multiplier', type=float, default=0.01,
                        help="Neural lr decay multiplier")
    parser.add_argument('--neural_dynamic_iterations_nudged', type=int, default=10,
                        help="Neural dynamic iterations nudged")
    parser.add_argument('--neural_dynamic_iterations_free_list', type=int,  nargs='+', default=[50],
                        help="List of neural dynamic free iterations")
    parser.add_argument('--hopfield_g_list', type=str2list, default=[0.1],
                        help="List of hopfield g values")
    parser.add_argument('--batchsize', type=int, default=1,
                        help="Batch size")
    parser.add_argument('--n_epochs', type=int, default=25000,
                        help="Number of epochs")
    parser.add_argument('--seed_list', type=int, nargs='+', default=[104,114,124,134,144,154,164,174,184,194],
                        help="List of seed values")
    parser.add_argument('--lateral_init_scalev', type=float, default=1.0,
                        help="Lateral initialization scale")
    parser.add_argument('--lateral_scalev', type=float, nargs='+', default=[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0],
                        help="List of lateral scale values")
    parser.add_argument('--Wff_init_scalev', type=float, nargs='+', default=[1.0, 1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0],
                        help="List of Wff initialization scale values")
    parser.add_argument('--subt_meanv', type=float, default=0,
                        help="subt_mean value")
    parser.add_argument('--bias_include', type=float, default=0.0,
                        help="bias_include value")
    parser.add_argument('--include_backv', type=float, default=0.0,
                        help="Include back value")
    parser.add_argument('--momentum_ffv', type=float, default=0.9999,
                        help="Momentum FF value")
    parser.add_argument('--momentum_fbv', type=float, default=0.9999,
                        help="Momentum FB value")
    parser.add_argument('--act_l1_lr_ffv', type=float, nargs='+', default=[0.01575,0.01575, 0.01575,0.01575, 0.01575,0.01575, 0.01575, 0.01575, 0.00675, 0],
                        help="List of act l1 learning rates for feedforward")
    parser.add_argument('--layer_pow_targetv', type=float, nargs='+', default=[2.5,2.5,2.5,2.5,2,5,2.5,2.5,2.5,2,5,0.1],
                        help="List of layer power target values")
    parser.add_argument('--act_pow_lr_ffv', type=float, nargs='+', default=[0.002,0.002,0.002,0.002,0.002,0.002,0.002,0.002,0.005,1e-10],
                        help="List of activation power learning rates for feedforward")
    parser.add_argument('--lr_erv', type=float, nargs='+', default=[9.6, 9.6, 9.6, 9.6,9.6,9.6, 9.6, 9.6, 6.0, 10000.0],
                        help="List of lr_erv values")
    parser.add_argument('--lr_er2v', type=float, nargs='+', default=[0,0,0,0,0,0,0,0,0,0],
                        help="List of lr_er2v values")
    parser.add_argument('--br_update_lateralv', type=float, default=0.5/15,
                        help="BR update lateral value")
    parser.add_argument('--br_update_lateralv_cont', type=float, default=0.00333333333,
                        help="BR update lateral value")
    parser.add_argument('--lambda_list', type=float, nargs='+', default=[0.99999995],
                        help="List of lambda values")
    parser.add_argument('--lr_frvv', type=float, default=1e-18,
                        help="Learning rate for prediction filter")
    parser.add_argument('--L1_DIVIDE', type=float, default=1e4,
                        help="L1_DIVIDE value")
    parser.add_argument('--lr_weight_l1vv', type=float, nargs='+', default=[0,0,0,0,0,0,0,0,0,0],
                        help="List of lr weight l1 values")
    parser.add_argument('--lr_weight_fb_l1vv', type=float, nargs='+',default=[0,0,0,0,0,0,0,0,0,0],
                        help="List of lr weight fb l1 values")
    parser.add_argument('--lr_weight_l2v', type=float, nargs='+', default=[0,0,0,0,0,0,0,0,0,0],
                        help="List of lr weight ff l2 values")
    parser.add_argument('--lr_weight_fb_l2v', type=float, nargs='+', default=[0,0,0,0,0,0,0,0,0,0],
                        help="List of lr weight fb l2 values")
    parser.add_argument('--lr_weight_bb_l2v', type=float, nargs='+', default=[0,0,0,0,0,0,0,0,0,0],
                        help="List of lr weight bb l2 values")
    parser.add_argument('--weight_decayv', type=lambda x: (str(x).lower() in ['true','1','yes']),
                        default=True, help="Weight decay flag (True/False)")
    parser.add_argument('--L2_SCALE', type=float, default=4.0,
                        help="L2_SCALE value")
    parser.add_argument('--architecture', type=int, nargs='+', default=[32 * 32 * 3, 4000, 500, 500, 500, 500, 500, 500, 500, 500, 10]),
    parser.add_argument('--CNT_SCALE', type=float, default=3,
                        help="CNT_SCALE value")
    parser.add_argument('--pickle_name_for_results', type=str, default="CIMEBD1.pkl", help="Output filename")
    parser.add_argument('--train_start', type=float, nargs='+', default=[9,8,7,6,5,4,3,2,1,0,0],
                        help="List of lr weight l1 values")
    parser.add_argument('--load_network_weights', type=int, default=0,
                        help="Load network weights from file")
    parser.add_argument('--pickle_name_for_weights', type=str, default="TestC42param.pkl", help="Weights filename")
    parser.add_argument('--weight_prune_period', type=int, default=25000,
                        help="Number of epochs")
    parser.add_argument('--weight_prune_scale', type=float, default=0.0001,
                        help="weight_prune_scale")
    

    args = parser.parse_args()

    run_experiment(
        lambda_eb_list=args.lambda_eb_list,
        lambda_eb2_=args.lambda_eb2_,
        beta=args.beta,
        EPS_DIV=args.EPS_DIV,
        SCALE_FB=args.SCALE_FB,
        lr_decay_multiplier_list=args.lr_decay_multiplier_list,
        neural_lr_start_list=args.neural_lr_start_list,
        neural_lr_decay_multiplier=args.neural_lr_decay_multiplier,
        neural_dynamic_iterations_nudged=args.neural_dynamic_iterations_nudged,
        neural_dynamic_iterations_free_list=args.neural_dynamic_iterations_free_list,
        hopfield_g_list=args.hopfield_g_list,
        batchsize=args.batchsize,
        n_epochs=args.n_epochs,
        seed_list=args.seed_list,
        lateral_init_scalev=args.lateral_init_scalev,
        lateral_scalev=args.lateral_scalev,
        Wff_init_scalev=args.Wff_init_scalev,
        subt_meanv=args.subt_meanv,
        include_backv=args.include_backv,
        momentum_ffv=args.momentum_ffv,
        momentum_fbv=args.momentum_fbv,
        act_l1_lr_ffv=args.act_l1_lr_ffv,
        layer_pow_targetv=args.layer_pow_targetv,
        act_pow_lr_ffv=args.act_pow_lr_ffv,
        lr_erv=args.lr_erv,
        lr_er2v=args.lr_er2v,
        br_update_lateralv=args.br_update_lateralv,
        br_update_lateralv_cont=args.br_update_lateralv_cont,
        lambda_list=args.lambda_list,
        lr_frvv=args.lr_frvv,
        L1_DIVIDE=args.L1_DIVIDE,
        lr_weight_l1vv=args.lr_weight_l1vv,
        lr_weight_fb_l1vv=args.lr_weight_fb_l1vv,
        lr_weight_bb_l2=args.lr_weight_bb_l2v,
        lr_weight_l2=args.lr_weight_l2v,
        lr_weight_fb_l2=args.lr_weight_fb_l2v,
        weight_decayv=args.weight_decayv,
        L2_SCALE=args.L2_SCALE,
        CNT_SCALE=args.CNT_SCALE,
        pickle_name_for_results=args.pickle_name_for_results,
        train_startv=args.train_start,
        load_network_weightsv=args.load_network_weights,
        pickle_name_for_weightsv=args.pickle_name_for_weights,
        weight_prune_period=args.weight_prune_period,
        architecture=args.architecture,
        weight_prune_scale=args.weight_prune_scale
    )


if __name__ == "__main__":
    main()
