#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import copy
import time
print(torch.__version__)
#from tensorflow.keras.datasets import mnist
activity = {}
activity_out = {i : [] for i in range(1,  + 1)}

dico_T = {}
dico_S = {}
import copy

if torch.cuda.is_available():
    device = torch.device('cuda')
    my_device = 'cuda'
else:
    device = torch.device('cpu')
    my_device = 'cpu'
    
import load_data  


print(my_device, 'my_device')
    
def call_load_data(name, normalize = True) :
    if name == 'MNIST' :
        load_data.load_database_MNIST()
    if name == 'CIFAR' :
        load_data.load_database_CIFAR10()

    global X_train_rescale
    global Y_train_rescale
    
    global X_test_rescale
    global Y_test_rescale
    
    
    X_train_rescale_cpu = load_data.X_train_rescale
    Y_train_rescale_cpu = load_data.Y_train_rescale
    
    X_test_rescale_cpu = load_data.X_test_rescale
    Y_test_rescale_cpu = load_data.Y_test_rescale
    
    X_train_rescale = X_train_rescale_cpu.to(device)
    Y_train_rescale = Y_train_rescale_cpu.to(device)
    
    X_test_rescale = X_test_rescale_cpu.to(device)
    Y_test_rescale = Y_test_rescale_cpu.to(device)
    
    if normalize :
        mean = X_train_rescale.mean(dim = 0)
        var = X_train_rescale.var(dim = 0)
        ind_zero = var == 0
        ind_zero = ind_zero.to(device)
        sqrt = torch.sqrt(var + ind_zero.float())
        
        X_train_rescale = (X_train_rescale - mean)/sqrt
        X_test_rescale = (X_test_rescale - mean)/sqrt
    
    
    del X_train_rescale_cpu, Y_train_rescale_cpu
    del X_test_rescale_cpu, Y_test_rescale_cpu
    
def calculate_accuracy(y_pred, y_true) :
    _, inds_pred = torch.max(y_pred, dim = 1)
    _, inds_true = torch.max(y_true, dim = 1)
    
    comparaison = (inds_pred == inds_true).int()
    acc = comparaison.sum()/len(comparaison)
    acc_cpu = acc.cpu().item()
    del acc
    del inds_pred, inds_true
    return(acc_cpu)



def permute_sparse(input, dims):
    dims = torch.LongTensor(dims)
    return torch.sparse_coo_tensor(indices=input._indices()[dims], values=input._values(), size=torch.Size(torch.tensor(input.size())[dims]))


def inv_APPLY(GTG, DVTV, precision = 1e-2) :
    y_true = torch.randn((DVTV.shape[0], GTG.shape[0]), device = my_device)
    y_true.requires_grad = True
    
    opti = torch.optim.LBFGS([y_true], lr=1)
    
    def partial_loss () :
        opti.zero_grad()
        loss = ((torch.matmul(y_true, GTG) - DVTV)**2).mean()
        loss.backward()
        return(loss)
    
    for j in range(y_true.shape[0]* y_true.shape[1]) :
        if np.mod(j, 100) == 0 :
            print(j)
            print(partial_loss())
        opti.step(partial_loss)
        
    print(partial_loss(), 'dernier_step')
    return(y_true)

class MyModel_copy_L_L(torch.nn.Module) :
    def __init__(self, model_to_copy, alpha, omega, bias_alpha) :
        super(MyModel_copy_L_L, self).__init__()
        
        self.init_structure = copy.deepcopy(model_to_copy.init_structure)
        self.deep = copy.deepcopy(model_to_copy.deep)
        self.Loss = copy.deepcopy(model_to_copy.Loss)
        self.layer_name = copy.deepcopy(model_to_copy.layer_name)
        self.layer = {}
        for j in range(1, self.deep + 1) :
            if self.layer_name[j] == 'L' :
                self.layer[j] = torch.nn.Linear(self.init_structure[j-1]['size'], self.init_structure[j]['size'], bias = True, device = my_device)
                self.layer[j].weight = copy.deepcopy(model_to_copy.layer[j].weight)
                self.layer[j].bias = copy.deepcopy(model_to_copy.layer[j].bias)
            if self.layer_name[j] == 'C' :
                self.layer[j] = torch.nn.Conv2d(self.init_structure[j]['in_channel'], self.init_structure[j]['out_channel'], 
                                                self.init_structure[j]['kernel_size'], device = my_device)
                self.layer[j].weight = copy.deepcopy(model_to_copy.layer[j].weight)
                self.layer[j].bias = copy.deepcopy(model_to_copy.layer[j].bias)
            
        self.where = copy.deepcopy(model_to_copy.where)
        self.decrease = copy.deepcopy(model_to_copy.decrease)
        self.init_deplacement = model_to_copy.init_deplacement
        self.fct = model_to_copy.fct
        self.layer = model_to_copy.layer
        self.break_conv = model_to_copy.break_conv
        self.outputs_size_before_activation = model_to_copy.outputs_size_before_activation
        self.alpha = alpha
        self.omega = omega
        self.bias_alpha = bias_alpha
        

                
    def forward(self, x, lmbda, deep_lmbda) :
        self.where = 1
        while self.where <= self.deep and self.layer_name[self.where] == 'C' :
            x = self.fct[self.where](self.layer[self.where](x))    
            self.where += 1
        
        x = x.flatten(start_dim = 1)

        
        while self.where <= self.deep :
            if self.where == deep_lmbda :
                m = torch.nn.Linear(self.layer[self.where].in_features, self.alpha.shape[0], device = my_device)
                
                m.weight =  torch.nn.parameter.Parameter(self.alpha, requires_grad = False)
                m.bias =  torch.nn.parameter.Parameter(self.bias_alpha, requires_grad = False)
                
                x_en_plus  =  self.fct[self.where](lmbda * m(x))
                
                x = self.fct[self.where](self.layer[self.where](x))
                
                self.where += 1
                
    
                
                m = torch.nn.Linear(self.layer[self.where].in_features + x_en_plus.shape[1], self.layer[self.where].out_features, device = my_device)
                m.weight = torch.nn.parameter.Parameter(torch.cat([self.layer[self.where].weight, self.omega], dim = 1), requires_grad = False)
                m.bias = torch.nn.parameter.Parameter(self.layer[self.where].bias, requires_grad = False)
               
                x = self.fct[self.where](m(torch.cat([x, x_en_plus], dim = 1)))
                
                
                
            else :
                x = self.layer[self.where](x)
                x = self.fct[self.where](x)
            self.where += 1
        
        del m, x_en_plus
        return x
class MyModel_copy_C_L(torch.nn.Module) :
    
    def __init__(self, model_to_copy, alpha, omega, bias_alpha) :
        super(MyModel_copy_C_L, self).__init__()
        
        self.init_structure = copy.deepcopy(model_to_copy.init_structure)
        self.deep = copy.deepcopy(model_to_copy.deep)
        self.Loss = copy.deepcopy(model_to_copy.Loss)
        self.layer_name = copy.deepcopy(model_to_copy.layer_name)
        self.layer = {}
        for j in range(1, self.deep + 1) :
            if self.layer_name[j] == 'L' :
                self.layer[j] = torch.nn.Linear(self.init_structure[j-1]['size'], self.init_structure[j]['size'], bias = True, device = my_device)
                self.layer[j].weight = copy.deepcopy(model_to_copy.layer[j].weight)
                self.layer[j].bias = copy.deepcopy(model_to_copy.layer[j].bias)
            if self.layer_name[j] == 'C' :
                self.layer[j] = torch.nn.Conv2d(self.init_structure[j]['in_channel'], self.init_structure[j]['out_channel'], 
                                                self.init_structure[j]['kernel_size'], device = my_device)
                self.layer[j].weight = copy.deepcopy(model_to_copy.layer[j].weight)
                self.layer[j].bias = copy.deepcopy(model_to_copy.layer[j].bias)
            
        self.where = copy.deepcopy(model_to_copy.where)
        self.decrease = copy.deepcopy(model_to_copy.decrease)
        self.init_deplacement = model_to_copy.init_deplacement
        self.fct = model_to_copy.fct
        self.layer = model_to_copy.layer
        self.break_conv = model_to_copy.break_conv
        self.outputs_size_before_activation = model_to_copy.outputs_size_before_activation
        self.alpha = alpha
        self.omega = omega
        self.bias_alpha = bias_alpha
        

                
    def forward(self, x, lmbda, deep_lmbda) :
        self.where = 1
        while self.where <= self.deep and self.layer_name[self.where] == 'C' :
            if self.where == deep_lmbda :
                m = torch.nn.Conv2d(self.init_structure[self.where]['in_channel'], self.alpha.shape[0], self.init_structure[self.where]['kernel_size'], device = my_device)
                
                m.weight =  torch.nn.parameter.Parameter(self.alpha, requires_grad = False)
                m.bias =  torch.nn.parameter.Parameter(self.bias_alpha, requires_grad = False)
                x_en_plus  = self.fct[self.where](lmbda * m(x))
                
                x = self.fct[self.where](self.layer[self.where](x))
                
                self.where += 1
                
                x = x.flatten(start_dim = 1)
                x_en_plus = x_en_plus.flatten(start_dim = 1)
                
                
                m = torch.nn.Linear(self.layer[self.where].in_features + x_en_plus.shape[1], self.layer[self.where].out_features, device = my_device)
                m.weight = torch.nn.parameter.Parameter(torch.cat([self.layer[self.where].weight, self.omega], dim = 1), requires_grad = False)
                m.bias = torch.nn.parameter.Parameter(self.layer[self.where].bias, requires_grad = False)
               
                x = self.fct[self.where](m(torch.cat([x, x_en_plus], dim = 1)))
                
            else :
                x = self.fct[self.where](self.layer[self.where](x))
                
            self.where += 1
        

        
        while self.where <= self.deep :
            x = self.layer[self.where](x)
            x = self.fct[self.where](x)
            self.where += 1
        del x_en_plus, m
        return x


    
class MyModel_copy_C_C(torch.nn.Module) :
    
    def __init__(self, model_to_copy, alpha, omega, bias_alpha) :
        super(MyModel_copy_C_C, self).__init__()
        
        self.init_structure = copy.deepcopy(model_to_copy.init_structure)
        self.deep = copy.deepcopy(model_to_copy.deep)
        self.Loss = copy.deepcopy(model_to_copy.Loss)
        self.layer_name = copy.deepcopy(model_to_copy.layer_name)
        self.layer = {}
        for j in range(1, self.deep + 1) :
            if self.layer_name[j] == 'L' :
                self.layer[j] = torch.nn.Linear(self.init_structure[j-1]['size'], self.init_structure[j]['size'], bias = True, device = my_device)
                self.layer[j].weight = copy.deepcopy(model_to_copy.layer[j].weight)
                self.layer[j].bias = copy.deepcopy(model_to_copy.layer[j].bias)
            if self.layer_name[j] == 'C' :
                self.layer[j] = torch.nn.Conv2d(self.init_structure[j]['in_channel'], self.init_structure[j]['out_channel'], 
                                                self.init_structure[j]['kernel_size'], device = my_device)
                self.layer[j].weight = copy.deepcopy(model_to_copy.layer[j].weight)
                self.layer[j].bias = copy.deepcopy(model_to_copy.layer[j].bias)
            
        self.where = copy.deepcopy(model_to_copy.where)
        self.decrease = copy.deepcopy(model_to_copy.decrease)
        self.init_deplacement = model_to_copy.init_deplacement
        self.fct = model_to_copy.fct
        self.layer = model_to_copy.layer
        self.break_conv = model_to_copy.break_conv
        self.outputs_size_before_activation = model_to_copy.outputs_size_before_activation
        self.alpha = alpha
        self.omega = omega
        self.bias_alpha = bias_alpha
        

                
    def forward(self, x, lmbda, deep_lmbda) :
        self.where = 1
        while self.where <= self.deep and self.layer_name[self.where] == 'C' :
            
            if self.where == deep_lmbda :
                
                m = torch.nn.Conv2d(self.init_structure[self.where]['in_channel'], self.alpha.shape[0], self.init_structure[self.where]['kernel_size'], device = my_device)
                
                m.weight =  torch.nn.parameter.Parameter(self.alpha, requires_grad = False)
                m.bias =  torch.nn.parameter.Parameter(self.bias_alpha, requires_grad = False)
                x_en_plus  = self.fct[self.where](lmbda * m(x))
                x = self.fct[self.where](self.layer[self.where](x))
                self.where += 1
                
                m = torch.nn.Conv2d(self.alpha.shape[0], self.init_structure[self.where]['out_channel'], self.init_structure[self.where]['kernel_size'], device = my_device)
                m.weight = torch.nn.parameter.Parameter(self.omega, requires_grad = False)
                m.bias = torch.nn.parameter.Parameter(torch.zeros(self.omega.shape[0], device = my_device), requires_grad = False)
                x_en_plus = m(x_en_plus)
                
                x = self.fct[self.where](self.layer[self.where](x) + x_en_plus)
            else :
                x = self.fct[self.where](self.layer[self.where](x))
                
            self.where += 1
        
        x = torch.flatten(x, start_dim = 1)
        
        
        while self.where <= self.deep :
            x = self.layer[self.where](x)
            x = self.fct[self.where](x)
            self.where += 1
        del m, x_en_plus
        return x

class MyModel_naturel(torch.nn.Module) :
    def __init__(self, model_to_copy, natural_gradient) :
        super(MyModel_naturel, self).__init__()
        
        self.init_structure = copy.deepcopy(model_to_copy.init_structure)
        self.deep = copy.deepcopy(model_to_copy.deep)
        self.Loss = copy.deepcopy(model_to_copy.Loss)
        self.layer_name = copy.deepcopy(model_to_copy.layer_name)
        self.layer = {}
        for j in range(1, self.deep + 1) :
            if self.layer_name[j] == 'L' :
                self.layer[j] = torch.nn.Linear(self.init_structure[j-1]['size'], self.init_structure[j]['size'], bias = True, device = my_device)
                self.layer[j].weight = copy.deepcopy(model_to_copy.layer[j].weight)
                self.layer[j].bias = copy.deepcopy(model_to_copy.layer[j].bias)
            if self.layer_name[j] == 'C' :
                self.layer[j] = torch.nn.Conv2d(self.init_structure[j]['in_channel'], self.init_structure[j]['out_channel'], 
                                                self.init_structure[j]['kernel_size'], device = my_device)
                self.layer[j].weight = copy.deepcopy(model_to_copy.layer[j].weight)
                self.layer[j].bias = copy.deepcopy(model_to_copy.layer[j].bias)
            
        self.where = copy.deepcopy(model_to_copy.where)
        self.decrease = copy.deepcopy(model_to_copy.decrease)
        self.init_deplacement = model_to_copy.init_deplacement
        self.fct = model_to_copy.fct
        self.layer = model_to_copy.layer
        self.break_conv = model_to_copy.break_conv
        self.outputs_size_before_activation = model_to_copy.outputs_size_before_activation
        self.natural_gradient = natural_gradient
        

                
    def forward(self, x, lmbda, deep_lmbda) :
        self.where = 1
        while self.layer_name[self.where] == 'C' :
            
            if self.where == deep_lmbda :
                if len(self.natural_gradient['weight'].shape) > 2 :
                    m = torch.nn.Conv2d(self.init_structure[self.where]['in_channel'], self.init_structure[self.where]['out_channel'], self.init_structure[self.where]['kernel_size'])
                else : 
                    m = torch.nn.Linear(self.layer[self.where].in_features, self.layer[self.where].out_features)
                m.weight =  torch.nn.parameter.Parameter(self.natural_gradient['weight'], requires_grad = False)
                m.bias =  torch.nn.parameter.Parameter(self.natural_gradient['bias'], requires_grad = False)
                x = self.layer[self.where](x) + lmbda * m(x)
            else :
                x = self.layer[self.where](x)
                
            x = self.fct[self.where](x)
            self.where += 1
        
        x = torch.flatten(x, start_dim = 1)
        
        
        while self.where <= self.deep :
            if self.where == deep_lmbda :
                if len(self.natural_gradient['weight'].shape) > 2 :
                    m = torch.nn.Conv2d(self.init_structure[self.where]['in_channel'], self.init_structure[self.where]['out_channel'], self.init_structure[self.where]['kernel_size'], device = my_device)
                else : 
                    m =torch.nn.Linear(self.layer[self.where].in_features, self.layer[self.where].out_features, device = my_device)
                #m = torch.nn.Conv2d(self.init_structure[self.where]['in_channel'], self.init_structure['out_channel'], self.init_structure['kernel_size'])
                m.weight =  torch.nn.parameter.Parameter(self.natural_gradient['weight'], requires_grad = False)
                m.bias =  torch.nn.parameter.Parameter(self.natural_gradient['bias'], requires_grad = False)
                x = self.layer[self.where](x) + lmbda * m(x)
            else :
                x = self.layer[self.where](x)
            x = self.fct[self.where](x)
            self.where += 1
        
        del m
        
        return x


class MyModel_linear_conv(torch.nn.Module):
    
    def __init__(self, parameters):
        super(MyModel_linear_conv, self).__init__()
        self.method_inv = parameters['method_inv']
        self.init_structure = parameters['init_structure']
        self.deep = len(self.init_structure) - 1
        self.Loss = parameters['Loss']
        self.layer_name = parameters['layer_name']
        self.layer = []
        self.where = None
        self.decrease = 1
        self.init_deplacement = parameters['init_deplacement']
        self.fct = parameters['fct']
        self.layer = {}
        self.break_conv = list(self.layer_name.values()).count('C') + 1
        self.outputs_size_before_activation = parameters['outputs_size_before_activation']
        self.register = parameters['register']
        self.df = pd.DataFrame({})
        self.empty = parameters['empty']
        self.information = {}
        self.method = parameters['method']
        self.batchsize_estimation = parameters['batchsize_estimation'] 
        
        for j in range(1, self.break_conv - 1) :
            self.creation_T_j(j)
            self.creation_T_j_T_T_J_S (j)
        
        for j in range(1, self.deep + 1) :
            if self.layer_name[j] == 'L' :
                self.layer[j] = torch.nn.Linear(self.init_structure[j-1]['size'], self.init_structure[j]['size'], bias = True, device = my_device)
            if self.layer_name[j] == 'C' :
                self.layer[j] = torch.nn.Conv2d(self.init_structure[j]['in_channel'], self.init_structure[j]['out_channel'], 
                                                self.init_structure[j]['kernel_size'], device = my_device)
            
    def forward(self, x):
        self.where = 1
        while self.layer_name[self.where] == 'C' :
            x = self.layer[self.where](x)
            x = self.fct[self.where](x)
            self.where += 1
            
        x = torch.flatten(x, start_dim = 1)
        
        while self.where <= self.deep :
            x = self.layer[self.where](x)
            x = self.fct[self.where](x)
            self.where += 1
        
        return x
    
    
    
    def register_df(self) :
        self.df = self.df.append(self.information, ignore_index = True)
        self.information = {} 
        
    def count_parameters(self) :
        return(sum(p.numel() for j in range(1, self.deep + 1) for p in self.layer[j].parameters() if p.requires_grad))
    
    def initialize(self, args = None) :
        
        for key in self.layer.keys() :
            
            if args[self.layer_name[key]]['arguments'] != None : 
                #print(key, self.layer_name[key], "args[self.layer_name[key]]['fct'](self.layer[key].weight," + args[self.layer_name[key]]['arguments'] + ")")
                eval("args[self.layer_name[key]]['fct'](self.layer[key].weight," + args[self.layer_name[key]]['arguments'] + ")")
            else : 
                #print(key, self.layer_name[key], "args[self.layer_name[key]]['fct'](self.layer[key].weight)")
                eval("args[self.layer_name[key]]['fct'](self.layer[key].weight)")
            self.layer[key].bias = torch.nn.Parameter(torch.zeros((self.layer[key].bias.shape), device = my_device))
        
        

    def ft_optimizer(self, mon_opti = None, lr = 0.0001, mu= 0) :
        
        if mon_opti == 'Adam' or mon_opti == None :
            self.optimizer = torch.optim.Adam([param for j in range(1, self.deep + 1) for param in self.layer[j].parameters()], lr=lr)
            
        else :
            self.optimizer = torch.optim.SGD([param for j in range(1, self.deep + 1) for param in self.layer[j].parameters()], lr=lr, momentum = mu)
    
    def update_architecture(self, deep, alpha_shape, omega_shape) :
        if self.layer_name[deep] == 'C' :
            
            self.init_structure[deep]['size'] = int(self.init_structure[deep]['size']/self.init_structure[deep]['out_channel'] *(alpha_shape[0] + self.init_structure[deep]['out_channel']))
            self.init_structure[deep]['out_channel'] += alpha_shape[0]
            self.outputs_size_before_activation[deep][2] += alpha_shape[0]
        else :
            self.init_structure[deep]['size'] += alpha_shape[0]
            self.outputs_size_before_activation[deep][0] += alpha_shape[0]
        
        if self.layer_name[deep + 1] == 'C' :
            self.init_structure[deep + 1]['in_channel'] += alpha_shape[0]
        
       
    
    def creation_T_j(self, deep) :
        T_tot = torch.tensor([], device = my_device)
        T_0 = torch.zeros((self.init_structure[deep]['kernel_size'][0]*self.init_structure[deep]['kernel_size'][1],
                          int(self.init_structure[deep]['size']/ self.init_structure[deep]['out_channel'])), device = my_device)
        kernel_s = self.init_structure[deep]['kernel_size']
        for k in range(kernel_s[1]) : 
            
            T_0[k * kernel_s[0] :(k + 1) *kernel_s[0],
                k*self.outputs_size_before_activation[deep][0] :k*self.outputs_size_before_activation[deep][0]  + kernel_s[0]]= torch.eye(kernel_s[0], device = my_device)
        
        for j in range (1, int(self.init_structure[deep + 1]['size']/ self.init_structure[deep + 1]['out_channel']) + 1):
            T_tot = torch.cat([T_tot, torch.unsqueeze(T_0, dim = 0)], dim = 0)
            
            
            if int((j)/ self.outputs_size_before_activation[deep + 1][0]) * self.outputs_size_before_activation[deep + 1][0] == j:
                T_0 = torch.cat([T_0[:, - kernel_s[0]:], T_0[:, :- kernel_s [0]]], dim = 1)
            else :
                T_0 = torch.cat([T_0[:, -1:], T_0[:, :-1]], dim = 1)
        T_tot = T_tot.to_sparse(3)
        #print('is_cuda', T_tot.is_cuda)
        dico_T[deep] = T_tot
       
    
    def creation_T_j_T_T_J_S (self, deep) :
        T_j = dico_T[deep] 
        T_T_T = torch.sparse.sum(torch.cat([torch.unsqueeze(torch.sparse.mm(T_j[k].transpose(1, 0), T_j[k]), dim = 0) for k in range(T_j.shape[0])]), dim = 0)
        #print('is_cuda TT', T_T_T.is_cuda)
        dico_S[deep] = T_T_T
       

    def gradient_naturel_conv(self, layer, update, lmbda = 1.) :
        with torch.no_grad() :
            #weight = torch.nn.parameter.Parameter(lmbda *update['weight'] + self.layer[layer].weight.detach(), requires_grad = True)
            #bias = torch.nn.parameter.Parameter(lmbda *update['bias'] + self.layer[layer].bias.detach(), requires_grad = True)
            self.layer[layer].weight =  torch.nn.parameter.Parameter(lmbda *update['weight'] + self.layer[layer].weight.detach(), requires_grad = True)
            self.layer[layer].bias = torch.nn.parameter.Parameter(lmbda *update['bias'] + self.layer[layer].bias.detach(), requires_grad = True)
            
            setattr(self.layer[layer].weight, 'requires_grad', True)
            setattr(self.layer[layer].bias, 'requires_grad', True)
            self.optimizer.add_param_group({'params' : self.layer[layer].weight})
            self.optimizer.add_param_group({'params' : self.layer[layer].bias}) 
            
            
            
    def gradient_naturel_lin(self, layer, delta_w, lmbda = 1.) :
        with torch.no_grad() :
            save_module = self.layer[layer]

            #weight = torch.nn.parameter.Parameter(save_module.weight + lmbda *delta_w['weight'], requires_grad = True)
            #bias = torch.nn.parameter.Parameter(save_module.bias + lmbda * delta_w['bias'], requires_grad = True)
            setattr(self.layer[layer], 'weight', torch.nn.parameter.Parameter(save_module.weight + lmbda *delta_w['weight'], requires_grad = True))
            setattr(self.layer[layer], 'bias',  torch.nn.parameter.Parameter(save_module.bias + lmbda * delta_w['bias'], requires_grad = True))
            
            setattr(self.layer[layer].weight, 'requires_grad', True)
            setattr(self.layer[layer].bias, 'requires_grad', True)
            self.optimizer.add_param_group({'params' : self.layer[layer].weight})
            self.optimizer.add_param_group({'params' : self.layer[layer].bias}) 
    
            del save_module
            
    def compute_NG_conv(self, deep, ind) :
        DV = -self.deplacement_voulu(deep, ind)
        _, dico_w = self.eval_expressivity_bottleneck_conv(deep, DV)
        
        
        lmbda = self.compute_natural_decay_upgrade(deep, dico_w, ind, exp = 10, lowest_deplacement = self.init_deplacement)
        
        self.gradient_naturel_conv(deep, dico_w, lmbda = lmbda)
        self.information['lambda_NG'] = lmbda
        
        V = activity[deep]
        m = torch.nn.Conv2d(self.layer[deep].in_channels, self.layer[deep].out_channels, self.layer[deep].kernel_size, device = my_device)
        m.weight = torch.nn.parameter.Parameter(dico_w['weight'] , requires_grad = False)
        m.bias = torch.nn.parameter.Parameter(dico_w['bias'], requires_grad = False)
        
        with torch.no_grad() :
            DV_eff = m(V)
  
        del dico_w, lmbda, V, m
        
        return(DV - DV_eff)
    
    def compute_NG_lin(self, deep, ind) :
        DV = -self.deplacement_voulu(deep, ind)
        w_star = self.eval_expressivity_bottleneck_lin(deep, DV)
        dico_w = {'weight' : w_star[:, :-1], 'bias' : w_star[:, -1]}
        lmbda = self.compute_natural_decay_upgrade(deep, dico_w, ind, exp = 10, lowest_deplacement = self.init_deplacement)
        self.gradient_naturel_lin(deep, dico_w, lmbda = lmbda)
        self.information['lambda_NG'] = lmbda
        
        m = torch.nn.Linear(dico_w['weight'].shape[1], dico_w['weight'].shape[0], bias = True, device = my_device)
        m.weight = torch.nn.parameter.Parameter(dico_w['weight'])
        m.bias = torch.nn.parameter.Parameter(dico_w['bias'])
        
        with torch.no_grad() :
            DV_eff = m(activity[deep])
        
        
        del w_star, dico_w, lmbda, m
        
        return(DV - DV_eff)
        
    def compute_NG(self, deep, ind) :
        if not(self.empty) : 
            if self.layer_name[deep] == 'C' :
                DV_proj = self.compute_NG_conv(deep, ind)
            if self.layer_name[deep] == 'L' :
                DV_proj = self.compute_NG_lin(deep, ind)    
            
            return(DV_proj)
    
    def create_MM(self, M_1) :
        M_T = torch.permute(M_1, (0, 2, 1))
        MM = torch.bmm(M_T, M_1).sum(dim = 0)
        
        del M_T
        return(MM)
    
    def create_MDV(self, DV, M) :
        DV = torch.unsqueeze(torch.unsqueeze(DV, dim = 0), dim = 0)
        DV = torch.permute(DV, (0, 1, 3, 2))
        M = torch.permute(M, (1, 0, 2))
        MDV = torch.matmul(DV, M)
        
        
        MDV = torch.flatten(MDV, start_dim = 0, end_dim = 2)
        
        del DV, M
        return(MDV)
    
    def create_MSSM(self, M, deep) :
        
        S = dico_S[deep]
        
        S = torch.unsqueeze(torch.unsqueeze(S, 0), 0).to_dense()
        
        SM = torch.matmul(S, M)
        
        MT = torch.permute(M, (0, 2, 1))
        
        
        MSM = torch.bmm(MT, SM[0]).sum(dim = 0)
        
        del MT, SM, S
        
        return(MSM)
    
    
    def create_TMDV_test(self, M, T_tot, DV) :
    
        DV = DV.flatten(start_dim = -2)
        #print(DV.shape, 'DV shape')
        channel = DV.shape[1]
        nbr_param = M.shape[2]
        
        DV_augmented = torch.unsqueeze(torch.permute(DV, (1, 2, 0)), dim = 1)
       
        
        M_permute = torch.permute(M, (1, 0, 2))
        

        #print(M_permute.shape, 'M_permute')
        #print(DV_augmented.shape, 'DV_augmented')
        MDV = torch.matmul(DV_augmented, M_permute)
        
 
        
        MDV_flat = torch.permute(MDV, (1, 2, 0, 3)).flatten(start_dim = 2)
        T_permute = permute_sparse(T_tot, (2, 1, 0))
        

        
        TMDV_ = torch.bmm(T_permute, MDV_flat).sum(dim = 0)
        #print(TMDV_.shape, 'TMDV')
        TMDV_ = TMDV_.reshape((TMDV_.shape[0], channel, nbr_param)).flatten(start_dim = 0, end_dim = 1)
        
        del DV, DV_augmented, M_permute, MDV, MDV_flat, T_permute
        return(TMDV_)
    
    def create_TMDV(self, M, T_tot, DV) :
        T_sq = torch.unsqueeze(T_tot, dim = 1)
        print(T_sq)
        DV = DV.flatten(start_dim = -2)
        
        TM = torch.matmul(T_sq.to_dense(), M)
        
        DV_perm = torch.permute(DV, (1, 0, 2))
        TM_perm = torch.permute(TM, (2, 3, 1, 0))
        
        
        DV_flat = DV_perm.flatten(start_dim = 1)
        TM_flat = TM_perm.flatten(start_dim = 2)
        
        DV_sq = torch.unsqueeze(DV_flat, dim = 0)
        DV_repeat = DV_sq.repeat(TM_flat.shape[0], 1, 1)
        TM_permute = torch.permute(TM_flat, (0, 2, 1))
        
        TM_DV = torch.bmm(DV_repeat, TM_permute)
        
        TM_DV = torch.permute(TM_DV, (1, 0, 2))
        
        TM_DV = TM_DV.flatten(start_dim = 0, end_dim = 1)
        
        del T_sq, DV, TM, DV_perm, TM_perm, DV_flat, TM_flat, DV_sq, DV_repeat, TM_permute
        return(TM_DV)
    
   
    def compute_optimal_K_neurone_conv(self, MSSM, TMDV, deep) :
        U, vps, _ = torch.linalg.svd(MSSM)
        sqrt_sigma = torch.sparse_coo_tensor(torch.arange(vps.shape[0], device = my_device).repeat(2, 1), torch.sqrt(vps), (vps.shape[0], vps.shape[0]))
        
        Q, S_1_demi = torch.linalg.qr(torch.sparse.mm(sqrt_sigma, U))
        
        S_moins_1_demi = torch.linalg.pinv(S_1_demi)
        S_moins_1_demin_N = torch.matmul(S_moins_1_demi, TMDV.T)
        u, valeurs_propres, omega = torch.linalg.svd(S_moins_1_demin_N, full_matrices = False)
        alpha = torch.matmul(S_moins_1_demi.T, u)
        if self.method != None :
            alpha, omega, valeurs_propres = self.method(alpha, omega, valeurs_propres)
        if self.layer_name[deep] == 'C' :
            alpha = alpha.T
            bias_alpha = alpha[:, -1]
            alpha = alpha[:, :-1]
            alias = self.init_structure[deep]
            alpha = alpha.reshape((alpha.shape[0], alias['in_channel'], alias['kernel_size'][0], alias['kernel_size'][1]))
        
        if self.layer_name[deep] == 'C' and self.layer_name[deep + 1] == 'C' :
            alias_2 = self.init_structure[deep + 1]
            omega = omega.reshape((omega.shape[0], alias_2['out_channel'], alias_2['kernel_size'][0], alias_2['kernel_size'][1]))
            omega = torch.permute(omega, (1, 0, 2, 3))
        if self.layer_name[deep] == 'L' :
            alpha = alpha.T
            bias_alpha = alpha[:, -1]
            alpha = alpha[:, :-1]
            omega = omega.T
            
        del U, vps, sqrt_sigma, Q, S_1_demi, S_moins_1_demi, S_moins_1_demin_N, u
        
        return(alpha, bias_alpha, omega, valeurs_propres)
    

        
    
    def compute_add_neurone(self, deep, ind, DV_proj) :
        if deep == self.deep :
            return(None)
        #DV_proj = -self.deplacement_voulu(deep + 1, ind)
        
        if self.layer_name[deep] == 'C' and self.layer_name[deep + 1] == 'C' :
            
            M = self.compute_M(deep)
            M_1 = torch.cat([M, torch.ones((M.shape[0], M.shape[1], 1), device = my_device)], dim = 2)
            T_j = dico_T[deep]
            MSSM = self.create_MSSM(M_1, deep)
            #TM_DV = self.create_TMDV(M_1, T_j, DV_proj) 
            TM_DV = self.create_TMDV_test(M_1, T_j, DV_proj)
            
            alpha, bias_alpha, omega, valeurs_propres = self.compute_optimal_K_neurone_conv(MSSM, TM_DV, deep)
            
            
            lambda_w = self.compute_decay_upgrade(alpha, omega, bias_alpha, deep, ind, exp = 10, lowest_deplacement = self.init_deplacement)
            
            if self.empty :
                self.create_first_layer(deep, alpha, omega, bias_alpha, lambda_w = lambda_w)
                self.empty = False
            else :
                self.add_K_neurons_linear_convolution(deep, alpha, omega, bias_alpha, lambda_w = lambda_w) 
        
            del M, M_1, MSSM
            
        if self.layer_name[deep] == 'C' and self.layer_name[deep + 1] == 'L' :
            
            M = self.compute_M(deep)
            M_1 = torch.cat([M, torch.ones((M.shape[0], M.shape[1], 1), device = my_device)], dim = 2)
            MM = self.create_MM(M_1)
            MDV = self.create_MDV(DV_proj, M_1)
            alpha, bias_alpha, omega, valeurs_propres = self.compute_optimal_K_neurone_conv(MM, MDV, deep)
            omega = omega.reshape((omega.shape[0], self.outputs_size_before_activation[deep][0] * self.outputs_size_before_activation[deep][1], 
                                   self.outputs_size_before_activation[deep + 1][0] ))    
            omega = omega.flatten(start_dim =0, end_dim = 1)
            omega = omega.permute(( 1, 0))
            
            #print('is_cuda M', M.is_cuda)
            #print('omega is cuda', omega.is_cuda)
            lambda_w = self.compute_decay_upgrade(alpha, omega, bias_alpha, deep, ind, exp = 10, lowest_deplacement = self.init_deplacement)
            
            if self.empty :
                self.create_first_layer(deep, alpha, omega, bias_alpha, lambda_w = lambda_w)
                self.empty = False
            else :
                self.add_K_neurons_linear_convolution(deep, alpha, omega, bias_alpha, lambda_w = lambda_w)
            
            del M, M_1, MM
        
        if self.layer_name[deep] == 'L' and self.layer_name[deep + 1] == 'L' :
            V = torch.cat([activity[deep], torch.ones((activity[deep].shape[0], 1), device = my_device)], dim = 1)
            M = torch.matmul(V.T, V)/V.shape[0]
            MDV = torch.matmul(V.T, DV_proj)/V.shape[0]
            alpha, bias_alpha, omega, valeurs_propres = self.compute_optimal_K_neurone_conv(M, MDV.T, deep)
            
                
            lambda_w = self.compute_decay_upgrade(alpha, omega, bias_alpha, deep, ind, exp = 10, lowest_deplacement = self.init_deplacement)
            if self.empty :
                self.create_first_layer(deep, alpha, omega, bias_alpha, lambda_w = lambda_w)
                self.empty = False
            else :
                self.add_K_neurons_linear_convolution(deep, alpha, omega, bias_alpha, lambda_w = lambda_w)
            del V, M, MDV
            
        if self.register :
            self.information.update({'valeurs_propres_'  + str(i) : valeurs_propres[i] for i in range(len(valeurs_propres))})
            self.information['lambda_new_neurone'] = lambda_w
        
            
        self.update_architecture(deep, alpha.shape, omega.shape)
        
        return(alpha, omega, bias_alpha, lambda_w, valeurs_propres)
    
    
    
    def eval_variation(self, batch_size, eval_batch = 10, to_repeat = 1) :
        if to_repeat > 0 :
            model_copy = copy.deepcopy(self)
            max_ind = X_train_rescale.shape[0]
            L = torch.zeros(eval_batch, device = my_device)
            
            for j in range(eval_batch) :
                model_copy.optimizer.zero_grad()
                ind = torch.randperm(max_ind)[:batch_size]
                L_save = model_copy.Loss(model_copy(X_train_rescale[ind]), Y_train_rescale[ind]) * model_copy.decrease
                
                grad = torch.autograd.grad(L_save, [param for j in range(1, model_copy.deep) for param in model_copy.linear[j].parameters()])
                L[j] = sum([torch.norm(grad_i) for grad_i in grad])
                
                
                
            var = L.var()
            mean = L.mean()

            
            decrease = torch.log10(0.05 * mean /var)
            if decrease < 0 :
                #self.change_lr(decrease = 1e-3)
                print(decrease,-torch.log10(-decrease),  'on decrease', self.decrease * min(1, 0.05/(var/mean)))
                del model_copy
                self.decrease = self.decrease * min(1,0.05/(var/mean))
                self.eval_variation(batch_size, to_repeat = to_repeat - 1)
            
    def change_lr(self, decrease = 1e-1) :
        for g in self.optimizer.param_groups:
            g['lr'] = decrease * g['lr']
    
    
    
    def deplacement_voulu(self, DV_deep, ind_MB) :
        for j in range(1, self.deep + 1) :
            for param in self.layer[j].parameters() :
                param.requires_grad = False
        
        x = X_train_rescale[ind_MB]
        
        for j in range(1, DV_deep) :
            if j == self.break_conv :
                x = torch.flatten(x, start_dim = 1)
            x = self.fct[j](self.layer[j](x))
        
        if DV_deep == self.break_conv :
            x = torch.flatten(x, start_dim = 1)
        
        x = self.layer[DV_deep](x)
        x.requires_grad = True
        
        y = x
        
        y = self.fct[DV_deep](x)
        
        for j in range(DV_deep + 1, self.deep + 1) :
            if j == self.break_conv :
                y = torch.flatten(y, start_dim = 1)
                
            y = self.fct[j](self.layer[j](y))
        
        y = self.Loss(y, Y_train_rescale[ind_MB])
        _DV = torch.autograd.grad(y, x)[0]
        for j in range(1, self.deep + 1) :
            for param in self.layer[j].parameters() :
                param.requires_grad = True
                
        del y
        del x
        
        return(_DV)
    
        
    def help_add_K_neurons_linear_left(self, layer, weight_to_add, bias_to_add, lambda_w = 1) :
        save_module = self.layer[layer]
        
        new_weight = torch.cat([save_module.weight, weight_to_add * lambda_w])
        new_bias = torch.cat([save_module.bias, bias_to_add * lambda_w])
        
        new_weight = torch.nn.parameter.Parameter(new_weight, requires_grad = True)
        new_bias = torch.nn.parameter.Parameter(new_bias, requires_grad = True)
        
        self.layer[layer] = torch.nn.Linear(self.layer[layer].in_features, self.layer[layer].out_features + weight_to_add.shape[0])
        
        self.layer[layer].weight = new_weight
        self.layer[layer].bias = new_bias
        
        self.optimizer.add_param_group({'params' : self.layer[layer].weight})
        self.optimizer.add_param_group({'params' : self.layer[layer].bias})
        
    
    def help_add_K_neurons_conv2d_left(self, layer, weight_to_add, bias_to_add, lambda_w = 1) :
        save_module = self.layer[layer]
        
        new_weight = torch.cat([save_module.weight, weight_to_add * lambda_w])
        new_bias = torch.cat([save_module.bias, bias_to_add * lambda_w])
        
        new_weight = torch.nn.parameter.Parameter(new_weight, requires_grad = True)
        new_bias = torch.nn.parameter.Parameter(new_bias, requires_grad = True)
        
        self.layer[layer] = torch.nn.Conv2d(self.layer[layer].in_channels, self.layer[layer].out_channels + weight_to_add.shape[0], kernel_size = save_module.kernel_size, stride = save_module.stride)
        
        self.layer[layer].weight = new_weight
        self.layer[layer].bias = new_bias
        
        self.optimizer.add_param_group({'params' : self.layer[layer].weight})
        self.optimizer.add_param_group({'params' : self.layer[layer].bias})
        
    def help_add_K_neurons_linear_right(self, layer, weight_to_add, lambda_w = 1) :
        save_module = self.layer[layer]
        
        new_weight = torch.cat([save_module.weight, weight_to_add * lambda_w], dim = 1)
        new_weight = torch.nn.parameter.Parameter(new_weight, requires_grad = True)
        
        self.layer[layer] = torch.nn.Linear(self.layer[layer].in_features + weight_to_add.shape[1], self.layer[layer].out_features)
        
        
        self.layer[layer].weight = new_weight
        self.layer[layer].bias = save_module.bias
        
        self.optimizer.add_param_group({'params' : self.layer[layer].weight})
        
        
    def help_add_K_neurons_conv2d_right(self, layer, weight_to_add, lambda_w = 1) :
        save_module = self.layer[layer]
        
        new_weight = torch.cat([save_module.weight, weight_to_add * lambda_w], dim = 1)
        new_weight = torch.nn.parameter.Parameter(new_weight, requires_grad = True)
        
        self.layer[layer] = torch.nn.Conv2d(self.layer[layer].in_channels + weight_to_add.shape[1], self.layer[layer].out_channels, kernel_size = save_module.kernel_size, stride = save_module.stride)
        
        
        self.layer[layer].weight = new_weight
        self.layer[layer].bias = save_module.bias
        
        self.optimizer.add_param_group({'params' : self.layer[layer].weight})
        
    
    def add_K_neurons_linear_convolution(self, layer, new_weight_1, new_weight_2, bias_1, lambda_w = 1.0) :
       
        if layer < self.break_conv :
            self.help_add_K_neurons_conv2d_left(layer, new_weight_1, bias_1, lambda_w = lambda_w)
        else :
            self.help_add_K_neurons_linear_left(layer, new_weight_1, bias_1, lambda_w = lambda_w)
            
        if layer  + 1 < self.break_conv :
            self.help_add_K_neurons_conv2d_right(layer + 1, new_weight_2, lambda_w = 1)
        else :
            self.help_add_K_neurons_linear_right(layer + 1, new_weight_2, lambda_w = 1)



    def create_first_layer(self, deep, alpha, omega, bias_alpha, lambda_w = 1.0) :
        print('je cree le premier layer')
        if self.layer_name[deep] == 'C' :
            self.layer[deep] = torch.nn.Conv2d(self.layer[deep].in_channels, 0, kernel_size = self.layer[deep].kernel_size)
        if self.layer_name[deep] == 'L' :
            self.layer[deep] = torch.nn.Linear(self.layer[deep].in_features, 0)
            
        if self.layer_name[deep + 1] == 'C' :
            self.layer[deep + 1] = torch.nn.Conv2d(0, self.layer[deep + 1].out_channels, kernel_size = self.layer[deep + 1].kernel_size)
        if self.layer_name[deep + 1] == 'L' :
            self.layer[deep + 1] = torch.nn.Linear(0, self.layer[deep + 1].out_features)
            
        self.add_K_neurons_linear_convolution(deep, alpha, omega, bias_alpha, lambda_w)


    def train_batch(self, perm_train, batch_size, epoch, limite_temps = None, reduction = 1, change_lr = False) :
        X_train = X_train_rescale[perm_train]
        Y_train = Y_train_rescale[perm_train]
        X_test = X_test_rescale
        Y_test = Y_test_rescale
        
        
        loss_list_train = []
        loss_list_test = []
        criterion = self.Loss
        accuracy = []
        accuracy_train = []
        my_time = []
        batch_size_bis = int(batch_size)
        optimizer = self.optimizer
        X_input = None
        if limite_temps != None :
            epoch = 1000000
        
        t0 = time.time()
        if change_lr :
            self.eval_variation(batch_size, 10)
        to_eval = 10
        for t in range(int(epoch * len(X_train) / batch_size_bis)) :
            t1 = time.time()
            x_input_1 = X_train[np.mod(t * batch_size, len(X_train)) : min(np.mod(t * batch_size, len(X_train)) + batch_size, len(X_train))]
            x_input_2 = X_train[min(np.mod(t * batch_size, len(X_train)) + batch_size, len(X_train)) : np.mod((t + 1) * batch_size, len(X_train)) + batch_size]
            X_input = torch.cat([x_input_1, x_input_2])
            y_input_1 = Y_train[np.mod(t * batch_size, len(X_train)) : min(np.mod(t * batch_size, len(X_train)) + batch_size, len(X_train))]
            y_input_2 = Y_train[min(np.mod(t * batch_size, len(X_train)) + batch_size, len(X_train)) : np.mod((t + 1) * batch_size, len(X_train)) + batch_size]
            Y_input = torch.cat([y_input_1, y_input_2])
            
            optimizer.zero_grad()
            Y_pred = self(X_input)
            Loss = criterion(Y_pred, Y_input) * self.decrease
            Loss.backward()
            optimizer.step()
            if change_lr and np.mod(t, to_eval) == 0 :
                self.eval_variation(batch_size, 10)
                to_eval = to_eval * 2
            with torch.no_grad() :
                if int(t/reduction) * reduction == t :
                    y_pred = self(X_test)
                    y_true = Y_test
                    loss_list_test.append(criterion(y_pred, y_true).item())
                    loss_list_train.append(Loss.detach().item())
                    accuracy.append(calculate_accuracy(y_pred, y_true))
                    accuracy_train.append(calculate_accuracy(self(X_input), Y_input))
                    my_time.append(time.time() - t0)
            tepoch = time.time() - t1
            
            #if limite_temps != None :
            #    while time.time() - t0 < limite_temps and time.time() - t0 > limite_temps - tepoch/k :
            #        k = 2 * k
            #        batch_size = int(batch_size / k)
            #        reduction = 2 * reduction
            if limite_temps != None and time.time() - t0 > limite_temps :
                    break
        
                
        with torch.no_grad() :
            if X_input != None :
                #ind = np.random.randint(size = batch_size, low = 0, high = len(X_test))
                y_pred = self(X_test)
                y_true = Y_test
                loss_list_test.append(criterion(y_pred, y_true).item())
                loss_list_train.append(loss_list_train[-1])
                accuracy.append(calculate_accuracy(y_pred, y_true))
                accuracy_train.append(accuracy_train[-1])
                my_time.append(time.time() - t0)
        return(np.array(loss_list_train), np.array(loss_list_test), np.array(accuracy), np.array(my_time), np.array(accuracy_train))
            
    
        
    def eval_expressivity_bottleneck_lin_save(self, deep, DV, method = 'LBFGS') :
        
        if activity[deep] ==[] :
            print('None')
            return(None)
        else :
            
            V_k = torch.cat([activity[deep], torch.ones((len(activity[deep]), 1), device = my_device)], dim = 1)
            GTG = torch.matmul(V_k.T, V_k)/V_k.shape[0]
            x_to_apply = torch.matmul(DV.T, V_k)/V_k.shape[0]
            
            diag = torch.diagonal(GTG, 0) == 0
            
            if self.method_inv == 'approx' :
                delta_w_star = inv_APPLY(GTG, x_to_apply)
            if self.method_inv == 'exact' :
                delta_w_star = torch.matmul(x_to_apply, torch.linalg.pinv(GTG.cpu()).to(my_device))
            return(delta_w_star)
    
    def eval_expressivity_bottleneck_lin(self, deep, DV, method = 'LBFGS') :
        
        if activity[deep].shape[0] == 0 :
            print('None')
            return(None)
        else :
            
            V_k = torch.cat([activity[deep], torch.ones((len(activity[deep]), 1), device = my_device)], dim = 1)
            GTG = torch.matmul(V_k.T, V_k)/V_k.shape[0]
            
            diag = (torch.diagonal(GTG, 0) != 0).to(my_device)
            if self.method_inv == 'approx' :
                x_to_apply = torch.matmul(DV.T, V_k)/V_k.shape[0]
                delta_w_star = inv_APPLY(GTG, x_to_apply)
            if self.method_inv == 'exact' :
                GTG_reduit = GTG[diag][:, diag]
                V_k_reduit = V_k[:, diag]
                
                x_to_apply = torch.matmul(DV.T, V_k_reduit)/V_k_reduit.shape[0]
                delta_w_star_reduit = torch.matmul(x_to_apply, torch.linalg.pinv(GTG_reduit.cpu()).to(my_device))
                print(delta_w_star_reduit[0][:10], 'try to figure out')
                delta_w_star = torch.zeros((delta_w_star_reduit.shape[0], diag.shape[0]), device = my_device)
                delta_w_star[:, diag] = delta_w_star_reduit
            return(delta_w_star)
    
    
    
    def compute_M(self, layer) :
        with torch.no_grad() :
            V_k = activity[layer]
            
            [i_1, i_2, i_3] = self.outputs_size_before_activation[layer]
            [j_1, j_2, j_3] = self.outputs_size_before_activation[layer - 1]
            M = torch.zeros((V_k.shape[0], i_1*i_2, j_3 * self.init_structure[layer]['kernel_size'][0] * self.init_structure[layer]['kernel_size'][1]), device = my_device)
            for k in range(i_1*i_2) :
                ind_i = int(k / i_2)
                ind_j = np.mod(k, i_2)
                M[:, k, :] = torch.flatten(V_k[:, :, ind_i :ind_i + self.init_structure[layer]['kernel_size'][0], ind_j : ind_j + self.init_structure[layer]['kernel_size'][1]],
                                     start_dim=1)
            return(M)
    def eval_expressivity_bottleneck_conv(self, layer, DV) :
        
        if activity[layer] ==[] :
            print('None')
            return(None)
        else :
            with torch.no_grad() :
                M = self.compute_M(layer)
                [i_1, i_2, i_3] = self.outputs_size_before_activation[layer]
                [j_1, j_2, j_3] = self.outputs_size_before_activation[layer - 1]
                
                MM = torch.flatten(M, start_dim = 0, end_dim = 1)
                MM = torch.cat([MM, torch.ones((MM.shape[0], 1), device = my_device)], dim = 1)
                
                MTM = torch.matmul(MM.T, MM)
                print('shape MTM ', MTM.shape, 'pour le layer ', layer)
                pseudo_inv = torch.linalg.pinv(MTM)
                
                #MM_T  = torch.transpose(M.permute(2, 1, 0), 1, 0).permute(2,1,0)
                MM_T = MM.T
               
                delta_kernel = torch.zeros((i_3, j_3* self.init_structure[layer]['kernel_size'][0] * self.init_structure[layer]['kernel_size'][1] + 1)) 
                for k in range (i_3) :
                    sub_DV = DV[:, k, :, :]
                    
                    delta_A_goal = torch.flatten(sub_DV, start_dim = 0)
                    delta_A_goal = delta_A_goal.reshape((delta_A_goal.shape[0], 1))
                    
                    A_goal_A_delta = torch.matmul(MM_T, delta_A_goal)
                    
                    kernel_i = torch.matmul(pseudo_inv, A_goal_A_delta)
                    
                    delta_kernel[k] = kernel_i.T[0]
                (k_1, k_2) = self.init_structure[layer]['kernel_size']
                del MM, MTM
                return(pseudo_inv, {'weight' : delta_kernel[:, : k_1*k_2* self.init_structure[layer]['in_channel']].reshape((delta_kernel.shape[0], self.init_structure[layer]['in_channel'], k_1, k_2)),
                                    'bias' : delta_kernel[:, -1] })
            
            
            
            
            
    def compute_optimale_K_neurones_lin(self, deep, DV_proj) :
        with torch.no_grad() :
            V_k_1 = torch.cat([activity[deep], torch.ones((len(activity[deep]), 1), device = my_device)], dim = 1)
            S = torch.matmul(V_k_1.T, V_k_1) / len(V_k_1)
            
            N = torch.matmul(V_k_1.T, DV_proj)/len(V_k_1)
            
            ind_not_null = torch.diag(S) > 1e-3
            ind_null = torch.diag(S) <= 1e-3
            S = S[ind_not_null][:, ind_not_null]
            #print(torch.diag(S), 'diag S')
            
            try :
                S_1demi = torch.linalg.inv(torch.linalg.cholesky(S))
            except :
                S_1demi =  torch.linalg.inv(torch.linalg.cholesky(S + 10 *torch.eye(len(S))))
            
            for is_null in range(V_k_1.shape[1]) :
                if ind_null[is_null] :
                    S_1demi = torch.cat([S_1demi[:is_null], torch.zeros(1, S_1demi.shape[1]), S_1demi[is_null:]], axis = 0)
                    S_1demi = torch.cat([S_1demi[:, :is_null], torch.zeros(S_1demi.shape[0], 1), S_1demi[:, is_null :]], axis = 1)
            del S
            return(torch.linalg.svd(torch.matmul(S_1demi, N), full_matrices = False), S_1demi)
        
    
    
    def compute_decay_upgrade(self, alpha, omega, bias_alpha, dee, ind, exp = 10, lowest_deplacement = 0.0001) :
        init_deplacement = self.init_deplacement
        if self.layer_name[dee] == 'C' and self.layer_name[dee + 1] == 'C' :
            model_copy = MyModel_copy_C_C(self, alpha, omega, bias_alpha)
        if self.layer_name[dee] == 'C' and self.layer_name[dee + 1] == 'L' :
            model_copy = MyModel_copy_C_L(self, alpha, omega, bias_alpha)
        if self.layer_name[dee] == 'L' :
            model_copy = MyModel_copy_L_L(self, alpha, omega, bias_alpha)
        model_copy.to(my_device)
        lmbda = torch.tensor(init_deplacement, device = my_device)
        deplacement = torch.tensor(init_deplacement, device = my_device)
        with torch.no_grad() :
            L_save = model_copy.Loss(model_copy(X_train_rescale[ind],  torch.tensor(0., device = my_device), dee), Y_train_rescale[ind])
            next_L = model_copy.Loss(model_copy(X_train_rescale[ind],  lmbda, dee), Y_train_rescale[ind])
            while next_L.item() > L_save.item() :
                lmbda = lmbda / exp
                deplacement = deplacement / exp
                next_L = model_copy.Loss(model_copy(X_train_rescale[ind], lmbda, dee), Y_train_rescale[ind])
                
                print('je descends avant l optimisation de lmbda :', lmbda)
            while (L_save.item() == next_L.item()) :
                deplacement = deplacement * exp
                lmbda = deplacement 
                next_L = model_copy.Loss(model_copy(X_train_rescale[ind],  lmbda, dee), Y_train_rescale[ind])
                print('je double', deplacement, 'L : ', L_save, 'next_L:', next_L, 'vrai Loss', self.Loss(self(X_train_rescale[ind]), Y_train_rescale[ind]))
            if next_L.item() > L_save.item() :
                lmbda = lmbda / exp
        diff_loss_0 = abs(L_save.item() - next_L.item())
        diff_loss = 10 * diff_loss_0
        lmbda_save = copy.deepcopy(lmbda)
        while deplacement > lowest_deplacement and diff_loss > 0.0 :
            with torch.no_grad() :
                while next_L < L_save :
                    lmbda_save = copy.deepcopy(lmbda)
                    print('lmbda', lmbda, 'next_L :', next_L, 'vrai Loss', self.Loss(self(X_train_rescale[ind]), Y_train_rescale[ind]))
                    deplacement = deplacement * exp
                    lmbda = lmbda + deplacement
                    L_save = next_L.item()
                    next_L = model_copy.Loss(model_copy(X_train_rescale[ind], lmbda, dee),Y_train_rescale[ind])
                    diff_loss = abs(next_L.item() - L_save)
            lmbda = lmbda - deplacement/exp
            deplacement = deplacement / exp
            next_L = model_copy.Loss(model_copy(X_train_rescale[ind],  lmbda, dee), Y_train_rescale[ind])
            diff_loss = abs(next_L.item() - L_save)
            while next_L.item() < L_save :
                lmbda_save = copy.deepcopy(lmbda)
                L_save = next_L.item()
                
                lmbda_tensor = torch.tensor(lmbda.clone().detach().float(), requires_grad = True, device = my_device)
                next_L = model_copy.Loss(model_copy(X_train_rescale[ind], lmbda_tensor, dee), Y_train_rescale[ind])
                grad = torch.sign(torch.autograd.grad(next_L, lmbda_tensor)[0]).item()
                lmbda = lmbda + grad * deplacement
                #print('lmbda', lmbda, 'next_L :', next_L, 'vrai Loss', self.Loss(self(X_train_rescale[ind]), Y_train_rescale[ind]), 'decrease')
            #print('deplacement :', deplacement, 'Loss :', next_L)
        
        del model_copy
        return(lmbda_save)
    
    
    def compute_natural_decay_upgrade(self, dee, natural_gradient, ind, exp = 10, lowest_deplacement = 0.0001) :
        init_deplacement = self.init_deplacement
        model_copy = MyModel_naturel(self, natural_gradient)
        model_copy.to(my_device)
        lmbda = torch.tensor(init_deplacement, device = my_device)
        deplacement = torch.tensor(init_deplacement, device = my_device)
        with torch.no_grad() :
            L_save = model_copy.Loss(model_copy(X_train_rescale[ind],  torch.tensor(0., device = my_device), dee), Y_train_rescale[ind])
            next_L = model_copy.Loss(model_copy(X_train_rescale[ind],  lmbda, dee), Y_train_rescale[ind])
            while next_L.item() > L_save.item() :
                lmbda = lmbda / exp
                deplacement = deplacement / exp
                next_L = model_copy.Loss(model_copy(X_train_rescale[ind],  lmbda, dee), Y_train_rescale[ind])
                #self.init_deplacement = deplacement
                print('je descends avant l optimisation de lmbda  natuirel:', lmbda)
            while (L_save.item() == next_L.item()) :
                deplacement = deplacement * exp
                lmbda = deplacement 
                L_save = model_copy.Loss(model_copy(X_train_rescale[ind],  torch.tensor(0., device = my_device), dee), Y_train_rescale[ind])
                next_L = model_copy.Loss(model_copy(X_train_rescale[ind], lmbda, dee), Y_train_rescale[ind])
                print('je double avant l optimisation de lmbda  natuirel:', lmbda, L_save, next_L)
                if next_L.item() > L_save.item() :
                    lmbda = lmbda / exp
        diff_loss_0 = abs(L_save.item() - next_L.item())
        diff_loss = 10 * diff_loss_0
        lmbda_save = copy.deepcopy(lmbda)
        while deplacement > lowest_deplacement and diff_loss > 0.0 :
            with torch.no_grad() :
                while next_L < L_save :
                    lmbda_save = copy.deepcopy(lmbda)
                    deplacement = deplacement * exp
                    lmbda = lmbda + deplacement
                    L_save = next_L.item()
                    next_L = model_copy.Loss(model_copy(X_train_rescale[ind],  lmbda, dee), Y_train_rescale[ind])
                    diff_loss = abs(next_L.item() - L_save)
                    #print(lmbda, 'lmbda')
            lmbda = lmbda - deplacement/exp
            deplacement = deplacement / exp
            next_L = model_copy.Loss(model_copy(X_train_rescale[ind], lmbda , dee), Y_train_rescale[ind])
            diff_loss = abs(next_L.item() - L_save)
            while next_L.item() < L_save :
                lmbda_save = copy.deepcopy(lmbda)
                L_save = next_L.item()
                #lmbda_tensor = lmbda.requires_grad(True)
                lmbda_tensor = torch.tensor(lmbda.clone().detach().float(), requires_grad = True, device = my_device)
                next_L = model_copy.Loss(model_copy(X_train_rescale[ind], lmbda_tensor, dee), Y_train_rescale[ind])
                grad = torch.sign(torch.autograd.grad(next_L, lmbda_tensor)[0]).item()
                lmbda = lmbda + grad * deplacement
            #print('deplacement :', deplacement, 'diff_loss :', diff_loss, 'diff_loss_0dee :', diff_loss_0, 'Loss' , next_L)
        with torch.no_grad() :
            print(model_copy.Loss(model_copy(X_train_rescale[ind], lmbda , dee), Y_train_rescale[ind]), 'derniere Loss')
            print(model_copy.Loss(model_copy(X_train_rescale[ind], lmbda_save , dee), Y_train_rescale[ind]), 'derniere Loss')
        del model_copy
        return(lmbda_save)
    
    
    def activities_input(self, module, i, o) :
        activity[self.where] = i[0].detach()
    
