import sys

sys.path.append('../../TINY/')
import TINY
import UTILS
import math
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import copy
import load_data_Loader
import GLOBALS
import SOLVE_EB as EB
from define_devices import my_device_0, my_device_1

from settings import starting_architecture_RN18 as starting_architecture

GLOBALS.reduction = 64
GLOBALS.rescale = 'theta'
GLOBALS.lr = 1e-2
GLOBALS.batch_size = 32
GLOBALS.architecture_growth = 'Our'
GLOBALS.lambda_method = 1e-3
nbr_epochs_betw_adding = 0.25

skip_connections = starting_architecture.skip_connections_RN18
for (i, j) in zip(skip_connections['in'], skip_connections['out']):
    print('skip connection between layer', i, ' and ', j)

skip_fct = starting_architecture.skip_fct_RN18
skip_fct.keys()

skeleton = starting_architecture.skeleton_RN18
layer_name = starting_architecture.layer_name_RN18

print('Reduction is :', GLOBALS.reduction)
print('depth  |  layer type | in_chan->out_chan / output size')
for k in sorted(list(skeleton.keys()))[1:]:
    if layer_name[k][0] == 'C':
        print('depth ' + str(k) + ' |  Conv | ' + str(skeleton[k]['in_channels']) + ' -> ' + str(
            skeleton[k]['out_channels']))
        # print('depth ' + str(k) + ' |  Conv | ' + str(RN.outputs_size_after_activation[k]))
    else:
        print('depth ' + str(k) + ' | Linear | ' + str(skeleton[k]['size']))
        # print('depth ' + str(k) + ' | Linear | ' + str(RN.outputs_size_after_activation[k]))

fct = starting_architecture.fct_RN18
for k in sorted(list(fct.keys())):
    print('fct_', k, ' := ', fct[k])

a, b, c, d = starting_architecture.a, starting_architecture.b, starting_architecture.c, starting_architecture.d
a0, b0, c0, d0 = starting_architecture.a0, starting_architecture.b0, starting_architecture.c0, starting_architecture.d0
ainf, binf, cinf, dinf = starting_architecture.ainf, starting_architecture.binf, starting_architecture.cinf, starting_architecture.dinf

print('architecture growth from ', (a0, b0, c0, d0), 'to', (a, b, c, d))

to_add = {2: (ainf - a0), 4: (ainf - a0), 7: (binf - b0), 9: (binf - b0),
          12: (cinf - c0), 14: (cinf - c0), 17: (dinf - d0), 19: (dinf - d0)}

dico_parameters = {
    'skeleton': copy.deepcopy(skeleton),
    'Loss': UTILS.Loss_entropy,
    'fct': fct,
    'layer_name': layer_name,
    'init_deplacement': 1e-8,
    'batch_size': GLOBALS.batch_size,
    'lr': GLOBALS.lr,
    'rescale': GLOBALS.rescale,
    'exp': 2,
    'lambda_method': GLOBALS.lambda_method,
    'accroissement_decay': 1e-8,
    'lu_conv': 0.001,
    'max_batch_estimation': 100,
    'max_amplitude': 20.,
    'ind_lmbda_shape': 1000,
    'init_X_shape': [3, 32, 32],
    'skip_connections': skip_connections,
    'skip_fcts': skip_fct,
    'len_train_dataset': 50000,
    'len_test_dataset': 10000,
    'T_j_depth': [d for d in list(to_add.keys()) if layer_name[d][0] == 'C'],
    'selection_neuron': UTILS.selection_neuron_seuil,
    'how_to_define_batchsize': UTILS.indices_non_constant,
    'depth_seuil': starting_architecture.depth_seuil,
    'architecture_growth': GLOBALS.architecture_growth
}
RN = TINY.TINY(dico_parameters)

RN.training_data, RN.test_data = load_data_Loader.load_database_CIFAR100(AugD=True)
RN.tr_loader = DataLoader(RN.training_data, batch_size=RN.max_batch_estimation, shuffle=True)
RN.te_loader = DataLoader(RN.test_data, batch_size=RN.max_batch_estimation, shuffle=True)

print('depth  |  layer type | output size')
for k in sorted(list(skeleton.keys()))[1:]:
    if layer_name[k][0] == 'C':
        # print('depth ' + str(k) + ' |  Conv | ' + str(skeleton[k]['in_channels']) + ' -> ' + str(skeleton[k]['out_channels']))
        print('depth ' + str(k) + ' |  Conv | ' + str(RN.outputs_size_after_activation[k]))
    else:
        # print('depth ' + str(k) + ' | Linear | ' + str(skeleton[k]['size']))
        print('depth ' + str(k) + ' | Linear | ' + str(RN.outputs_size_after_activation[k]))

df_tracker = pd.DataFrame()
A_tr, A_te, L_tr, L_te, T = np.array([]), np.array([]), np.array([]), np.array([]), np.array([0])
nbr_pass = 8

path = 'resultats/'


def stabilize_training():
    RN.batch_size = math.ceil(np.sqrt(RN.count_parameters() / nbr_parameters_avant) * RN.batch_size)


nbr_pass = 1
RN.force_small_estimation_batch = True


count = 1
for pas in range(nbr_pass):
    size_to_add = len(to_add.keys())
    # size_to_add = 1
    for j in tqdm(range(size_to_add)):
        nbr_parameters_avant = RN.count_parameters()
        ########### Select the best depth (time consuming)############
        # gc.collect()
        # torch.cuda.empty_cache()
        # RN.tr_loader = DataLoader(RN.training_data, batch_size=RN.max_batch_estimation, shuffle=True)
        # RN.te_loader = DataLoader(RN.test_data, batch_size=RN.max_batch_estimation, shuffle=True)

        # depth_ajout = list(to_add.keys())
        # depth_in_decreasing_criterion, dico_EB = EB.where_is_EB_best_solved(RN, depths = depth_ajout)

        # best_depth = depth_in_decreasing_criterion[0]
        # dico_EB_bd = dico_EB[best_depth]
        # alpha, omega, bias_alpha, vps = dico_EB_bd['alpha'], dico_EB_bd['omega'], dico_EB_bd['bias_alpha'], dico_EB_bd['vps']
        # lambda_method = dico_EB_bd['beta_min']

        # RN.alpha, RN.omega, RN.bias_alpha, RN.valeurs_propres = None, None, None, []
        # RN.TAB_Add = None

        # if dico_EB_bd['accroissement'] > 0 :
        #    RN.dico_w, RN.lambda_method = dico_EB_bd['dico_w'], dico_EB_bd['beta_min']
        #    EB.add_neurons(RN, best_depth, alpha = alpha, omega = omega, bias_alpha = bias_alpha, valeurs_propres = vps)
        #    RN.lambda_method = torch.tensor(0., device = my_device_0)

        #    to_add[best_depth] -= alpha.shape[0]
        #    if to_add[best_depth] == 0 :
        #        del to_add[best_depth]

        ######## Or just add neurons in depth order (fast) ######
        k = list(to_add.keys())[j]
        best_depth, dico_EB = k, {k: {'accroissement': 0., 'portion_gain': 0}}
        RN.dico_w = None
        RN.how_to_define_batchsize(RN, k + 1, method='NG')
        EB.compute_optimal_update(RN, k + 1, update=False, compute_gain=False)
        RN.how_to_define_batchsize(RN, k, method='Add')
        EB.add_neurons(RN, k, update=True)
        to_add[k] -= RN.alpha.shape[0]

        #### training loop ####
        stabilize_training()
        RN.tr_loader = DataLoader(RN.training_data, batch_size=RN.batch_size, shuffle=True)
        RN.te_loader = DataLoader(RN.test_data, batch_size=RN.batch_size, shuffle=True)
        optimizer = torch.optim.SGD(RN.parameters(), lr=RN.lr)
        l_tr, l_te, l_va, a_tr, a_te, a_va, t = RN.train_batch(nbr_epochs_betw_adding, optimizer=optimizer)
        # update_quantity_of_interest()

        # df_performance = pd.DataFrame.from_dict({'L_tr': L_tr, 'L_te': L_te, 'A_tr': A_tr, 'A_te': A_te, 'T': T[1:]})
        # df_tracker.to_csv(path + 'df_tracker.csv', index=False)
        # df_performance.to_csv(path + '/df_performance.csv', index=False)
        RN.T = T[-1]
        RN.len_L_tr = len(L_tr)

    UTILS.save_model_to_file(RN, path=path + '/', name='model_' + str(count))
    count += 1
