import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn.functional as F
import argparse
import matplotlib
import pandas as pd

from tqdm import tqdm
import glob
import os
from datetime import datetime
from itertools import product
import time
import math
import sys
sys.path.append("./src")
from BP_EBDCorInfoMaxModelsBackward11 import EBDCorInfoMaxHopfield 
from torch_utilsn import *

import warnings
warnings.filterwarnings("ignore")

from IPython.core.debugger import Pdb
import sys
sys.path.append("./src")
from IPython.display import clear_output


def my_clip(M,clevel,msg=''):
    if (torch.sum(torch.abs(M)>clevel)>0):
        print(msg)
    out=M*(abs(M)<clevel)+sign(M)*clevel*(abs(M)>=clevel)
    return out

def update_list(lr, cnt2):
    # Calculating the index to keep
    index_to_keep = np.mod(cnt2, 3)
    
    # Setting all elements to zero except the one at index_to_keep
    new_lr = [0 if i != index_to_keep else lr[index_to_keep] for i in range(len(lr))]
    
    return new_lr



device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


# Adjust file/directory organization
current_directory = os.getcwd()
working_path = current_directory
os.chdir(working_path)

if not os.path.exists("../Results"):
    os.mkdir("../Results")

pickle_name_for_results = "BPMNISTAseed6batch1New2.pkl"

RESULTS_DF = pd.DataFrame( columns = ['setting_number', 'seed', 'Model', 'Hyperparams', 'Trn_ACC_list', 'Tst_ACC_list'])

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                            torchvision.transforms.Normalize(mean=(0.0,), std=(1.0,))])

mnist_dset_train = torchvision.datasets.MNIST('data', train=True, transform=transform, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader(mnist_dset_train, batch_size=1, shuffle=True, num_workers=0)

mnist_dset_test = torchvision.datasets.MNIST('data', train=False, transform=transform, target_transform=None, download=True)
test_loader = torch.utils.data.DataLoader(mnist_dset_test, batch_size=20, shuffle=False, num_workers=0)


activation = hard_sigmoid
architecture = [784, 500, 500, 10]

#RESULTS_DF = pd.DataFrame( columns = ['setting_number', 'seed', 'Model', 'Hyperparams', 'Trn_ACC_list', 'Tst_ACC_list', 'forward_backward_weight_angle_list'])


############# HYPERPARAMS GRID SEARCH LISTS #########################

############# HYPERPARAMS GRID SEARCH LISTS #########################

############# HYPERPARAMS GRID SEARCH LISTS #########################

############# HYPERPARAMS GRID SEARCH LISTS #########################

############# HYPERPARAMS GRID SEARCH LISTS #########################
beta = 1

lambda_eb_list=[0.99999]
EPS_DIV=100000
epsilon = 0.15/EPS_DIV
one_over_epsilon = 1 / epsilon

SCALE_FB=1.0 
lr_start_list = [{'ff': np.array([0.11, 0.06, 0.035]), 'fb': np.array([ np.nan, 1.125*SCALE_FB, 0.375*SCALE_FB])}] 

lr_decay_multiplier_list = [0.95]

lr_decay_multiplier_list = [0.95]
neural_lr_start_list = [0.05/EPS_DIV]
neural_lr_stop = 0.001/EPS_DIV
neural_lr_rule_list = ["constant"]
neural_lr_decay_multiplier = 0.01
neural_dynamic_iterations_nudged = 10
neural_dynamic_iterations_free_list = [50] #30]
hopfield_g_list = [0.1]
use_random_sign_beta = False
use_three_phase_list = [False]
setting_number = 0
batchsize=1
n_epochs = 250
seed_list = [10*(j+6) for j in range(10)]


lateral_init_scalev=1.0
lateral_scalev=[1.0,1.0,1.0]
Wff_init_scalev=[1.0,1.0,1.0]
subt_meanv=0 
include_backv=1.0
momentum_ffv=0.999
momentum_fbv=0.999
momentum_ffv=0.9999
momentum_fbv=0.9999

#######################################################
#      ACTIVATION SPARSITY                         ####
#######################################################
act_l1_lr_ffv=[7e-2*3*1.5,3e-2*3*1.5,0] 
act_l1_lr_ffv=[7e-2*3*1.5/20,3e-2*3*1.5/20,0]
act_l1_lr_fbv=[0,act_l1_lr_ffv[0],act_l1_lr_ffv[1]] 

#######################################################
#     POWER NORMALIZATION                            ##
#######################################################
layer_pow_targetv=[2.5,2,5,0.1]  
act_pow_lr_ffv=[4e-2,1e-1,1e-10]  
act_pow_lr_ffv=[4e-2/20,1e-1/20,1e-10] 

########################################################
# THREE FACTOR LEARNING     LRs                        #
########################################################
lr_erv=[40*2,25*2,1e5]  # A1f [0]*2 [1]*2  #A1g additional [0]*2 [1]*2 #A1h removed *2s' in g
lr_erv=[40*2/20,25*2/20,1e5/20]
lr_er2v=[0,0,0]
br_update_lateralv=0.0005/100*100*10*10*10 # *100 # QW4 *100 QW5 /100 QW7 *100 #A1d *10 #A1e *10 #A1f *10
br_update_lateralv=0.5/20
#########################################################
#   LATERAL COVARAINCE INVERSE UPDATE                  #
########################################################

lambda_list =[0.999999] #[0.9999] #[0.999999] QW3 [0.9999]->[0.999999]
lambda_list =[0.99999995]
###############################
# PREDICTION FILTER LR #######
###############################
lr_frvv=1e-6/1e2/1e10 #A1h make lr_frvv very small by /1e10




############################################
# WEIGHT SPARSITY ###############
##################################


L1_DIVIDE=1e4
#lr_weight_l1vv=[3e-4/L1_DIVIDE/2/5/4,1e-4*10/L1_DIVIDE*4*2*5*100,1e-4/L1_DIVIDE]
#lr_weight_fb_l1vv=[0,3e-5/50,1e-5/5]
lr_weight_l1vv=[0,0,0] 
lr_weight_fb_l1vv=[0,0,0] 

#######################################
#      WEIGHT DECAY                  #
######################################
weight_decayv=True
L2_SCALE=8.0
L2_SCALE=8.0/20
lr_weight_l2vv=[1e-4*L2_SCALE,1e-4*L2_SCALE,1e-4*L2_SCALE]
lr_weight_fb_l2vv=[0,1e-4*L2_SCALE,1e-4*L2_SCALE]

###########################################
#       GRADIENT CLIPPING                ##
###########################################
#ff_grad_clipv=[0.01,0.01,0.01]
#fb_grad_clipv=[0.01,0.01,0.01]
#bb_grad_clipv=[0.01,0.01,0.01]
ff_grad_clipv=[10,10,10]
fb_grad_clipv=[10,10,10]
bb_grad_clipv=[10,10,10]
ff_weight_clipv=[100,100,100]
fb_weight_clipv=[100,100,100]
bb_weight_clipv=[100,100,100]


spw=[0,0,0]
spw_fb=[0,0,0]
act_sp=[0,0,0]
act_sp_list0=[0]
act_sp_list1=[0]
act_sp_list2=[0]
wff_sp_list0=[0]
wff_sp_list1=[0]
wff_sp_list2=[0]
wfb_sp_list0=[0]
wfb_sp_list1=[0]



for lambda_, lambda_eb_, lr_start, lr_decay_multiplier, neural_lr_start, neural_lr_rule, neural_dynamic_iterations_free, hopfield_g, use_three_phase in product(lambda_list, lambda_eb_list,lr_start_list, lr_decay_multiplier_list, neural_lr_start_list, neural_lr_rule_list, neural_dynamic_iterations_free_list, hopfield_g_list, use_three_phase_list):
    setting_number += 1
    hyperparams_dict = {"lr_start" : lr_start, "lr_decay_multiplier" : lr_decay_multiplier,
                        "neural_dynamic_iterations_free" : neural_dynamic_iterations_free,
                        "neural_dynamic_iterations_nudged" : neural_dynamic_iterations_nudged, 
                        "neural_lr_rule" : neural_lr_rule, "neural_lr" : neural_lr_start, 
                        "epsilon" : epsilon, "lambda" : lambda_,"lambda_eb": lambda_eb_,
                        "architecture" : architecture,
                        "three_phase" : use_three_phase}
    for seed_ in seed_list:
        np.random.seed(seed_)
        torch.manual_seed(seed_)
        trn_acc_est=[]
        tr_ar=0
        trn_acc_list = []
        tst_acc_list = []

         #  Train  Test
        model = EBDCorInfoMaxHopfield(architecture = architecture, lambda_ = lambda_, lambda_eb_=lambda_eb_,
                                              epsilon = epsilon, lr_fr=lr_frvv, lr_er=lr_erv,lr_er2=lr_er2v, include_forw=1,include_back=include_backv, \
                                      use_preact=0, subt_mean=subt_meanv, br_update_lateral=br_update_lateralv,activation = activation,\
                                      act_l1_lr_ff=act_l1_lr_ffv, act_l1_lr_fb=act_l1_lr_fbv, momentum_ff=momentum_ffv,momentum_fb=momentum_fbv, lateral_init_scale=lateral_init_scalev,
                                      lateral_scale=lateral_scalev,Wff_init_scale=Wff_init_scalev,\
                                      layer_pow_target=layer_pow_targetv,act_pow_lr_ff=act_pow_lr_ffv)
        debug_iteration_point = 1
        cnt=0
        cnt2=0
        for epoch_ in range(n_epochs):

            model.args_w=100*torch.rand(1).item()+0.5
            model.args_ph=3.14*torch.rand(1).item()
            model.non_fun_der_scale=1.0
            for idx, (x, y) in tqdm(enumerate(train_loader)):
                cnt=cnt+1
                #if np.mod(cnt,10)==9:
                if np.mod(cnt,200)==199: #np.mod(cnt,500)==499:
                    cnt2=cnt2+1
                    #model.momentum_ff=0.3+0.65*torch.rand(1).item()

                if np.mod(cnt,100)==99:
                    model.args_w=100*torch.rand(1).item()+0.5
                    model.args_ph=3.14*torch.rand(1).item()
                    model.non_fun_der_scale=1.0
                x, y = x.to(device), y.to(device)
                x = x.view(x.size(0),-1).T
                y_one_hot = F.one_hot(y, 10).to(device).T
                take_debug_logs_ = (idx % 500 == 0)
                CNT_SCALE=5*4/20*3 #A1a 5*4/20*5 #A1b 5*4/20  #A1c 5*4/20*2 #A1d *3
                model.momentum_ff=1/(cnt2+1)*0.999+(1-1/(cnt2+1))*0.9999       # A1b
                model.momentum_fb=1/(cnt2+1)*0.999+(1-1/(cnt2+1))*0.9999      # A1b
                if epoch_<400:    
                    lr = {'ff' : lr_start['ff']/(CNT_SCALE*cnt2+1), 'fb' : lr_start['fb']/(CNT_SCALE*cnt2+1)} 

                if epoch_<400: 
                    lr_weight_l1v=[val/(CNT_SCALE*cnt2/30*30+1) for val in lr_weight_l1vv]# QW10  *30
                    lr_weight_fb_l1v=[val/(CNT_SCALE*cnt2/30*30+1) for val in lr_weight_fb_l1vv]#QW10  *30
                    lr_weight_l2v=[val/(CNT_SCALE*cnt2/30/10+1) for val in lr_weight_l2vv]
                    lr_weight_fb_l2v=[val/(CNT_SCALE*cnt2/30/10+1) for val in lr_weight_fb_l2vv]
                    model.br_update_lateral=br_update_lateralv*lr['ff'][0]/lr_start['ff'][0]/(1/(CNT_SCALE*cnt2/1e3+1))  #/np.sqrt(CNT_SCALE*cnt2/30+1) QW#9 
                    model.lr_er=[val/(1/(CNT_SCALE*cnt2/1e3+1)) for val in lr_erv] # Q1 /10 #Q2 /1000 #Q3 /np->* np /1000->/10000
                                                                            # QW1 1e5->1e6 # QW2 1e6->1e4 QW6 1e4->1e3 QW8 1e3->1e4 QW9 1e4->1e3
                
                neurons = model.batch_step_hopfield(x, y_one_hot, hopfield_g, 
                                                    lr, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                                                    neural_lr_decay_multiplier, neural_dynamic_iterations_free,
                                                    neural_dynamic_iterations_nudged, beta, 
                                                    use_three_phase, take_debug_logs_,weight_decay = weight_decayv, lr_weight_l1=lr_weight_l1v, lr_weight_fb_l1=lr_weight_fb_l1v,lr_weight_l2=lr_weight_l2v, lr_weight_fb_l2=lr_weight_fb_l2v)
                q=torch.argmax(neurons[2],axis=0)
                tr_ar=0.99*tr_ar+0.01*torch.sum(1.0*(q==y)).item()/batchsize
                act_sp_list0.append(0.99*act_sp_list0[cnt-1]+0.01*(torch.sum(neurons[0]==0)/neurons[0].numel()).item())
                act_sp_list1.append(0.99*act_sp_list1[-1]+0.01*(torch.sum(neurons[1]==0)/neurons[1].numel()).item())
                act_sp_list2.append(0.99*act_sp_list2[-1]+0.01*(torch.sum(neurons[2]==0)/neurons[2].numel()).item())
                Freq=19
                DIV_SCALE=[12*L1_DIVIDE,10*L1_DIVIDE,15*L1_DIVIDE]
                DIV_SCALE=[12*L1_DIVIDE/40*5,10*L1_DIVIDE/20*15,15*L1_DIVIDE]
                ''' # Version 5 commented out pruning
                if (np.mod(idx,Freq+1)==Freq):
                    W=model.Wff[0]['weight']
                    peak_magnitude = torch.max(torch.abs(W))
                    #Step 2: Calculate 1/10th of the peak magnitude
                    threshold = peak_magnitude /DIV_SCALE[0]*1.5
                    # Step 3: Create a mask for elements less than the threshold
                    mask = torch.abs(W) < threshold
                    #Step 4: Set elements below threshold to zero
                    model.Wff[0]['weight'][mask] = 0
                    spw[0]=torch.sum(model.Wff[0]['weight']==0)/(model.Wff[0]['weight']).numel()
                    wff_sp_list0.append(0.9*wff_sp_list0[-1]+0.1*spw[0].item())
                if (np.mod(idx,Freq+1)==Freq):
                    W=model.Wfb[1]['weight']
                    peak_magnitude = torch.max(torch.abs(W))
                    #Step 2: Calculate 1/10th of the peak magnitude
                    threshold = peak_magnitude /DIV_SCALE[0]/1.3
                    # Step 3: Create a mask for elements less than the threshold
                    mask = torch.abs(W) < threshold
                    #Step 4: Set elements below threshold to zero
                    model.Wfb[1]['weight'][mask] = 0
                    spw_fb[0]=torch.sum(model.Wfb[1]['weight']==0)/(model.Wfb[1]['weight']).numel()
                    wfb_sp_list0.append(0.9*wfb_sp_list0[-1]+0.1*spw_fb[0].item())
                if (np.mod(idx,Freq+1)==Freq):
                    W=model.Wfb[2]['weight']
                    peak_magnitude = torch.max(torch.abs(W))
                    #Step 2: Calculate 1/10th of the peak magnitude
                    threshold = peak_magnitude /DIV_SCALE[0]
                    # Step 3: Create a mask for elements less than the threshold
                    mask = torch.abs(W) < threshold
                    #Step 4: Set elements below threshold to zero
                    model.Wfb[2]['weight'][mask] = 0
                    spw_fb[1]=torch.sum(model.Wfb[2]['weight']==0)/(model.Wfb[2]['weight']).numel()
                    wfb_sp_list1.append(0.9*wfb_sp_list1[-1]+0.1*spw_fb[1].item())
                if (np.mod(idx,Freq+1)==Freq):
                    W=model.Wff[1]['weight']
                    peak_magnitude = torch.max(torch.abs(W))
                    #Step 2: Calculate 1/10th of the peak magnitude
                    threshold = peak_magnitude / DIV_SCALE[1]*8.0*2.0
                    # Step 3: Create a mask for elements less than the threshold
                    mask = torch.abs(W) < threshold
                    #Step 4: Set elements below threshold to zero
                    model.Wff[1]['weight'][mask] = 0
                    spw[1]=torch.sum(model.Wff[1]['weight']==0)/(model.Wff[1]['weight']).numel()
                    wff_sp_list1.append(0.9*wff_sp_list1[-1]+0.1*spw[1].item())
                if (np.mod(idx,Freq+1)==Freq):
                    W=model.Wff[2]['weight']
                    peak_magnitude = torch.max(torch.abs(W))
                    #Step 2: Calculate 1/10th of the peak magnitude
                    threshold = peak_magnitude / DIV_SCALE[2]*4.0
                    # Step 3: Create a mask for elements less than the threshold
                    mask = torch.abs(W) < threshold
                    #Step 4: Set elements below threshold to zero
                    model.Wff[2]['weight'][mask] = 0
                    spw[2]=torch.sum(model.Wff[2]['weight']==0)/(model.Wff[2]['weight']).numel()
                    wff_sp_list2.append(0.9*wff_sp_list2[-1]+0.1*spw[2].item())
                    '''

                    
                    
                    
                    
            
            trn_acc = evaluateEBDCorInfoMaxHopfield(model, train_loader, hopfield_g, neural_lr_start, 
                                                            neural_lr_stop, neural_lr_rule, 
                                                            neural_lr_decay_multiplier, 
                                                            neural_dynamic_iterations_free, 
                                                            device, printing = False)
            tst_acc = evaluateEBDCorInfoMaxHopfield(model, test_loader, hopfield_g, neural_lr_start, 
                                                            neural_lr_stop, neural_lr_rule, 
                                                            neural_lr_decay_multiplier, 
                                                            neural_dynamic_iterations_free, 
                                                            device, printing = False)
            
            trn_acc_list.append(trn_acc)
            tst_acc_list.append(tst_acc)
            Wff0=model.Wff[0]['weight']
            Wff1=model.Wff[1]['weight']
            Wff2=model.Wff[2]['weight']
            Wfb0=model.Wfb[0]['weight']
            Wfb1=model.Wfb[1]['weight']
            Wfb2=model.Wfb[2]['weight']
            Web0=model.Web[0]['weight']
            Web1=model.Web[1]['weight']
            Web2=model.Web[2]['weight']
            B0=model.B[0]['weight']
            B1=model.B[1]['weight']
            B2=model.B[2]['weight']

            Result_Dict = {"setting_number" : setting_number, "seed" : seed_, "Model" : "EBDv9", 
                        "Hyperparams" : hyperparams_dict, "Trn_ACC_list" : trn_acc_list, "Tst_ACC_list" : tst_acc_list, \
                           "Wff0":Wff0, "Wff1" : Wff1, "Wff2" : Wff2, "Wfb0": Wfb0, "Wfb1": Wfb1, "Wfb2": Wfb2, "Web0": Web0, "Web1": Web1, "Web2": Web2,"B0": B0, "B1": B1, "B2": B2, "neurons":neurons,"spw":spw,"act_sp_list0":act_sp_list0, "act_sp_list1":act_sp_list1, "act_sp_list2":act_sp_list2,"wff_sp_list0":wff_sp_list0, "wff_sp_list1":wff_sp_list1,"wff_sp_list2":wff_sp_list2  ,"wfb_sp_list0":wfb_sp_list0, "wfb_sp_list1":wfb_sp_list1                   }

            RESULTS_DF = pd.concat([RESULTS_DF, pd.DataFrame([Result_Dict])], ignore_index=True)
            RESULTS_DF.to_pickle(os.path.join("../Results", pickle_name_for_results))
            torch.save(model, '../Results/BPMNISTAseed6batch1New2.pth')
            

            

       #Result_Dict = {"setting_number" : setting_number, "seed" : seed_, "Model" : "CorInfoMax", 
       #                 "Hyperparams" : hyperparams_dict, "Trn_ACC_list" : trn_acc_list, "Tst_ACC_list" : tst_acc_list,
       #                 "forward_backward_weight_angle_list" : model.forward_backward_angles}

        #RESULTS_DF = RESULTS_DF.append(Result_Dict, ignore_index = True)
        #RESULTS_DF.to_pickle(os.path.join("../Results", pickle_name_for_results))

RESULTS_DF.to_pickle(os.path.join("../Results", pickle_name_for_results))