from __future__ import print_function
import argparse
import ast
import os
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import copy
import numpy as np
from sklearn.model_selection import StratifiedKFold

import plots.utils
import imdb_utils
#from plots.utils import *

method=0#0,1 or 2

dataset='imdb'#'mnist','fashion_mnist' or 'imdb'
#mnist=False#gradient_factor should be smaller for mnist=False
model_nr=12#9#7 #3 or 4, 5, 6
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')#torch.device("cuda")

mini_batch=True
batch_size=128#128#2048#128#2048#128#2048#60000#512#512#64 #for minibatch=False higher batch_size is numerically better

cross_val=False
calculate_test_stats=True
calculate_train_stats=True

trainings_same_model = 15# 100 if method!=0 else 200  # 5#5#2#1#30#3 #or number of folds for cross_val=True
trainings_different_model = 1  # 100#2#100

separate_hyperparameters_for_optimized=False

lr_mul=1#3.3333333333333333333333333

show_grad=False

seed=11
epochs=500#2500#300#5000
if model_nr==6:
    epochs=15
elif model_nr==9:
    # epochs=500
    # if method>=1:
    #     epochs=300
    epochs = 125
    if method >= 1:
        if trainings_different_model > 1:
            epochs = 50
        else:
            epochs = 50
elif model_nr==11:
    epochs = 40
    if method >= 1:
        if trainings_different_model>1:
            epochs = 40
        else:
            epochs=40
elif model_nr==12:
    epochs = 200
    if method >= 1:
        if trainings_different_model>1:
            epochs = 150
        else:
            epochs=150
#epochs=5
model_save=False
model_load=False

iter_count=-1#-1
with_optimizer_parameter_copy=None
if method==1:
    iter_count = 2#5#may be changed (>=2)
    # if iter_count>2 and trainings_same_model!=1:
    #     trainings_same_model=int(trainings_same_model/2)
if method==2:
    with_optimizer_parameter_copy=False#default: False #it is not known which is better
    iter_count=1#don't change

folds_num=None
if cross_val:
    folds_num=trainings_same_model
    trainings_same_model=1

# calculate_dist=True
step_is_fraction_of_optimizer_denominator=0.#np.inf#np.inf#np.Inf#0.01 #0 value turns it off
denominator_mul=None
if step_is_fraction_of_optimizer_denominator!=0.:
    denominator_mul=.99

d_type=torch.float#torch.double, or torch.float, etc.

def pretraining_finish_criterion(stats,epoch):
    return stats and 'train_accuracy' in stats and stats['train_accuracy'][epoch][-1]>=55.

pretraining=None#{'method':0,'method_after_pretraining':method,'pretraining_finish_criterion':pretraining_finish_criterion}#None

if pretraining:
    pretraining['finished']=False
    
    
#upgraded_training = True
#average_gradient_of_linear_layers_enhancement=False#for method=1
average_gradient_of_nonlinear_layers_enhancement=True#for method=1
average_gradient_of_loss=False#for method=1 and iter_count = 2, seems to not affect

class NN(nn.Module):
    nonlinear_complex=()#deprecated#(torch.nn.Softmax,torch.nn.LogSoftmax)#(torch.nn.Softmax,torch.nn.LogSoftmax)#todo: delete all classes that doesn't exist among model.layers
    nonlinear_simple=(torch.nn.SELU,torch.nn.ELU,torch.nn.Sigmoid,torch.nn.Tanh)#(torch.nn.Sigmoid,torch.nn.Tanh)#(torch.nn.Sigmoid,torch.nn.ReLU)

    step_factor=1#0.347#0.336#0.39#0.9#0.39#0.85
    gradient_factor=1#0.622#0.621#0.53#0.05#0.53#0.05

    gradient_factor_simple_layers=1#0.354#0.25#0.25#0.5#0.5#0.25#1.

    small_num = (torch.finfo(torch.float32).eps)  # **0.75#sys.float_info.epsilon**0.5
    stats=None
    input_cached=None

    gradient_factor_how_many_times_faster_decreases_than_increases=4
    gradient_factor_increase_multipler=(2**(1/100))#gradient may change maximally 2 times through 300 updates (300 batches)
    gradient_factor_decrease_multipler=1/gradient_factor_increase_multipler**gradient_factor_how_many_times_faster_decreases_than_increases
    def loss_signal(self,higher:bool):
        return
        #self.step_factor=1.
        if higher:
            #self.gradient_factor=self.gradient_factor*self.gradient_factor_decrease_multipler
            self.step_factor = self.step_factor * self.gradient_factor_decrease_multipler
        else:
            #self.gradient_factor=self.gradient_factor*self.gradient_factor_increase_multipler
            #self.gradient_factor = min(self.gradient_factor * self.gradient_factor_increase_multipler, 1.)
            self.step_factor = min(self.step_factor * self.gradient_factor_increase_multipler, 1.)

        self.step_factor=min(self.step_factor,0.25)#0.25
        #self.gradient_factor=self.step_factor
        self.gradient_factor_simple_layers=self.gradient_factor

    def __init__(self,layers):
        super(NN, self).__init__()

        self.layers=layers
        self.layer_list=nn.ModuleList(self.layers)

        self.intermediate_outputs=[None for i in range(len(self.layers))]
        self.intermediate_outputs_grad=[None for i in range(len(self.layers))]


        if step_is_fraction_of_optimizer_denominator!=0.:
            for l in self.layers:
                if hasattr(l,'weight'):
                    l.weight.denominator=0.*l.weight
                    l.bias.denominator = 0. * l.bias
                    # l.weight.denominator.requires_grad=False
                    # l.bias.denominator.requires_grad = False
        #torch.autograd.set_detect_anomaly(True)


    @torch.no_grad()#copy of gradient_calc
    def gradient_calc_last_layer2(self,layer_input,layer_input_candidate,layer_input_candidate_grad,layer_out,layer_out_candidate_gradient,layer,loss_fn,target):
        input_mod=layer_input.detach().clone()
        # layer_input_candidate=layer_input_candidate.detach().clone()
        # layer_input_candidate+=0.001
        reshaped_input = layer_input.reshape((layer_input.size()[0], -1)).detach()
        reshaped_input_mod=input_mod.reshape((input_mod.size()[0],-1)).detach()
        reshaped_input_candidate=layer_input_candidate.reshape((input_mod.size()[0],-1)).detach()
        #reshaped_input_grad=layer_input.grad.reshape((layer_input.grad.size()[0],-1)).detach()
        reshaped_input_grad=layer_input_candidate_grad.reshape((layer_input_candidate_grad.size()[0],-1)).detach()

        reshaped_output_candidate_gradient=layer_out_candidate_gradient.reshape((layer_out_candidate_gradient.size()[0],-1)).detach()
        #last_index=-1
        #out=torch.empty_like(layer_input)
        for i in range(reshaped_input_mod.size()[1]):
            if i>0:
                reshaped_input_mod[:,i-1]=reshaped_input[:,i-1].detach()


            out1 = layer(input_mod)
            loss1 = loss_fn(out1, target)

            #reshaped_input_mod[:, i] = reshaped_input[:, i].detach() + 0.0001 #test
            #reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach() #original one
            reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach()*self.step_factor+reshaped_input[:, i]*(1.-self.step_factor)

            out = layer(input_mod)
            # out=nn.LogSoftmax(dim=1)(input_mod.detach().clone())
            loss2 = loss_fn(out, target)

            for j in range(reshaped_input_grad.size()[0]):
                diff=reshaped_input_mod[j,i]-reshaped_input[j,i]
                if torch.abs(diff)>self.small_num:
                    #reshaped_input_grad[j,i]=torch.sum((out[j].detach()-layer_out[j].detach())*layer_out.grad[j].detach()/diff)
                    reshaped_input_grad[j, i] =reshaped_input_grad[j, i]*(1-self.gradient_factor)+self.gradient_factor* (loss2 - loss1) / diff.detach()
                #else:
                #    print("Diff<=very small number")
    @torch.no_grad()
    def gradient_calc_last_layer(self, layer_input, layer_input_candidate, layer_input_candidate_grad, layer_out,
                      layer_out_candidate_gradient, layer,loss_fn,target):#doesn't work well
        input_mod = layer_input.detach().clone()
        # layer_input_candidate=layer_input_candidate.detach().clone()
        # layer_input_candidate+=0.001
        reshaped_input = layer_input.reshape((layer_input.size()[0], -1)).detach()
        reshaped_input_mod = input_mod.reshape((input_mod.size()[0], -1)).detach()
        reshaped_input_candidate = layer_input_candidate.reshape((input_mod.size()[0], -1)).detach()
        # reshaped_input_grad=layer_input.grad.reshape((layer_input.grad.size()[0],-1)).detach()
        reshaped_input_grad = layer_input_candidate_grad.reshape((layer_input_candidate_grad.size()[0], -1)).detach()

        reshaped_output_candidate_gradient = layer_out_candidate_gradient.reshape(
            (layer_out_candidate_gradient.size()[0], -1)).detach()
        # last_index=-1
        # out=torch.empty_like(layer_input)


        for i in range(reshaped_input_mod.size()[1]):
            if i > 0:
                reshaped_input_mod[:, i - 1] = reshaped_input[:, i - 1].detach()

            out1 = layer(input_mod)
            loss1 = loss_fn(out1, target)

            # reshaped_input_mod[:, i] = reshaped_input[:, i].detach() + 0.0001 #test
            #reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach() #original one
            reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach() * 0.5 + reshaped_input[:, i] * 0.5


            out = layer(input_mod)
            # out=nn.LogSoftmax(dim=1)(input_mod.detach().clone())
            loss2 = loss_fn(out,target)

            for j in range(reshaped_input_grad.size()[0]):
                diff = reshaped_input_mod[j, i] - reshaped_input[j, i]
                if torch.abs(diff) > self.small_num:
                    # reshaped_input_grad[j,i]=torch.sum((out[j].detach()-layer_out[j].detach())*layer_out.grad[j].detach()/diff)
                    #reshaped_input_grad[j, i] = reshaped_input_grad[j, i] * 0.5 + 0.5 * (loss2-loss1) / diff.detach()
                    reshaped_input_grad[j, i] =reshaped_input_grad[j, i]*0.5+0.5* (loss2 - loss1) / diff.detach()
                # else:
                #     print("Diff<=very small number")


    @torch.no_grad()
    def gradient_calc_simple(self,layer_input,layer_input_candidate,layer_input_candidate_grad,layer_out,layer_out_candidate_gradient,layer):
        # #input_mod=layer_input.detach().clone()
        # # layer_input_candidate=layer_input_candidate.detach().clone()
        # # layer_input_candidate+=0.001
        # reshaped_input = layer_input.reshape((layer_input.size()[0], -1)).detach()
        # #reshaped_input_mod=input_mod.reshape((input_mod.size()[0],-1)).detach()
        # reshaped_input_candidate=layer_input_candidate.reshape((layer_input_candidate.size()[0],-1)).detach()
        # #reshaped_input_grad=layer_input.grad.reshape((layer_input.grad.size()[0],-1)).detach()
        # reshaped_input_grad=layer_input_candidate_grad.reshape((layer_input_candidate_grad.size()[0],-1)).detach()
        #
        # reshaped_output_candidate_gradient=layer_out_candidate_gradient.reshape((layer_out_candidate_gradient.size()[0],-1)).detach()
        # #last_index=-1
        # #out=torch.empty_like(layer_input)
        out = layer(layer_input_candidate)

        layer_input_candidate_grad[:]=torch.where(torch.abs(layer_input_candidate-layer_input)>self.small_num,
                                                  layer_input_candidate_grad*(1.-self.gradient_factor_simple_layers)+self.gradient_factor_simple_layers*(out-layer_out)*layer_out_candidate_gradient/(layer_input_candidate-layer_input),
                                                  layer_input_candidate_grad)

        # for i in range(reshaped_input.size()[1]):#that code maybe works
        #
        #     for j in range(reshaped_input_grad.size()[0]):
        #         diff=reshaped_input_candidate[j,i]-reshaped_input[j,i]
        #         if torch.abs(diff)>self.small_num:
        #             #reshaped_input_grad[j,i]=torch.sum((out[j].detach()-layer_out[j].detach())*layer_out.grad[j].detach()/diff)
        #             reshaped_input_grad[j,i]=reshaped_input_grad[j,i].detach()*(1.-self.gradient_factor)+self.gradient_factor*((out[j,i].detach()-layer_out[j,i].detach())*reshaped_output_candidate_gradient[j,i].detach()/diff.detach())
        #         #else:
        #         #    print("Diff<=very small number")
        #         #else default gradient value is kept
    @torch.no_grad()
    def gradient_calc(self,layer_input,layer_input_candidate,layer_input_candidate_grad,layer_out,layer_out_candidate_gradient,layer):
        input_mod=layer_input.detach().clone()
        # layer_input_candidate=layer_input_candidate.detach().clone()
        # layer_input_candidate+=0.001
        reshaped_input = layer_input.reshape((layer_input.size()[0], -1)).detach()
        reshaped_input_mod=input_mod.reshape((input_mod.size()[0],-1)).detach()
        reshaped_input_candidate=layer_input_candidate.reshape((input_mod.size()[0],-1)).detach()
        #reshaped_input_grad=layer_input.grad.reshape((layer_input.grad.size()[0],-1)).detach()
        reshaped_input_grad=layer_input_candidate_grad.reshape((layer_input_candidate_grad.size()[0],-1)).detach()

        reshaped_output_candidate_gradient=layer_out_candidate_gradient.reshape((layer_out_candidate_gradient.size()[0],-1)).detach()
        #last_index=-1
        #out=torch.empty_like(layer_input)
        for i in range(reshaped_input_mod.size()[1]):
            if i>0:
                reshaped_input_mod[:,i-1]=reshaped_input[:,i-1].detach()

            #reshaped_input_mod[:, i] = reshaped_input[:, i].detach() + 0.0001 #test
            #reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach() #original one
            reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach()*self.step_factor+reshaped_input[:, i]*(1.-self.step_factor)

            out=layer(input_mod)
            #out=nn.LogSoftmax(dim=1)(input_mod.detach().clone())

            for j in range(reshaped_input_grad.size()[0]):
                diff=reshaped_input_mod[j,i]-reshaped_input[j,i]
                if torch.abs(diff)>self.small_num:
                    #reshaped_input_grad[j,i]=torch.sum((out[j].detach()-layer_out[j].detach())*layer_out.grad[j].detach()/diff)
                    reshaped_input_grad[j,i]=reshaped_input_grad[j,i]*(1.-self.gradient_factor)+self.gradient_factor*torch.sum((out[j].detach()-layer_out[j].detach())*reshaped_output_candidate_gradient[j].detach()/diff.detach())
                # else:
                #     print("Diff<=very small number")
                #else default gradient value is kept

            #reshaped_input_grad[:, i]=1
            #last_index=i

        #reshaped_input_mod[:, reshaped_input_mod.size()[1]-1] = reshaped_input[:, reshaped_input_mod.size()[1]-1].detach()
        #reshaped_input_grad._version-=1
    def backward_grad_correction_with_weight_change2(self,loss,model2,weight_change=True,accumulate_gradients=False):
        #start_ind = len(self.intermediate_outputs) - 2 if loss_fn==F.cross_entropy else len(self.intermediate_outputs) - 1
        start_ind = len(self.intermediate_outputs) - 1
        ind = start_ind#len(self.intermediate_outputs) - 1
        grad = torch.autograd.grad(outputs=loss,
                                   inputs=self.intermediate_outputs[start_ind],
                                   # inputs=self.intermediate_outputs[0],
                                   grad_outputs=None
                                   , retain_graph=True)
        self.intermediate_outputs_grad[ind] = grad[0]

        if average_gradient_of_loss:
            self.intermediate_outputs_grad[ind]=0.5*(grad[0]+model2.intermediate_outputs_grad[ind])
        #print(model2.intermediate_outputs_grad[ind])
        #enhanced=False#True
        # if enhanced:
            # # aa=loss_fn(self.intermediate_outputs[ind],target,reduction='none')
            # # aa=zip(self.intermediate_outputs[ind],target).apply_(lambda a:a*a)
            # loss_tensor = torch.zeros(self.intermediate_outputs[ind].size())
            # loss_tensor2 = torch.zeros(model2.intermediate_outputs[ind].size())
            # for i in range(self.intermediate_outputs[ind].size()[0]):
            #     # self.intermediate_outputs[ind][i]=loss_fn(self.intermediate_outputs[ind][i],target[i],reduction='none')
            #     loss_tensor[i] = loss_fn(self.intermediate_outputs[ind][i], target[i], reduction='none')
            #     #loss_tensor[i]=-self.intermediate_outputs[ind][i]*torch.log(target[i])
            #     loss_tensor2[i] = loss_fn(model2.intermediate_outputs[ind][i], target[i], reduction='none')
            # print(str(torch.sum(loss_tensor)- loss_fn(self.intermediate_outputs[ind][i], target[i], reduction='sum')))
            # with torch.no_grad():
            #     # self.intermediate_outputs_grad[ind][:] = torch.where(
            #     #     torch.abs(self.intermediate_outputs[ind] - model2.intermediate_outputs[ind]) > self.small_num,
            #     #     self.intermediate_outputs_grad[ind][:] * (
            #     #                 1. - self.gradient_factor_simple_layers) + self.gradient_factor_simple_layers *
            #     #                 (loss_tensor - loss_tensor2) / (self.intermediate_outputs[ind] - model2.intermediate_outputs[ind]),
            #     #     self.intermediate_outputs_grad[ind][:])
            #     self.intermediate_outputs_grad[ind][:] = torch.where(
            #         torch.abs(self.intermediate_outputs[ind] - model2.intermediate_outputs[ind]) > self.small_num,
            #         (loss_tensor - loss_tensor2) / (self.intermediate_outputs[ind] - model2.intermediate_outputs[ind]),
            #         self.intermediate_outputs_grad[ind][:])
            # grad = torch.autograd.grad(outputs=model2_loss,
            #                            inputs=model2.intermediate_outputs[start_ind],
            #                            # inputs=self.intermediate_outputs[0],
            #                            grad_outputs=None
            #                            , retain_graph=True)
            # #model2.intermediate_outputs_grad[ind] = grad[0]
            # self.intermediate_outputs_grad[ind][:]=.5*(self.intermediate_outputs_grad[ind][:]+grad[0][:])

        for ind in range(start_ind, -1, -1):
            if hasattr(self.layers[ind], 'weight'):


                if ind>0:
                    self.intermediate_outputs[ind] = self.layers[ind](self.intermediate_outputs[ind-1])
                else:
                    self.intermediate_outputs[ind]=self.layers[ind](self.input_cached)

                grad_outputs = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                                   inputs=[self.layers[ind].weight,self.layers[ind].bias],
                                                   # inputs=self.intermediate_outputs[0],
                                                   grad_outputs=self.intermediate_outputs_grad[ind]
                                                   , retain_graph=True)

                if accumulate_gradients and self.layers[ind].weight.grad is not None:
                    self.layers[ind].weight.grad += grad_outputs[0]
                    self.layers[ind].bias.grad+=grad_outputs[1]
                else:
                    self.layers[ind].weight.grad = grad_outputs[0]
                    self.layers[ind].bias.grad = grad_outputs[1]



                #self.layers[ind].weight = model2.layers[ind].weight#.clone().detach()
                #self.layers[ind].bias = model2.layers[ind].bias#.clone().detach()

                #self.layers[ind].weight.requires_grad=False
                #self.layers[ind].bias.requires_grad=False
                #self.layers[ind].weight[:]=1#torch.abs(self.layers[ind].weight-model2.layers[ind].weight)*torch.sign(self.layers[ind].weight.grad)
                #self.layers[ind].bias[:]=1#torch.abs(self.layers[ind].bias-model2.layers[ind].bias)*torch.sign(self.layers[ind].bias.grad)
                # self.layers[ind].weight.requires_grad = True
                # self.layers[ind].bias.requires_grad = True

                #model2.layers[ind].weight.grad=grad_outputs[0].detach()#.clone()
                #model2.layers[ind].bias.grad=grad_outputs[1].detach()

                # if ind > 0:
                #     grad_layers = torch.autograd.grad(outputs=self.layers[ind].weight,
                #                                       inputs=self.intermediate_outputs[ind - 1],
                #                                       # inputs=self.intermediate_outputs[0],
                #                                       grad_outputs=self.layers[ind].weight.grad
                #                                       , retain_graph=True)
                #     self.intermediate_outputs[ind - 1].grad = grad_layers[0]


            if ind > 0:
                #self.intermediate_outputs[ind]+=0.
                #self.intermediate_outputs[ind - 1]+=0.


                grad_layers = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                                  inputs=self.intermediate_outputs[ind - 1],
                                                  # inputs=self.intermediate_outputs[0],
                                                  grad_outputs=self.intermediate_outputs_grad[ind]
                                                  , retain_graph=True,
                                                  # allow_unused=True
                                                  )
                self.intermediate_outputs_grad[ind - 1] = grad_layers[0]



                # if ind==len(self.intermediate_outputs)-1:
                #     self.gradient_calc(self.intermediate_outputs[ind - 1],self.intermediate_outputs[ind - 1],
                #                        self.intermediate_outputs[ind],self.layers[ind])
                # pass

                #enhanced=True
                if average_gradient_of_nonlinear_layers_enhancement:

                    if isinstance(self.layers[ind], self.nonlinear_complex):#deprecated
                        # self.gradient_calc(model2.intermediate_outputs[ind - 1], self.intermediate_outputs[ind - 1],
                        #                    model2.intermediate_outputs[ind],self.intermediate_outputs[ind].grad, self.layers[ind])

                        # self.gradient_calc(model2.intermediate_outputs[ind - 1], self.intermediate_outputs[ind - 1],
                        #                    self.intermediate_outputs_grad[ind - 1],
                        #                    model2.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                        #                    self.layers[ind])
                        log=False
                        if log:
                            print("----")
                            print(self.intermediate_outputs_grad[ind - 1])
                        self.gradient_calc(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                                           self.intermediate_outputs_grad[ind - 1],
                                           self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                                           self.layers[ind])
                        # self.gradient_calc_last_layer2(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                        #                    self.intermediate_outputs_grad[ind - 1],
                        #                    self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                        #                    self.layers[ind],loss_fn,target)
                        # self.intermediate_outputs_grad[ind - 1]=0.5*self.intermediate_outputs_grad[ind - 1]+\
                        #                                         model2.intermediate_outputs[ind - 1].grad * 0.5
                        # self.gradient_calc_last_layer(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                        #                    self.intermediate_outputs_grad[ind - 1],
                        #                    self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                        #                    self.layers[ind],loss_fn,target)
                        if log:
                            print(self.intermediate_outputs_grad[ind - 1])
                        #self.intermediate_outputs_grad[ind - 1]=self.intermediate_outputs_grad[ind - 1]*torch.norm(model2.intermediate_outputs_grad[ind - 1],1)/torch.norm(self.intermediate_outputs_grad[ind - 1],1)

                    elif isinstance(self.layers[ind], self.nonlinear_simple):
                        self.gradient_calc_simple(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                                           self.intermediate_outputs_grad[ind - 1],
                                           self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                                           self.layers[ind])
                        #self.intermediate_outputs_grad[ind - 1]=self.intermediate_outputs_grad[ind - 1]*torch.norm(model2.intermediate_outputs_grad[ind - 1],1)/torch.norm(self.intermediate_outputs_grad[ind - 1],1)

                #self.intermediate_outputs[ind - 1].grad.detach()

            if weight_change and hasattr(self.layers[ind], 'weight'):
                self.layers[ind].weight.requires_grad=False
                self.layers[ind].bias.requires_grad=False

                if step_is_fraction_of_optimizer_denominator!=0.:
                    #self.layers[ind].weight.denominator=self.layers[ind].weight.grad.detach()*self.layers[ind].weight.grad.detach()+denominator_mul*self.layers[ind].weight.denominator.detach()
                    #self.layers[ind].bias.denominator = self.layers[ind].bias.grad.detach() * self.layers[ind].bias.grad.detach()+denominator_mul*self.layers[ind].bias.denominator.detach()
                    # self.layers[ind].weight[:] -= torch.min(step_is_fraction_of_optimizer_denominator*torch.sqrt(self.layers[ind].weight.denominator.detach()),torch.abs(
                    #     self.layers[ind].weight.detach() - model2.layers[ind].weight.detach())) * torch.sign(
                    #     self.layers[ind].weight.grad.detach())
                    # self.layers[ind].bias[:] -= torch.min(step_is_fraction_of_optimizer_denominator*torch.sqrt(self.layers[ind].bias.denominator.detach()),torch.abs(
                    #     self.layers[ind].bias.detach() - model2.layers[ind].bias.detach())) * torch.sign(
                    #     self.layers[ind].bias.grad.detach())
                    diff_w=self.layers[ind].weight.grad.detach()-model2.layers[ind].weight.grad.detach()
                    self.layers[ind].weight.denominator=diff_w*diff_w+denominator_mul*self.layers[ind].weight.denominator.detach()
                    diff_b=self.layers[ind].bias.grad.detach()-model2.layers[ind].bias.grad.detach()
                    self.layers[ind].bias.denominator = diff_b*diff_b+denominator_mul*self.layers[ind].bias.denominator.detach()

                    self.layers[ind].weight[:] = model2.layers[ind].weight.detach()+torch.where(self.layers[ind].weight.denominator!=0.,torch.where(
                        model2.layers[ind].weight.grad!=0.,
                        torch.clip(
                            (self.layers[ind].weight.grad.detach()-model2.layers[ind].weight.grad.detach())/self.layers[ind].weight.denominator.detach(),
                            #1./(1.+step_is_fraction_of_optimizer_denominator)-1.,
                            -step_is_fraction_of_optimizer_denominator,
                            step_is_fraction_of_optimizer_denominator)\
                            *(model2.layers[ind].weight.detach()-self.layers[ind].weight.detach())*torch.abs(self.layers[ind].weight.denominator.detach()/model2.layers[ind].weight.grad.detach()),
                        0.),0.)
                    self.layers[ind].bias[:] = model2.layers[ind].bias.detach() + torch.where(self.layers[ind].bias.denominator!=0.,torch.where(
                        model2.layers[ind].bias.grad != 0.,
                        torch.clip(
                            (self.layers[ind].bias.grad.detach() - model2.layers[ind].bias.grad.detach()) /
                            self.layers[ind].bias.denominator.detach(),
                            #1. / (1. + step_is_fraction_of_optimizer_denominator) - 1.,
                             -step_is_fraction_of_optimizer_denominator,
                            step_is_fraction_of_optimizer_denominator) \
                        * (model2.layers[ind].bias.detach() - self.layers[ind].bias.detach()) * torch.abs(
                            self.layers[ind].bias.denominator.detach() / model2.layers[ind].bias.grad.detach()),
                        0.),0.)
                    # self.layers[ind].weight[:] = torch.where(
                    #     model2.layers[ind].weight.grad*self.layers[ind].weight.grad > 0.,
                    #     model2.layers[ind].weight.detach(),
                    #     self.layers[ind].weight.detach()-torch.abs(self.layers[ind].weight.detach()-model2.layers[ind].weight.detach()))
                    # self.layers[ind].bias[:] = torch.where(
                    #     model2.layers[ind].bias.grad * self.layers[ind].bias.grad > 0.,
                    #     model2.layers[ind].bias.detach(),
                    #     self.layers[ind].bias.detach() - torch.abs(
                    #         self.layers[ind].bias.detach() - model2.layers[ind].bias.detach()))

                else:
                    self.layers[ind].weight[:] -= torch.abs(
                        self.layers[ind].weight.detach() - model2.layers[ind].weight.detach()) * torch.sign(
                        self.layers[ind].weight.grad.detach())
                    self.layers[ind].bias[:] -= torch.abs(
                        self.layers[ind].bias.detach() - model2.layers[ind].bias.detach()) * torch.sign(
                        self.layers[ind].bias.grad.detach())

                #self.layers[ind].weight[:] =model2.layers[ind].weight.detach()
                # self.layers[ind].bias[:] = model2.layers[ind].bias.detach()
                self.layers[ind].weight.requires_grad = True
                self.layers[ind].bias.requires_grad = True

    def backward_grad_correction_with_weight_change(self,loss,model2,loss_fn,target):
        start_ind = len(self.intermediate_outputs) - 2 if loss_fn==F.cross_entropy else len(self.intermediate_outputs) - 1
        ind = start_ind#len(self.intermediate_outputs) - 1
        grad = torch.autograd.grad(outputs=loss,
                                   inputs=self.intermediate_outputs[start_ind],
                                   # inputs=self.intermediate_outputs[0],
                                   grad_outputs=None
                                   , retain_graph=True)
        self.intermediate_outputs_grad[ind] = grad[0]
        for ind in range(start_ind, -1, -1):
            if hasattr(self.layers[ind], 'weight'):
                if ind>0:
                    self.intermediate_outputs[ind] = self.layers[ind](self.intermediate_outputs[ind-1])
                else:
                    self.intermediate_outputs[ind]=self.layers[ind](self.input_cached)
                grad_outputs = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                                   inputs=[self.layers[ind].weight,self.layers[ind].bias],
                                                   # inputs=self.intermediate_outputs[0],
                                                   grad_outputs=self.intermediate_outputs_grad[ind]
                                                   , retain_graph=True)
                self.layers[ind].weight.grad = grad_outputs[0]
                self.layers[ind].bias.grad=grad_outputs[1]

                #self.layers[ind].weight = model2.layers[ind].weight#.clone().detach()
                #self.layers[ind].bias = model2.layers[ind].bias#.clone().detach()

                #self.layers[ind].weight.requires_grad=False
                #self.layers[ind].bias.requires_grad=False
                #self.layers[ind].weight[:]=1#torch.abs(self.layers[ind].weight-model2.layers[ind].weight)*torch.sign(self.layers[ind].weight.grad)
                #self.layers[ind].bias[:]=1#torch.abs(self.layers[ind].bias-model2.layers[ind].bias)*torch.sign(self.layers[ind].bias.grad)
                # self.layers[ind].weight.requires_grad = True
                # self.layers[ind].bias.requires_grad = True

                #model2.layers[ind].weight.grad=grad_outputs[0].detach()#.clone()
                #model2.layers[ind].bias.grad=grad_outputs[1].detach()

                # if ind > 0:
                #     grad_layers = torch.autograd.grad(outputs=self.layers[ind].weight,
                #                                       inputs=self.intermediate_outputs[ind - 1],
                #                                       # inputs=self.intermediate_outputs[0],
                #                                       grad_outputs=self.layers[ind].weight.grad
                #                                       , retain_graph=True)
                #     self.intermediate_outputs[ind - 1].grad = grad_layers[0]
            if ind > 0:
                #self.intermediate_outputs[ind]+=0.
                #self.intermediate_outputs[ind - 1]+=0.
                grad_layers = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                                  inputs=self.intermediate_outputs[ind - 1],
                                                  # inputs=self.intermediate_outputs[0],
                                                  grad_outputs=self.intermediate_outputs_grad[ind]
                                                  , retain_graph=True,
                                                  # allow_unused=True
                                                  )
                self.intermediate_outputs_grad[ind - 1] = grad_layers[0]

                # if ind==len(self.intermediate_outputs)-1:
                #     self.gradient_calc(self.intermediate_outputs[ind - 1],self.intermediate_outputs[ind - 1],
                #                        self.intermediate_outputs[ind],self.layers[ind])
                # pass

                #enhanced=average_gradient_of_nonlinear_layers_enhancement
                if average_gradient_of_nonlinear_layers_enhancement:
                    if isinstance(self.layers[ind], self.nonlinear_complex):
                        # self.gradient_calc(model2.intermediate_outputs[ind - 1], self.intermediate_outputs[ind - 1],
                        #                    model2.intermediate_outputs[ind],self.intermediate_outputs[ind].grad, self.layers[ind])

                        # self.gradient_calc(model2.intermediate_outputs[ind - 1], self.intermediate_outputs[ind - 1],
                        #                    self.intermediate_outputs_grad[ind - 1],
                        #                    model2.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                        #                    self.layers[ind])
                        log=False
                        if log:
                            print("----")
                            print(self.intermediate_outputs_grad[ind - 1])
                        self.gradient_calc(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                                           self.intermediate_outputs_grad[ind - 1],
                                           self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                                           self.layers[ind])
                        # self.gradient_calc_last_layer2(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                        #                    self.intermediate_outputs_grad[ind - 1],
                        #                    self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                        #                    self.layers[ind],loss_fn,target)
                        # self.intermediate_outputs_grad[ind - 1]=0.5*self.intermediate_outputs_grad[ind - 1]+\
                        #                                         model2.intermediate_outputs[ind - 1].grad * 0.5
                        # self.gradient_calc_last_layer(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                        #                    self.intermediate_outputs_grad[ind - 1],
                        #                    self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                        #                    self.layers[ind],loss_fn,target)
                        if log:
                            print(self.intermediate_outputs_grad[ind - 1])
                        #self.intermediate_outputs_grad[ind - 1]=self.intermediate_outputs_grad[ind - 1]*torch.norm(model2.intermediate_outputs_grad[ind - 1],1)/torch.norm(self.intermediate_outputs_grad[ind - 1],1)

                    elif isinstance(self.layers[ind], self.nonlinear_simple):
                        self.gradient_calc_simple(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                                           self.intermediate_outputs_grad[ind - 1],
                                           self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                                           self.layers[ind])
                        #self.intermediate_outputs_grad[ind - 1]=self.intermediate_outputs_grad[ind - 1]*torch.norm(model2.intermediate_outputs_grad[ind - 1],1)/torch.norm(self.intermediate_outputs_grad[ind - 1],1)
                #self.intermediate_outputs[ind - 1].grad.detach()

            if hasattr(self.layers[ind], 'weight'):
                #with torch.no_grad():
                self.layers[ind].weight.requires_grad=False
                self.layers[ind].bias.requires_grad=False
                self.layers[ind].weight[:]-=torch.abs(self.layers[ind].weight.detach()-model2.layers[ind].weight.detach())*torch.sign(self.layers[ind].weight.grad.detach())
                self.layers[ind].bias[:]-=torch.abs(self.layers[ind].bias.detach()-model2.layers[ind].bias.detach())*torch.sign(self.layers[ind].bias.grad.detach())
                #self.layers[ind].weight[:] =model2.layers[ind].weight.detach()
                # self.layers[ind].bias[:] = model2.layers[ind].bias.detach()
                self.layers[ind].weight.requires_grad = True
                self.layers[ind].bias.requires_grad = True

    def backward_grad_correction(self,loss,model2,loss_fn,target):
        ind = len(self.intermediate_outputs) - 1
        grad = torch.autograd.grad(outputs=loss,
                                   inputs=self.intermediate_outputs[len(self.intermediate_outputs) - 1],
                                   # inputs=self.intermediate_outputs[0],
                                   grad_outputs=None
                                   , retain_graph=True)
        self.intermediate_outputs_grad[ind] = grad[0]
        for ind in range(len(self.intermediate_outputs) - 1, -1, -1):
            if hasattr(self.layers[ind], 'weight'):
                if ind>0:
                    self.intermediate_outputs[ind] = self.layers[ind](self.intermediate_outputs[ind-1])
                else:
                    self.intermediate_outputs[ind]=self.layers[ind](self.input_cached)
                grad_outputs = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                                   inputs=[self.layers[ind].weight,self.layers[ind].bias],
                                                   # inputs=self.intermediate_outputs[0],
                                                   grad_outputs=self.intermediate_outputs_grad[ind]
                                                   , retain_graph=True)
                self.layers[ind].weight.grad = grad_outputs[0]
                self.layers[ind].bias.grad=grad_outputs[1]

                model2.layers[ind].weight.grad=grad_outputs[0].detach()#.clone()
                model2.layers[ind].bias.grad=grad_outputs[1].detach()

                # if ind > 0:
                #     grad_layers = torch.autograd.grad(outputs=self.layers[ind].weight,
                #                                       inputs=self.intermediate_outputs[ind - 1],
                #                                       # inputs=self.intermediate_outputs[0],
                #                                       grad_outputs=self.layers[ind].weight.grad
                #                                       , retain_graph=True)
                #     self.intermediate_outputs[ind - 1].grad = grad_layers[0]
            if ind > 0:
                grad_layers = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                                  inputs=self.intermediate_outputs[ind - 1],
                                                  # inputs=self.intermediate_outputs[0],
                                                  grad_outputs=self.intermediate_outputs_grad[ind]
                                                  , retain_graph=True,
                                                  #allow_unused=True
                                                  )
                self.intermediate_outputs_grad[ind - 1] = grad_layers[0]

                # if ind==len(self.intermediate_outputs)-1:
                #     self.gradient_calc(self.intermediate_outputs[ind - 1],self.intermediate_outputs[ind - 1],
                #                        self.intermediate_outputs[ind],self.layers[ind])
                # pass

                #enhanced=True
                if average_gradient_of_nonlinear_layers_enhancement:
                    if isinstance(self.layers[ind], self.nonlinear_complex):
                        # self.gradient_calc(model2.intermediate_outputs[ind - 1], self.intermediate_outputs[ind - 1],
                        #                    model2.intermediate_outputs[ind],self.intermediate_outputs[ind].grad, self.layers[ind])

                        # self.gradient_calc(model2.intermediate_outputs[ind - 1], self.intermediate_outputs[ind - 1],
                        #                    self.intermediate_outputs_grad[ind - 1],
                        #                    model2.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                        #                    self.layers[ind])
                        log=False
                        if log:
                            print("----")
                            print(self.intermediate_outputs_grad[ind - 1])
                        self.gradient_calc(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                                           self.intermediate_outputs_grad[ind - 1],
                                           self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                                           self.layers[ind])
                        # self.gradient_calc_last_layer2(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                        #                    self.intermediate_outputs_grad[ind - 1],
                        #                    self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                        #                    self.layers[ind],loss_fn,target)
                        # self.intermediate_outputs_grad[ind - 1]=0.5*self.intermediate_outputs_grad[ind - 1]+\
                        #                                         model2.intermediate_outputs[ind - 1].grad * 0.5
                        # self.gradient_calc_last_layer(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                        #                    self.intermediate_outputs_grad[ind - 1],
                        #                    self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                        #                    self.layers[ind],loss_fn,target)
                        if log:
                            print(self.intermediate_outputs_grad[ind - 1])
                        #self.intermediate_outputs_grad[ind - 1]=self.intermediate_outputs_grad[ind - 1]*torch.norm(model2.intermediate_outputs_grad[ind - 1],1)/torch.norm(self.intermediate_outputs_grad[ind - 1],1)

                    elif isinstance(self.layers[ind], self.nonlinear_simple):
                        self.gradient_calc_simple(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                                           self.intermediate_outputs_grad[ind - 1],
                                           self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                                           self.layers[ind])
                        #self.intermediate_outputs_grad[ind - 1]=self.intermediate_outputs_grad[ind - 1]*torch.norm(model2.intermediate_outputs_grad[ind - 1],1)/torch.norm(self.intermediate_outputs_grad[ind - 1],1)
                #self.intermediate_outputs[ind - 1].grad.detach()

    def backward(self,loss):
        start_ind = len(self.intermediate_outputs) - 1
        #start_ind=len(self.intermediate_outputs)-2 if
        ind=start_ind
        grad = torch.autograd.grad(outputs=loss,
                                   inputs=self.intermediate_outputs[start_ind],
                                   # inputs=self.intermediate_outputs[0],
                                   grad_outputs=None
                                   , retain_graph=True)
        self.intermediate_outputs[ind].grad=grad[0]
        if average_gradient_of_loss:
            self.intermediate_outputs_grad[ind]=grad[0]

        for ind in range(start_ind,-1,-1):
            if hasattr(self.layers[ind],'weight'):
                grad_outputs = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                           inputs=(self.layers[ind].weight,self.layers[ind].bias),
                                           # inputs=self.intermediate_outputs[0],
                                           grad_outputs=self.intermediate_outputs[ind].grad
                                           , retain_graph=True)
                self.layers[ind].weight.grad=grad_outputs[0]
                self.layers[ind].bias.grad = grad_outputs[1]
                # if ind > 0:
                #     grad_layers = torch.autograd.grad(outputs=self.layers[ind].weight,
                #                                       inputs=self.intermediate_outputs[ind - 1],
                #                                       # inputs=self.intermediate_outputs[0],
                #                                       grad_outputs=self.layers[ind].weight.grad
                #                                       , retain_graph=True)
                #     self.intermediate_outputs[ind - 1].grad = grad_layers[0]
            if ind > 0:
                grad_layers = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                                  inputs=self.intermediate_outputs[ind - 1],
                                                  # inputs=self.intermediate_outputs[0],
                                                  grad_outputs=self.intermediate_outputs[ind].grad
                                                  , retain_graph=True)
                self.intermediate_outputs[ind - 1].grad = grad_layers[0]


    # def train(self, mode: bool = True):
    #     super().train(mode)
    def forward(self, x):
        self.input_cached=x
        #layers_num=len(self.layers)
        #for i in range(layers_num-1):
        #    x=self.layers[i](x)
        #output=self.layers[layers_num-1](x)
        #return output

        i=0
        for layer in self.layers:
            if self.intermediate_outputs[i] is None:
                x=layer(x)
                #self.intermediate_outputs[i]=(Variable(x, requires_grad=True))
                self.intermediate_outputs[i]=x
                self.intermediate_outputs[i].require_grad=True

                #self.intermediate_outputs_grad[i]=torch.empty_like(x)
            else:
                self.intermediate_outputs[i] = layer(x)
                x=self.intermediate_outputs[i]
            i+=1
            if i==len(self.layers)-1 and isinstance(self.layers[len(self.layers)-1],torch.nn.Softmax) and self.training:
                break
        return x

    def copy_grad_from(self,model):
        for ind in range(0,len(self.layers)):
            if hasattr(self.layers[ind],'weight') and hasattr(model.layers[ind],'weight'):
                if hasattr(self.layers[ind].weight,'grad') and self.layers[ind].weight.grad is not None:
                    self.layers[ind].weight.grad[:]=model.layers[ind].weight.grad[:]
                else:
                    self.layers[ind].weight.grad = model.layers[ind].weight.grad.detach().clone()
                    #self.layers[ind].weight.grad = model.layers[ind].weight.grad.detach()
            if hasattr(self.layers[ind],'bias') and hasattr(model.layers[ind],'bias'):
                if hasattr(self.layers[ind].bias,'grad') and self.layers[ind].bias.grad is not None:
                    self.layers[ind].bias.grad[:]=model.layers[ind].bias.grad[:]
                else:
                    self.layers[ind].bias.grad = model.layers[ind].bias.grad.detach().clone()
                    #self.layers[ind].bias.grad = model.layers[ind].bias.grad.detach()

    def distance_from(self,model,l=1):
        d=float(0)
        for ind in range(0,len(self.layers)):
            if hasattr(self.layers[ind],'weight') and hasattr(model.layers[ind],'weight'):
                d+=(self.layers[ind].weight[:]-model.layers[ind].weight[:]).abs().pow(l).sum()

            if hasattr(self.layers[ind],'bias') and hasattr(model.layers[ind],'bias'):
                d+=(self.layers[ind].bias[:]-model.layers[ind].bias[:]).abs().pow(l).sum()
        d=d**(1./l)
        return d

def init_weights(layers,init_method=torch.nn.init.xavier_uniform,avg_gain=1.):
    for ind in range(0,len(layers)):
        if hasattr(layers[ind],'weight'):
            init_method(layers[ind].weight)
            if avg_gain!=1.:
                layers[ind].weight.requires_grad = False
                layers[ind].weight*=avg_gain
                layers[ind].weight.requires_grad = True

        if hasattr(layers[ind],'bias'):
            #init_method(layers[ind].bias)
            layers[ind].bias.requires_grad=False
            layers[ind].bias.zero_()
            layers[ind].bias.requires_grad=True

def train(model, device, train_loader, optimizer, epoch,model2=None,opt2=None,log_interval=10,epoch_num=-1):
    model.train()

    #optimizer.param_groups[0]['lr']/=len(train_loader.dataset)
    if model2 is None or opt2 is None:
        model2 = NN(copy.deepcopy(model.layers))
        model2.train()
        model2.load_state_dict(model.state_dict().copy())
        opt2 = type(optimizer)(model2.parameters())
        opt2.load_state_dict(optimizer.state_dict().copy())

    # model3 = NN(copy.deepcopy(model.layers))
    # model3.train()
    # model3.load_state_dict(model.state_dict().copy())
    # opt3 = type(optimizer)(model.parameters())
    # opt3.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']

    opt2.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']

    global iter_count

    model3=None
    if method==1 and iter_count>=3:
        model3 = NN(copy.deepcopy(model.layers))
        model3.train()
        model3.load_state_dict(model.state_dict().copy())
        # opt3 = type(optimizer)(model.parameters())
        # opt3.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']
    #model.double()
    #model2.double()

    if method==0:
        optimizer.zero_grad()
        # if show_grad and model.layers[0].weight.grad is not None:
        #     print(model.layers[0].weight.grad[0][0][0])
        loss_sum_log=0.
        loss_count=0
        total_loss_sum=0.
        total_loss_count=0
        #torch.manual_seed(3)
        for batch_idx, (data, target) in enumerate(train_loader):
            #data, target = data.double().to(device), target.to(device)
            data, target = data.to(device), target.to(device)


            # data.requires_grad=True
            output = model(data)
            loss = F.cross_entropy(output, target,reduction='sum')  # F.nll_loss(output, target)
            # loss.requires_grad=True
            # print(hash(model.named_parameters()))
            # print(dict(model.named_parameters()).keys())

            # torch.manual_seed(1)
            loss.backward()
            # model.backward(loss)
            # torch.manual_seed(1)
            # x=model.state_dict()
            # print(x)
            loss_sum_log+=loss.item()
            loss_count+=data.size()[0]
            total_loss_sum+=loss.item()
            total_loss_count+=data.size()[0]
            if (batch_idx+1) % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * batch_size, len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss_sum_log/loss_count))
                loss_sum_log=0.
                loss_count=0

        optimizer.step()
        if show_grad:
            print(model.layers[0].weight.grad[0][0][0])
        if not calculate_train_stats:#then save training stats
            if 'train_loss_onthefly' not in model.stats:
                model.stats['train_loss_onthefly'] = {}
            if epoch not in model.stats['train_loss_onthefly']:
                model.stats['train_loss_onthefly'][epoch] = []
            if total_loss_count!=0:
                model.stats['train_loss_onthefly'][epoch].append(total_loss_sum/total_loss_count)
    elif method==1:
        if iter_count>=3:
            model2.load_state_dict(model.state_dict())
            model3.load_state_dict(model.state_dict())
            opt2.zero_grad()

            loss_sum_log = 0.
            loss_count = 0
            total_loss_sum = 0.
            total_loss_count = 0
            for batch_idx, (data, target) in enumerate(train_loader):
                # data, target = data.double().to(device), target.to(device)
                data, target = data.to(device), target.to(device)
                output = model2(data)
                _loss = F.cross_entropy(output, target, reduction='sum')
                _loss.backward()

                loss_sum_log += _loss.item()
                loss_count += data.size()[0]
                total_loss_sum += _loss.item()
                total_loss_count += data.size()[0]
                if (batch_idx + 1) % log_interval == 0:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch, batch_idx * batch_size, len(train_loader.dataset),
                               100. * batch_idx / len(train_loader), loss_sum_log / loss_count))
                    loss_sum_log = 0.
                    loss_count = 0
            if show_grad:
                print(model2.layers[0].weight.grad[0][0][0])


            # output = model2(data)
            # _loss = F.cross_entropy(output, target)  # F.nll_loss(output, target)

            # opt2.zero_grad()
            # _loss.backward()
            # if i == 0:
            #     opt2.zero_grad()
            #     _loss.backward()
            #     # model2.backward(_loss)
            #     # loss.backward(retain_graph=True)
            # else:
            #     model2.copy_grad_from(model)

            opt2.step()
            # _output = model2(data)
            # loss_backprop = F.cross_entropy(_output, target)

            for i in range(iter_count - 1):
                # if i % 2 == 0:
                #     if i != 0:
                #         model.load_state_dict(model3.state_dict())
                #         _output = model2(data)
                #         loss = F.cross_entropy(_output, target)  # F.nll_loss(output, target)
                #     else:
                #         loss = loss_backprop
                #
                #     output1 = model(data)
                #     loss1 = F.cross_entropy(output1, target)  # F.nll_loss(output1, target)
                #
                #     model.backward_grad_correction_with_weight_change2(loss1, model2)
                # else:
                    model2.load_state_dict(model.state_dict())
                    model.load_state_dict(model3.state_dict())

                    loss_sum_log = 0.
                    loss_count = 0
                    total_loss_sum = 0.
                    total_loss_count = 0
                    for batch_idx, (data, target) in enumerate(train_loader):
                        # data, target = data.double().to(device), target.to(device)
                        data, target = data.to(device), target.to(device)
                        # output = model2(data)
                        # _loss = F.cross_entropy(output, target, reduction='sum')
                        # _loss.backward()

                        _output = model(data)
                        loss = F.cross_entropy(_output, target, reduction='sum')  # F.nll_loss(output, target)

                        output1 = model2(data)
                        #loss1 = F.cross_entropy(output1, target, reduction='sum')  # F.nll_loss(output1, target)
                        weight_change = len(train_loader.dataset) == total_loss_count
                        if weight_change:
                            print('weight change')
                        model.backward_grad_correction_with_weight_change2(loss, model2, weight_change=weight_change,
                                                                           accumulate_gradients=True)

                        loss_sum_log += loss.item()
                        loss_count += data.size()[0]
                        total_loss_sum += loss.item()
                        total_loss_count += data.size()[0]
                        if (batch_idx + 1) % log_interval == 0:
                            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                                epoch, batch_idx * batch_size, len(train_loader.dataset),
                                       100. * batch_idx / len(train_loader), loss_sum_log / loss_count))
                            loss_sum_log = 0.
                            loss_count = 0
                    if show_grad:
                        print(model.layers[0].weight.grad[0][0][0])

            # if (iter_count) % 2 == 1:
            #     model.load_state_dict(model2.state_dict())

            # output2 = model(data)
            # loss2 = F.cross_entropy(output2, target)  # F.nll_loss(output2, target)
            #
            # batch_counter += 1
            # if loss_backprop < loss2:
            #     higher_loss_batch_counter += 1
            #     # high_loss=True
            # elif loss_backprop > loss2:
            #     lower_loss_batch_counter += 1
            # loss_improvement += float(loss_backprop - loss2)
            #
            # if float(loss_backprop - _loss) != 0.:
            #     relative_loss_improvement += float((loss_backprop - loss2) / abs(loss_backprop - _loss))
        else:
            model2.load_state_dict(model.state_dict())
            opt2.zero_grad()

            loss_sum_log = 0.
            loss_count = 0
            total_loss_sum = 0.
            total_loss_count = 0
            for batch_idx, (data, target) in enumerate(train_loader):
                # data, target = data.double().to(device), target.to(device)
                data, target = data.to(device), target.to(device)
                output = model2(data)
                _loss = F.cross_entropy(output, target,reduction='sum')
                _loss.backward()

                loss_sum_log += _loss.item()
                loss_count += data.size()[0]
                total_loss_sum += _loss.item()
                total_loss_count += data.size()[0]
                if (batch_idx + 1) % log_interval == 0:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch, batch_idx * batch_size, len(train_loader.dataset),
                               100. * batch_idx / len(train_loader), loss_sum_log / loss_count))
                    loss_sum_log = 0.
                    loss_count = 0
            if show_grad:
                print(model2.layers[0].weight.grad[0][0][0])
            if not calculate_train_stats:  # then save training stats
                if 'train_loss_onthefly' not in model.stats:
                    model.stats['train_loss_onthefly'] = {}
                if epoch not in model.stats['train_loss_onthefly']:
                    model.stats['train_loss_onthefly'][epoch] = []
                if total_loss_count != 0:
                    model.stats['train_loss_onthefly'][epoch].append(total_loss_sum / total_loss_count)
            opt2.step()

            model.copy_grad_from(model2)#just to initialize tensors to hold gradients
            optimizer.zero_grad()#initialize gradients to 0.
            loss_sum_log = 0.
            loss_count = 0
            #total_loss_sum = 0.
            total_loss_count = 0
            for batch_idx, (data, target) in enumerate(train_loader):
                # data, target = data.double().to(device), target.to(device)
                data, target = data.to(device), target.to(device)

                _output = model2(data)
                loss = F.cross_entropy(_output, target,reduction='sum')#this line is just to log a loss

                output1 = model(data)
                loss1 = F.cross_entropy(output1, target,reduction='sum')

                loss_sum_log += loss.item()
                loss_count += data.size()[0]
                #total_loss_sum += loss.item()
                total_loss_count += data.size()[0]

                weight_change=len(train_loader.dataset)==total_loss_count
                if weight_change:
                    print('weight change')
                model.backward_grad_correction_with_weight_change2(loss1, model2,weight_change=weight_change,accumulate_gradients=True)
                # output2 = model(data)
                # loss2 = F.cross_entropy(output2, target)

                if (batch_idx + 1) % log_interval == 0:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch, batch_idx * batch_size, len(train_loader.dataset),
                               100. * batch_idx / len(train_loader), loss_sum_log / loss_count))
                    loss_sum_log = 0.
                    loss_count = 0
            if show_grad:
                print(model.layers[0].weight.grad[0][0][0])



total_relative_loss_improvement_denominator = 0.
total_loss_improvement = 0.
def train_minibatch(model, device, train_loader, optimizer, epoch,model2=None,opt2=None,log_interval=10):
    model.train()

    if model2 is None or opt2 is None:
        model2 = NN(copy.deepcopy(model.layers))
        model2.train()
        model2.load_state_dict(model.state_dict().copy())
        opt2 = type(optimizer)(model2.parameters())
        opt2.load_state_dict(optimizer.state_dict().copy())

    model3 = NN(copy.deepcopy(model.layers))
    model3.train()
    model3.load_state_dict(model.state_dict().copy())
    opt3 = type(optimizer)(model.parameters())
    opt3.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']

    model2.step_factor=model.step_factor
    model2.gradient_factor=model.gradient_factor
    #opt2.lr=optimizer.lr

    global total_relative_loss_improvement_denominator
    global total_loss_improvement
    if epoch==1:
        total_relative_loss_improvement_denominator = 0.
        total_loss_improvement = 0.
    relative_loss_improvement_denominator=0.
    loss_improvement=0.
    higher_loss_batch_counter=0
    lower_loss_batch_counter=0
    batch_counter=0

    opt2.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']

    common_optimizer_parameters_enhancement=False
    if common_optimizer_parameters_enhancement and method==1:
        opt2.param_groups[0]['alpha'] = optimizer.param_groups[0]['alpha']**0.5

    global iter_count

    take_optimal_update=False

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        additional_upgrade=False
        loss = None

        if method == 3:#!!!not ready yet
            model2.load_state_dict(model.state_dict())

            output = model2(data)
            _loss = F.cross_entropy(output, target)  # F.nll_loss(output, target)

            opt2.zero_grad()
            _loss.backward()

            opt2.step()

            _output = model2(data)
            loss = F.cross_entropy(_output, target)  # F.nll_loss(output, target)#todo check what this line changes

            output1 = model(data)
            loss1 = F.cross_entropy(output1, target)  # F.nll_loss(output1, target)
            # model.backward_grad_correction(loss1, model2, F.nll_loss, target)
            model.backward_grad_correction_with_weight_change2(loss1, model2)

            if common_optimizer_parameters_enhancement:
                model2.copy_grad_from(model)
                opt2.step()

            output2 = model(data)
            loss2 = F.cross_entropy(output2, target)  # F.nll_loss(output2, target)

            batch_counter += 1
            if loss < loss2:
                higher_loss_batch_counter += 1
                # high_loss=True
            elif loss > loss2:
                lower_loss_batch_counter += 1
            loss_improvement += float(loss.item() - loss2.item())

            # if float(loss - _loss) != 0.:
            #     relative_loss_improvement += float((loss - loss2) / abs(loss - _loss))
            relative_loss_improvement_denominator+=float(loss.item()-_loss.item())

            total_loss_improvement+=loss.item() - loss2.item()
            total_relative_loss_improvement_denominator+=loss.item()-_loss.item()



        elif method == 2:
            global with_optimizer_parameter_copy

            model.gradient_factor_simple_layers = model.gradient_factor
            high_loss = False
            iter_count=1
            for i in range(iter_count):
                model2.load_state_dict(model.state_dict())
                # x=optimizer.state_dict()
                if with_optimizer_parameter_copy:
                    opt2.load_state_dict(optimizer.state_dict())
                # opt2.param_groups[0]['params']=list(model2.parameters())
                # load_adam_state(opt2,optimizer)
                # x=optimizer.state_dict()
                # opt2.zero_grad()
                # optimizer.zero_grad()
                # data.requires_grad=True

                output = model2(data)
                loss = F.cross_entropy(output, target)  # F.nll_loss(output, target)
                # loss.requires_grad=True

                # model2.backward(loss)
                if i == 0:
                    opt2.zero_grad()
                    loss.backward(retain_graph=True)
                else:
                    model2.copy_grad_from(model)

                opt2.step()

                output1 = model(data)
                loss1 = F.cross_entropy(output1, target)  # F.nll_loss(output1, target)

                output = model2(data)
                loss = F.cross_entropy(output, target)  # F.nll_loss(output, target)#todo check what this line changes

                model.backward_grad_correction(loss1, model2, F.cross_entropy, target)

                if False and additional_upgrade:
                    model2.load_state_dict(model.state_dict())
                    opt2.load_state_dict(optimizer.state_dict())

                additional_upgrade = False
                if additional_upgrade:
                    model3.load_state_dict(model.state_dict().copy())
                    opt3.load_state_dict(optimizer.state_dict())

                optimizer.step()
                output2 = model(data)
                loss2 = F.cross_entropy(output2, target)  # F.nll_loss(output2, target)
                #
                # batch_counter += 1
                # if loss < loss2:
                #     higher_loss_batch_counter += 1
                #     high_loss = True
                # elif loss > loss2:
                #     lower_loss_batch_counter += 1

                batch_counter += 1
                if loss < loss2:
                    higher_loss_batch_counter += 1
                    high_loss=True
                elif loss > loss2:
                    lower_loss_batch_counter += 1
                loss_improvement += float(loss.item() - loss2.item())

                # if float(loss - loss1) != 0.:
                #     relative_loss_improvement += float((loss - loss2) / abs(loss - loss1))
                relative_loss_improvement_denominator += float(loss.item() - loss1.item())

                total_loss_improvement += float(loss.item() - loss2.item())
                total_relative_loss_improvement_denominator += float(loss.item() - loss1.item())

                if additional_upgrade and loss1 < loss2:
                    model.load_state_dict(model3.state_dict())
                    optimizer.load_state_dict(opt3.state_dict())
                    break

            model.loss_signal(high_loss)
            model2.gradient_factor = model.gradient_factor
            model2.gradient_factor_simple_layers = model.gradient_factor_simple_layers
            model2.step_factor = model.step_factor
        elif method==1:
            # def load_adam_state(opt2:optim.Adam,opt:optim.Adam):
            #     opt_values=list(opt.state.values())
            #     opt2_values=list(opt2.state.values())
            #     #for key,value in opt2.state.items():
            #     soft_copy_factor=0.01
            #     for i in range(len(opt2_values)):
            #         opt2_values[i]['exp_avg']=opt2_values[i]['exp_avg']*(1-soft_copy_factor)+soft_copy_factor*opt_values[i]['exp_avg'].clone().detach()
            #         opt2_values[i]['exp_avg_sq'] = opt2_values[i]['exp_avg_sq']*(1-soft_copy_factor)+soft_copy_factor*opt_values[i]['exp_avg_sq'].clone().detach()
            #model.step_factor=1.

            if iter_count>=3:
                model2.load_state_dict(model.state_dict())
                model3.load_state_dict(model.state_dict())

                output = model2(data)
                _loss = F.cross_entropy(output, target)  # F.nll_loss(output, target)

                opt2.zero_grad()
                _loss.backward()
                # if i == 0:
                #     opt2.zero_grad()
                #     _loss.backward()
                #     # model2.backward(_loss)
                #     # loss.backward(retain_graph=True)
                # else:
                #     model2.copy_grad_from(model)

                opt2.step()
                _output = model2(data)
                loss_backprop=F.cross_entropy(_output, target)

                for i in range(iter_count-1):
                    if i%2==0:
                        if i!=0:
                            model.load_state_dict(model3.state_dict())
                            _output = model2(data)
                            loss = F.cross_entropy(_output, target)  # F.nll_loss(output, target)
                        else:
                            loss=loss_backprop

                        output1 = model(data)
                        loss1 = F.cross_entropy(output1, target)  # F.nll_loss(output1, target)

                        model.backward_grad_correction_with_weight_change2(loss1, model2)
                    else:
                        model2.load_state_dict(model3.state_dict())

                        _output = model(data)
                        loss = F.cross_entropy(_output, target)  # F.nll_loss(output, target)

                        output1 = model2(data)
                        loss1 = F.cross_entropy(output1, target)  # F.nll_loss(output1, target)

                        model2.backward_grad_correction_with_weight_change2(loss1, model)

                if (iter_count)%2==1:
                    model.load_state_dict(model2.state_dict())

                output2 = model(data)
                loss2 = F.cross_entropy(output2, target)  # F.nll_loss(output2, target)

                batch_counter += 1
                if loss_backprop < loss2:
                    higher_loss_batch_counter += 1
                    # high_loss=True
                elif loss_backprop > loss2:
                    lower_loss_batch_counter += 1
                loss_improvement += float(loss_backprop.item() - loss2.item())

                # if float(loss_backprop - _loss) != 0.:
                #     relative_loss_improvement += float((loss_backprop - loss2) / abs(loss_backprop - _loss))
                relative_loss_improvement_denominator += float(loss_backprop.item() - _loss.item())

                total_loss_improvement += float(loss_backprop.item() - loss2.item())
                total_relative_loss_improvement_denominator += float(loss_backprop.item() - _loss.item())
            else:
                #model.gradient_factor_simple_layers=model.gradient_factor
                #high_loss=False
                #optimizer.zero_grad()
                #for i in range(iter_count):
                model2.load_state_dict(model.state_dict())
                #opt2.load_state_dict(optimizer.state_dict())
                #x=optimizer.state_dict()

                #opt2.param_groups[0]['params']=list(model2.parameters())
                #load_adam_state(opt2,optimizer)
                #x=optimizer.state_dict()
                #opt2.zero_grad()
                #optimizer.zero_grad()
                # data.requires_grad=True

                output = model2(data)
                _loss = F.cross_entropy(output, target)#F.nll_loss(output, target)
                # loss.requires_grad=True

                # model2.backward(loss)
                #if i==0:
                opt2.zero_grad()

                #_loss.backward()
                
                _loss.backward()

                    #model2.backward(_loss)
                    #loss.backward(retain_graph=True)
                #else:
                #    model2.copy_grad_from(model)

                # loss.backward()
                # model2.backward(loss)
                opt2.step()

                if show_grad:
                    print(model2.layers[0].weight.grad[0][0][0])

                _output = model2(data)
                loss = F.cross_entropy(_output, target)#F.nll_loss(output, target)#todo check what this line changes
                #model.load_state_dict(model2.state_dict())
                #break

                output1 = model(data)
                loss1 = F.cross_entropy(output1, target)  # F.nll_loss(output1, target)
                #model.backward_grad_correction(loss1, model2, F.nll_loss, target)
                model.backward_grad_correction_with_weight_change2(loss1, model2)

                if show_grad:
                    print(model.layers[0].weight.grad[0][0][0])
                if common_optimizer_parameters_enhancement:
                    #model.backward_grad_correction_with_weight_change2(loss1, model2, loss, F.cross_entropy, target)
                    model2.copy_grad_from(model)
                    opt2.step()
                # model2.backward_grad_correction(loss,model)

                #optimizer.step()
                #opt2.zero_grad()

                if False and additional_upgrade:
                    model2.load_state_dict(model.state_dict())
                    opt2.load_state_dict(optimizer.state_dict())

                # model3.load_state_dict(model.state_dict().copy())
                # opt3.load_state_dict(optimizer.state_dict())

                # additional_upgrade=False
                # if additional_upgrade:
                #     model3.load_state_dict(model.state_dict().copy())
                #     opt3.load_state_dict(optimizer.state_dict())

                #optimizer.step()
                output2 = model(data)
                loss2 = F.cross_entropy(output2, target)#F.nll_loss(output2, target)

                batch_counter+=1
                if loss<loss2:
                    higher_loss_batch_counter+=1
                    #high_loss=True
                    if take_optimal_update:
                        model.load_state_dict(model2.state_dict())
                elif loss>loss2:
                    lower_loss_batch_counter+=1
                loss_improvement+=float(loss.item()-loss2.item())

                relative_loss_improvement_denominator += float(loss.item() - _loss.item())

                total_loss_improvement += float(loss.item()-loss2.item())
                total_relative_loss_improvement_denominator += float(loss.item() - _loss.item())
                # if float(loss-_loss)!=0.:
                #     relative_loss_improvement+=float((loss-loss2)/abs(loss-_loss))
                #
                #     #opt2.load_state_dict(optimizer.state_dict())
                #
                #     additional_upgrade = False
                #     if additional_upgrade and loss < loss2:
                #         #print('aaaaaaaaaaaaaaa')
                #         model.load_state_dict(model2.state_dict())
                #         #optimizer.load_state_dict(opt2.state_dict())
                #         # output = model(data)
                #         # loss = F.nll_loss(output, target)
                #         # loss.backward(retain_graph=True)
                #         # optimizer.step()
                #         #loss=loss2
                #         #todo: test with no weight change
                #         #break
                #     #loss = loss2


            # model.loss_signal(high_loss)
            # model2.gradient_factor=model.gradient_factor
            # model2.gradient_factor_simple_layers=model.gradient_factor_simple_layers
            # model2.step_factor=model.step_factor
            #model.load_state_dict(model2.state_dict())


            # model2.load_state_dict(model.state_dict())
            # opt2.load_state_dict(optimizer.state_dict())
            # opt2.step()
            # model.load_state_dict(model2.state_dict())
            # optimizer.load_state_dict(opt2.state_dict())


            #optimizer.step()
        else:
            optimizer.zero_grad()
            # data.requires_grad=True
            output = model(data)
            loss = F.cross_entropy(output, target)#F.nll_loss(output, target)
            # loss.requires_grad=True
            #print(hash(model.named_parameters()))
            #print(dict(model.named_parameters()).keys())

            #torch.manual_seed(1)
            loss.backward()
            #model.backward(loss)
            #torch.manual_seed(1)
            #x=model.state_dict()
            #print(x)

            optimizer.step()
            #print(dict(model.named_parameters()))
            if show_grad:
                print(model.layers[0].weight.grad[0][0][0])

        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if batch_counter!=0:
                print("Higher loss batches: "+str(higher_loss_batch_counter)+"/"+str(batch_counter)+"="+str(higher_loss_batch_counter/batch_counter)+"   gradient_factor: "+str(model.gradient_factor)+"   step factor:"+str(model.step_factor)+" Lower loss batch ratio: "+str(lower_loss_batch_counter/batch_counter)+" Avg loss improvement: "+str(loss_improvement/batch_counter)+" Avg relative loss improvement: "+str(loss_improvement/abs(relative_loss_improvement_denominator))+" Total avg relative loss improvement: "+str(total_loss_improvement/abs(total_relative_loss_improvement_denominator)))
    if batch_counter!=0:
        if model.stats is None:
            model.stats={}
        key='train_higher_loss_batch_ratio'
        if key not in model.stats:
            model.stats[key]={}
        if epoch not in model.stats[key]:
            model.stats[key][epoch]=[]
        model.stats[key][epoch].append(higher_loss_batch_counter/batch_counter)

        key = 'train_lower_loss_batch_ratio'
        if key not in model.stats:
            model.stats[key] = {}
        if epoch not in model.stats[key]:
            model.stats[key][epoch] = []
        model.stats[key][epoch].append(lower_loss_batch_counter / batch_counter)

        key = 'train_same_loss_batch_ratio'
        if key not in model.stats:
            model.stats[key] = {}
        if epoch not in model.stats[key]:
            model.stats[key][epoch] = []
        model.stats[key][epoch].append((batch_counter-lower_loss_batch_counter-higher_loss_batch_counter) / batch_counter)

        key = 'train_batch_avg_loss_improvement'
        if key not in model.stats:
            model.stats[key] = {}
        if epoch not in model.stats[key]:
            model.stats[key][epoch] = []
        model.stats[key][epoch].append(
            loss_improvement / batch_counter)

        key = 'train_batch_avg_relative_loss_improvement'
        if key not in model.stats:
            model.stats[key] = {}
        if epoch not in model.stats[key]:
            model.stats[key][epoch] = []
        model.stats[key][epoch].append(
            loss_improvement/abs(relative_loss_improvement_denominator))

        key = 'train_avg_relative_loss_improvement'
        if key not in model.stats:
            model.stats[key] = {}
        if epoch not in model.stats[key]:
            model.stats[key][epoch] = []
        model.stats[key][epoch].append(
            total_loss_improvement / abs(total_relative_loss_improvement_denominator))

def test(model, device, test_loader,train_loader,val_loader=None,epoch=1):
    if model.stats is None:
        model.stats={}

    if 'train_accuracy' not in model.stats:
        model.stats['train_accuracy']={}
    if calculate_test_stats and 'test_accuracy' not in model.stats:
        model.stats['test_accuracy'] = {}
    if val_loader is not None and 'val_accuracy' not in model.stats:
        model.stats['val_accuracy'] = {}

    if 'train_loss' not in model.stats:
        model.stats['train_loss'] = {}
    if calculate_test_stats and 'test_loss' not in model.stats:
        model.stats['test_loss'] = {}
    if val_loader is not None and 'val_loss' not in model.stats:
        model.stats['val_loss'] = {}

    if epoch  not in model.stats['train_accuracy']:
        model.stats['train_accuracy'][epoch]=[]
    if epoch  not in model.stats['train_loss']:
        model.stats['train_loss'][epoch]=[]
    if calculate_test_stats:
        if epoch  not in model.stats['test_accuracy']:
            model.stats['test_accuracy'][epoch]=[]
        if epoch  not in model.stats['test_loss']:
            model.stats['test_loss'][epoch]=[]
    if val_loader is not None:
        if epoch  not in model.stats['val_accuracy']:
            model.stats['val_accuracy'][epoch]=[]
        if epoch  not in model.stats['val_loss']:
            model.stats['val_loss'][epoch]=[]
    model.eval()

    if calculate_test_stats:
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                #test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
                test_loss += F.cross_entropy(output, target, reduction='sum').item()
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)

        print('\nTest set: Avg loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
        model.stats['test_accuracy'][epoch].append(100. * correct / len(test_loader.dataset))
        model.stats['test_loss'][epoch].append(test_loss)

    if calculate_train_stats:
        train_loss = 0
        correct = 0
        length = 0
        with torch.no_grad():
            for data, target in train_loader:
                length += data.shape[0]
                data, target = data.to(device), target.to(device)
                output = model(data)
                #train_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
                train_loss += F.cross_entropy(output, target, reduction='sum').item()
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        train_loss /= length


        print('\nTraining set: Avg loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
            train_loss, correct, length,
            100. * correct / length))
        model.stats['train_accuracy'][epoch].append(100. * correct / length)
        model.stats['train_loss'][epoch].append(train_loss)

        if val_loader is not None:
            val_loss = 0
            correct = 0
            length=0
            with torch.no_grad():
                for data, target in val_loader:
                    length+=data.shape[0]
                    data, target = data.to(device), target.to(device)
                    output = model(data)
                    # train_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
                    val_loss += F.cross_entropy(output, target, reduction='sum').item()
                    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                    correct += pred.eq(target.view_as(pred)).sum().item()

            val_loss /= length

            print('\nValidation set: Avg loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
                val_loss, correct, length,
                100. * correct / length))
            model.stats['val_accuracy'][epoch].append(100. * correct / length)
            model.stats['val_loss'][epoch].append(val_loss)



def evaluate_stats(stats):
    # # return -stats['train_loss_avg']
    if ('train_loss_onthefly_aggregated' not in stats or len(stats['train_loss_onthefly_aggregated'].items())==0) and ('train_loss' not in stats or len(stats['train_loss'].items())==0):
        return np.Inf

    # if not calculate_train_stats:
    #     # return np.Inf
    #
    #     min_loss = np.Inf  # minimal training loss in any epoch
    #     min_epoch=0
    #     for epoch_num, score in stats['train_loss_onthefly_aggregated'].items():
    #         if min_loss > score:
    #             min_loss = score
    #             min_epoch=epoch_num
    #         # if min_loss>sum(scores)/len(scores):
    #         #     min_loss=sum(scores)/len(scores)
    #     if min_loss==0.:
    #         return 1./min_epoch#the lowest is the first epoch with 0 loss, the higher is the evaluation
    #     return -min_loss  # "-" to minimize min_loss value instead of default maximization
    #
    # if cross_val:
    #     min_loss=np.Inf #minimal training loss in any epoch
    #     for epoch_num,score in stats['val_loss_aggregated'].items():
    #         if min_loss>score:
    #             min_loss=score
    #         # if min_loss>sum(scores)/len(scores):
    #         #     min_loss=sum(scores)/len(scores)
    #     return -min_loss #"-" to minimize min_loss value instead of default maximization
    # else:
    #     min_loss = np.Inf  # minimal training loss in any epoch
    #     for epoch_num, score in stats['train_loss_aggregated'].items():
    #         if min_loss > score:
    #             min_loss = score
    #         # if min_loss>sum(scores)/len(scores):
    #         #     min_loss=sum(scores)/len(scores)
    #     return -min_loss  # "-" to minimize min_loss value instead of default maximization
    if not calculate_train_stats:
        return -sum(stats['train_loss_onthefly_min'])/len(stats['train_loss_onthefly_min'])

    if cross_val:
        return -sum(stats['val_loss_min'])/len(stats['val_loss_min'])
    else:
        return -sum(stats['train_loss_min'])/len(stats['train_loss_min'])

def stop_criteria(stats):
    return evaluate_stats(stats)<-3.
def process_stats(model):
    stats={}
    for key,value in model.stats.items():
        if key=='params':
            stats[key]=copy.deepcopy(model.stats[key])
            continue
        stats[key+"_aggregated"]={}
        stats[key] = {}
        all_scores=[]
        for epoch_num,list_of_scores in value.items():
            if len(list_of_scores)!=0:
                stats[key+"_aggregated"][epoch_num]=sum(list_of_scores)/len(list_of_scores)
                all_scores+=list_of_scores
                stats[key][epoch_num]=copy.deepcopy(list_of_scores)
        if len(all_scores)!=0:
            stats[str(key)+'_avg']=sum(all_scores)/len(all_scores)

        #epochs=list(value.keys())
        lists_of_scores=list(value.values())
        stats[key + "_max"]=[]
        stats[key + "_min"]=[]
        for training_num in range(len(lists_of_scores[0])):
            val_max=-np.Inf
            val_min=np.Inf
            for epoch in range(len(lists_of_scores)):
                if len(lists_of_scores[epoch])==training_num:
                    continue
                val=lists_of_scores[epoch][training_num]
                if val_min>val:
                    val_min=val
                if val_max<val:
                    val_max=val
            stats[key + "_max"].append(val_max)
            stats[key + "_min"].append(val_min)
        if len(stats[key + "_max"])!=0:
            stats[key + "_max_avg"]=sum(stats[key + "_max"])/len(stats[key + "_max"])
            stats[key + "_min_avg"] = sum(stats[key + "_min"]) / len(stats[key + "_min"])


    stats['gradient_factor']=model.gradient_factor
    stats['step_factor'] = model.step_factor

    if cross_val:
        lists_of_scores = list(model.stats['val_loss'].values())
        for training_num in range(len(lists_of_scores[0])):
            val_max = -np.Inf
            val_min = np.Inf
            epoch_min=-1
            for epoch in range(len(lists_of_scores)):
                if len(lists_of_scores[epoch]) == training_num:
                    continue
                val = lists_of_scores[epoch][training_num]
                if val_min > val:
                    val_min = val
                    epoch_min=epoch+1
                if val_max < val:
                    val_max = val
            if calculate_test_stats:
                stats['test_loss_validation_optimal']=model.stats['test_loss'][epoch_min]
            if 'validation_optimal_epoch' not in stats:
                stats['validation_optimal_epoch']=[]
            stats['validation_optimal_epoch'].append(epoch_min)
            #stats[key + "_max"].append(val_max)
            #stats[key + "_min"].append(val_min)
        stats['validation_optimal_epoch_aggregated']=sum(stats['validation_optimal_epoch'])/len(stats['validation_optimal_epoch'])

    return stats
def write_to_file(name,text,mode='a'):
    with open(name, mode) as myfile:
        myfile.write(text+'\n')

def save_model(model,name):
    torch.save(model,name)

def load_model(name):
    return torch.load(name)

def rand_between(a,b):
    return torch.FloatTensor(1).uniform_(a, b)[0]

hyperparameter_index=-1

# class State:
#     hyperparameter_index=-1
#     actual_stats={}
#     trainings_different_model=0
#     trainings_same_model=0
def save_state(file_name,data):
    #write_to_file(file_name,data)
    torch.set_printoptions(profile="full")#otherwise "default"
    tmp_name='tmp_'+file_name
    with open(tmp_name, 'w') as myfile:
        myfile.write(str(data).replace("tensor",'torch.tensor'))
        myfile.flush()
        os.fsync(myfile.fileno())
    torch.set_printoptions(profile="default")
    os.replace(tmp_name,file_name)#atomic operation
    # if os.path.isfile(tmp_name):
    #     os.remove(tmp_name)
    #     #print("Removing tmp file")
def read_state(file_name):
    try:
        with open(file_name, 'r') as f:
            s = f.read()
            #return ast.literal_eval(s)
            return eval(s)
    except Exception as e:  # works on python 3.x
        print(repr(e))
    return None

def main():
    #device = torch.device("cpu")
    global method
    global batch_size
    lr=1#0.007 #todo: change
    gamma=0.7 #todo: set optimal gamma
    train_kwargs = {'batch_size': batch_size}
    test_kwargs = {'batch_size': batch_size}

    # seed=23
    torch.manual_seed(seed)

    # epochs=2
    # if model_nr == 3:
    #     epochs=2
    # elif model_nr == 4:
    #     epochs=2
    # elif model_nr == 5:
    #     epochs=30
    # elif model_nr == 6:
    #     epochs=40

    #global trainings_same_model
    #global trainings_different_model
    train_loader:torch.utils.data.DataLoader=None
    test_loader:torch.utils.data.DataLoader=None

    convert_to_datatype=lambda x: x.to(d_type)
    if dataset=='mnist':
        transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.5,), (0.5,))
            convert_to_datatype
        ])
        dataset1 = datasets.MNIST('../data', train=True, download=True,
                           transform=transform)
        dataset2 = datasets.MNIST('../data', train=False,
                           transform=transform)
        train_loader = torch.utils.data.DataLoader(dataset1, shuffle=True, **train_kwargs)
        test_loader = torch.utils.data.DataLoader(dataset2, shuffle=True, **test_kwargs)
    elif dataset=='fashion_mnist':
        transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.5,), (0.5,))
            convert_to_datatype
        ])
        dataset1 = datasets.FashionMNIST('../data', train=True, download=True,
                                  transform=transform)
        dataset2 = datasets.FashionMNIST('../data', train=False,
                                  transform=transform)
        train_loader = torch.utils.data.DataLoader(dataset1, shuffle=True, **train_kwargs)
        test_loader = torch.utils.data.DataLoader(dataset2, shuffle=True, **test_kwargs)
    elif dataset=='imdb':
        train_loader, test_loader = imdb_utils.get_preprocessed_IMDB(d_type=d_type)
        train_loader=torch.utils.data.DataLoader(train_loader.dataset, shuffle=True, **train_kwargs)
        test_loader=torch.utils.data.DataLoader(test_loader.dataset, shuffle=True, **test_kwargs)

    def get_layers():
        if model_nr==3:
            return [nn.Conv2d(1, 16, 3, 1,dtype=d_type),
                    nn.ReLU(),
                    #nn.BatchNorm2d(16),
                    nn.Conv2d(16, 16, 3, 1,dtype=d_type),
                    nn.ReLU(),
                    #nn.BatchNorm2d(16),
                    nn.MaxPool2d(2),
                    # nn.Dropout(0.25),
                    nn.Flatten(),
                    nn.Linear(2304, 32,dtype=d_type),
                    nn.ReLU(),
                    #nn.BatchNorm1d(32),
                    # nn.Dropout(0.5),
                    nn.Linear(32, 10,dtype=d_type),
                    #nn.Softmax(dim=1)#nn.LogSoftmax(dim=1)
                    ]
        elif model_nr==4:
            return [nn.Conv2d(1, 8, 3, 1,dtype=d_type),
                    nn.ReLU(),
                    #nn.BatchNorm2d(8),
                    # nn.MaxPool2d(2),
                    nn.Conv2d(8, 8, 3, 1,dtype=d_type),
                    nn.ReLU(),
                    #nn.BatchNorm2d(8),
                    nn.Conv2d(8, 16, kernel_size=5, stride=2, padding=2,dtype=d_type),
                    nn.ReLU(),
                    #nn.BatchNorm2d(16),

                    nn.Conv2d(16, 16, 3, 1,dtype=d_type),
                    nn.ReLU(),
                    #nn.BatchNorm2d(16),
                    # nn.MaxPool2d(2),
                    nn.Conv2d(16, 16, 3, 1,dtype=d_type),
                    nn.ReLU(),
                    #nn.BatchNorm2d(16),
                    nn.Conv2d(16, 16, kernel_size=5, stride=2, padding=2,dtype=d_type),
                    nn.ReLU(),
                    #nn.BatchNorm2d(16),

                    # nn.MaxPool2d(2),
                    # nn.Dropout(0.25),
                    nn.Flatten(),
                    nn.Linear(256, 10,dtype=d_type),
                    # nn.ReLU(),
                    # nn.Dropout(0.5),
                    # nn.Linear(32, 10),
                    #nn.Softmax(dim=1)
                    ]
        elif model_nr==5:
            return [nn.Conv2d(1, 8, 3, 1,dtype=d_type),
                    nn.ELU(),
                    # nn.BatchNorm2d(16),
                    nn.MaxPool2d(2),
                    nn.Conv2d(8, 16, 3, 1,dtype=d_type),
                    nn.ELU(),
                    # nn.BatchNorm2d(16),
                    nn.MaxPool2d(2),
                    # nn.Dropout(0.25),
                    nn.Flatten(),
                    nn.Linear(400, 32,dtype=d_type),  # nn.Linear(2304, 32),
                    nn.ELU(),
                    # nn.BatchNorm1d(32),
                    # nn.Dropout(0.5),
                    nn.Linear(32, 10,dtype=d_type),
                    # nn.Softmax(dim=1)#nn.LogSoftmax(dim=1)
                    ]
            # return [nn.Conv2d(1, 8, 3, 1),
            #         nn.Hardswish(),
            #         #nn.BatchNorm2d(16),
            #         nn.MaxPool2d(2),
            #         nn.Conv2d(8, 16, 3, 1),
            #         nn.Hardswish(),
            #         #nn.BatchNorm2d(16),
            #         nn.MaxPool2d(2),
            #         # nn.Dropout(0.25),
            #         nn.Flatten(),
            #         nn.Linear(400, 32),#nn.Linear(2304, 32),
            #         nn.Hardswish(),
            #         #nn.BatchNorm1d(32),
            #         # nn.Dropout(0.5),
            #         nn.Linear(32, 10),
            #         #nn.Softmax(dim=1)#nn.LogSoftmax(dim=1)
            #         ]
        elif model_nr==6:
            # return [nn.Conv2d(1, 8, 3, 1),
            #         nn.ELU(),
            #         nn.BatchNorm2d(8),
            #         # nn.MaxPool2d(2),
            #         nn.Conv2d(8, 8, 3, 1),
            #         nn.ELU(),
            #         nn.BatchNorm2d(8),
            #         nn.Conv2d(8, 16, kernel_size=5, stride=2, padding=2),
            #         nn.ELU(),
            #         nn.BatchNorm2d(16),
            #
            #         nn.Conv2d(16, 16, 3, 1),
            #         nn.ELU(),
            #         nn.BatchNorm2d(16),
            #         # nn.MaxPool2d(2),
            #         nn.Conv2d(16, 16, 3, 1),
            #         nn.ELU(),
            #         nn.BatchNorm2d(16),
            #         nn.Conv2d(16, 16, kernel_size=5, stride=2, padding=2),
            #         nn.ELU(),
            #         nn.BatchNorm2d(16),
            #
            #         # nn.MaxPool2d(2),
            #         # nn.Dropout(0.25),
            #         nn.Flatten(),
            #         nn.Linear(256, 10),
            #         # nn.ReLU(),
            #         # nn.Dropout(0.5),
            #         # nn.Linear(32, 10),
            #         nn.Softmax(dim=1)]
            return [nn.Conv2d(1, 8, 3, 1,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(8),
                    # nn.MaxPool2d(2),
                    nn.Conv2d(8, 8, 3, 1,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(8),
                    nn.Conv2d(8, 16, kernel_size=5, stride=2, padding=2,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(16),

                    nn.Conv2d(16, 16, 3, 1,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(16),
                    # nn.MaxPool2d(2),
                    nn.Conv2d(16, 16, 3, 1,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(16),
                    nn.Conv2d(16, 16, kernel_size=5, stride=2, padding=2,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(16),

                    # nn.MaxPool2d(2),
                    # nn.Dropout(0.25),
                    nn.Flatten(),
                    nn.Linear(256, 10,dtype=d_type),
                    # nn.ReLU(),
                    # nn.Dropout(0.5),
                    # nn.Linear(32, 10),
                    #nn.Softmax(dim=1)
                    ]
        elif model_nr==7:
            return [nn.Conv2d(1, 8, 3, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1,1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.Hardswish(),


                    ##nn.BatchNorm2d(8),
                    # nn.MaxPool2d(2),
                    nn.Conv2d(8, 8, 3, 1,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(8),
                    nn.Conv2d(8, 16, kernel_size=5, stride=2, padding=2,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(16),

                    nn.Conv2d(16, 16, 3, 1,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(16),
                    # nn.MaxPool2d(2),
                    nn.Conv2d(16, 16, 3, 1,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(16),
                    nn.Conv2d(16, 16, kernel_size=5, stride=2, padding=2,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(16),

                    # nn.MaxPool2d(2),
                    # nn.Dropout(0.25),
                    nn.Flatten(),
                    nn.Linear(256, 10,dtype=d_type),
                    # nn.ReLU(),
                    # nn.Dropout(0.5),
                    # nn.Linear(32, 10),
                    #nn.Softmax(dim=1)
                    ]
        elif model_nr==8:
            return [nn.Conv2d(1, 8, 3, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1,1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),
                    nn.Conv2d(8, 8, 3, 1, 1,dtype=d_type),
                    nn.ELU(),


                    ##nn.BatchNorm2d(8),
                    # nn.MaxPool2d(2),
                    nn.Conv2d(8, 8, 3, 1,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(8),
                    nn.Conv2d(8, 16, kernel_size=5, stride=2, padding=2,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(16),

                    nn.Conv2d(16, 16, 3, 1,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(16),
                    # nn.MaxPool2d(2),
                    nn.Conv2d(16, 16, 3, 1,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(16),
                    nn.Conv2d(16, 16, kernel_size=5, stride=2, padding=2,dtype=d_type),
                    nn.ELU(),
                    #nn.BatchNorm2d(16),

                    # nn.MaxPool2d(2),
                    # nn.Dropout(0.25),
                    nn.Flatten(),
                    nn.Linear(256, 10,dtype=d_type),
                    # nn.ReLU(),
                    # nn.Dropout(0.5),
                    # nn.Linear(32, 10),
                    #nn.Softmax(dim=1)
                    ]
        elif model_nr==9:
            layers= [
                nn.Conv2d(1, 8, 3, 1,dtype=d_type),
                nn.ELU(),
                # nn.BatchNorm2d(16),
                nn.MaxPool2d(2),
                nn.Conv2d(8, 16, 3, 1,dtype=d_type),
                nn.ELU(),
                # nn.BatchNorm2d(16),
                nn.MaxPool2d(2),
                # nn.Dropout(0.25),
                nn.Flatten(),
                nn.Linear(400, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),
                nn.Tanh(),
                nn.Linear(10, 10,dtype=d_type),



                #nn.Conv2d(1, 8, 3, 1),
            #         nn.ELU(),
            #         # nn.BatchNorm2d(16),
            #         nn.MaxPool2d(2),
            #         nn.Conv2d(8, 16, 3, 1),
            #         nn.ELU(),
            #         # nn.BatchNorm2d(16),
            #         nn.MaxPool2d(2),
            #         # nn.Dropout(0.25),
            #         nn.Flatten(),
            #         nn.Linear(400, 32),  # nn.Linear(2304, 32),
            #         nn.ELU(),
            #         # nn.BatchNorm1d(32),
            #         # nn.Dropout(0.5),
            #         nn.Linear(32, 10),
            #         # nn.Softmax(dim=1)#nn.LogSoftmax(dim=1)
                    ]
            init_weights(layers, nn.init.xavier_uniform_)
            return layers
        elif model_nr == 10:
            slope=0.1
            layers=[nn.Conv2d(1, 8, 3, 1,dtype=d_type),
                    nn.LeakyReLU(slope),
                    # nn.BatchNorm2d(8),
                    # nn.MaxPool2d(2),
                    nn.Conv2d(8, 8, 3, 1,dtype=d_type),
                    nn.LeakyReLU(slope),
                    # nn.BatchNorm2d(8),
                    nn.Conv2d(8, 8, kernel_size=5, stride=2, padding=2,dtype=d_type),
                    nn.LeakyReLU(slope),
                    # nn.BatchNorm2d(16),

                    nn.Conv2d(8, 8, 3, 1,dtype=d_type),
                    nn.LeakyReLU(slope),
                    # nn.BatchNorm2d(16),
                    # nn.MaxPool2d(2),
                    ]
            for _ in range(32):
                layers+=(lambda:[
                       nn.Conv2d(8, 8, kernel_size=3, stride=1, padding=1,dtype=d_type),
                       nn.LeakyReLU(slope), ])()
            layers+=[

                       nn.Conv2d(8, 8, 3, 1,dtype=d_type),
                       nn.LeakyReLU(slope),
                       # nn.BatchNorm2d(16),
                       nn.Conv2d(8, 8, kernel_size=5, stride=2, padding=2,dtype=d_type),
                       nn.LeakyReLU(slope),
                       # nn.BatchNorm2d(16),

                       # nn.MaxPool2d(2),
                       # nn.Dropout(0.25),
                       nn.Flatten(),
                       nn.Linear(128, 30,dtype=d_type),
                       nn.Sigmoid(),
                       nn.Linear(30, 10,dtype=d_type),
                       # nn.ReLU(),
                       # nn.Dropout(0.5),
                       # nn.Linear(32, 10),
                       # nn.Softmax(dim=1)
                   ]
            return layers
        elif model_nr == 11:
            layers=imdb_utils.get_model_layers(d_type=d_type,model_nr=1)
            init_weights(layers,nn.init.xavier_uniform_)
            init_weights(layers,nn.init.xavier_uniform_,avg_gain=1.66666667)
            return layers
        elif model_nr == 12:
            layers=imdb_utils.get_model_layers(d_type=d_type,model_nr=2)
            init_weights(layers,nn.init.xavier_uniform_,avg_gain=1.66666667)#2.85)
            return layers
        return None
        # return [nn.Conv2d(1, 16, 3, 1),
        #         nn.ReLU(),
        #         nn.BatchNorm2d(16),
        #         nn.Conv2d(16, 16, 3, 1),
        #         nn.ReLU(),
        #         nn.BatchNorm2d(16),
        #         nn.MaxPool2d(2),
        #         # nn.Dropout(0.25),
        #         nn.Flatten(),
        #         nn.Linear(2304, 32),
        #         nn.ReLU(),
        #         nn.BatchNorm1d(32),
        #         # nn.Dropout(0.5),
        #         nn.Linear(32, 10),
        #         nn.LogSoftmax(dim=1)]
        # return [nn.Conv2d(1, 32, 3, 1),
        #  nn.ReLU(),
        #  nn.Conv2d(32, 64, 3, 1),
        #  nn.ReLU(),
        #  nn.MaxPool2d(2),
        #  # nn.Dropout(0.25),
        #  nn.Flatten(),
        #  nn.Linear(9216, 64),
        #  nn.ReLU(),
        #  # nn.Dropout(0.5),
        #  nn.Linear(64, 10),
        #  nn.LogSoftmax(dim=1)]
        # return [nn.Conv2d(1, 32, 3, 1),
        #             nn.Tanh(),
        #             nn.Conv2d(32, 64, 3, 1),
        #             nn.Tanh(),
        #             nn.MaxPool2d(2),
        #             # nn.Dropout(0.25),
        #             nn.Flatten(),
        #             nn.Linear(9216, 64),
        #             nn.Tanh(),
        #             # nn.Dropout(0.5),
        #             nn.Linear(64, 10),
        #             nn.LogSoftmax(dim=1)]

    # model_layers=[nn.Conv2d(1, 32, 3, 1),
    #                 nn.ReLU(),
    #                  nn.Conv2d(32, 64, 3, 1),
    #                  nn.ReLU(),
    #                 nn.MaxPool2d(2),
    #                  #nn.Dropout(0.25),
    #                 nn.Flatten(),
    #                  nn.Linear(9216, 64),
    #                  nn.ReLU(),
    #                  #nn.Dropout(0.5),
    #                  nn.Linear(64, 10),
    #                  nn.LogSoftmax(dim=1)]

    # model_layers = [nn.Conv2d(1, 32, 3, 1),
    #                 nn.Tanh(),
    #                 nn.Conv2d(32, 64, 3, 1),
    #                 nn.Tanh(),
    #                 nn.MaxPool2d(2),
    #                 # nn.Dropout(0.25),
    #                 nn.Flatten(),
    #                 nn.Linear(9216, 64),
    #                 nn.Tanh(),
    #                 # nn.Dropout(0.5),
    #                 nn.Linear(64, 10),
    #                 nn.LogSoftmax(dim=1)]
    #lr=0.1

    save_name='results'
    save_name+='_model'+str(model_nr)
    if method==0:
        save_name+="_not"
    save_name+="_improved_training"
    if method>1:
        save_name += str(method)
    if method==1:
        save_name+="_iter"+str(iter_count)
    save_name+="_seed"+str(seed)
    save_name+="_"+str(trainings_same_model)+"trainings"
    save_name += "_" + str(epochs) + "e"
    #save_name+="_const_step"
    #if not mnist:
    save_name+="_"+dataset
    if cross_val:
        save_name+="_cross_validation"+str(folds_num)
    if lr_mul!=1:
        save_name+='_'+str(lr_mul)+'lr'
    if not mini_batch:
        save_name += '_batch'
    if average_gradient_of_loss:
        save_name+='__loss_avggrad'
    if d_type!=torch.float:
        save_name+='_'+str(d_type)
    if pretraining is not None:
        save_name+='_with_pretraining'+str(pretraining['method'])
    save_name+='.txt'
    model=None

    best_stats=None
    processed_stats=None
    actual_stats = None
    #preserve_stats=False

    model_name='model'+str(model_nr)+'_'+(dataset)+'.pt'

    def set_params(model,params):
        #model.gradient_factor = params[0]
        #model.step_factor = params[1]
        model.gradient_factor_simple_layers = params

    def generate_params():
        global hyperparameter_index
        hyperparameter_index += 1
        # return (rand_between(0,1),rand_between(0,1))
        # return (rand_between(-1, 1)*0.1+0.53, rand_between(-1, 1)*0.1+0.39)
        # return (rand_between(-1, 1)*0.1+0.621, rand_between(-1, 1)*0.1+0.336)
        if cross_val:
            if dataset == 'mnist':
                if model_nr == 6:
                    # return {'gamma': float(0.9329), 'lr': float(0.02787)}
                    pass
                elif model_nr == 5:
                    #########return {'gamma': float(0.9145), 'lr': float(0.006728)}
                    # return {'gamma': float(rand_between(0.9145-0.02,0.9145+0.04)), 'lr': float(0.006728*0.7**rand_between(-1,1))}
                    # return {'gamma': float(rand_between(0.9145-0.01,0.9145+0.05)), 'lr': float(0.003297*0.7**rand_between(-1,1))}
                    # return {'gamma': float(rand_between(0.9145,0.9145+0.06)), 'lr': float(0.001615*0.7**rand_between(-1,1))}
                    # return {'gamma': float(rand_between(0.9145+0.01,0.9145+0.07)), 'lr': float(0.0007915*0.7**rand_between(-1,1))}
                    # return {'gamma': float(rand_between(0.915+0.025,1)), 'lr': float(0.0003878*0.7**rand_between(-1,1))}

                    # return {'gamma': float(rand_between(0.8848, 0.9048)),'lr': float(0.002897 * 0.9 ** rand_between(-2, 0))}#6*100
                    # return {'gamma': float(rand_between(0.8848, 0.9048)),'lr': float(0.002897 * 0.9 ** rand_between(0, 2))}#6*100
                    return {'gamma': float(rand_between(0.8649, 0.8849)),
                            'lr': float(0.003098 * 0.9 ** rand_between(-1.7, 0.3))}
            elif dataset == 'fashion_mnist':
                if model_nr == 6:
                    # return {'gamma': float(rand_between(0.955 - 0.01, 0.955 + 0.01)),'lr': float(0.04535 * (0.9 ** rand_between(-1, 1)))}
                    # return {'gamma': float(0.9581), 'lr': float(0.04678)}
                    pass
                elif model_nr == 5:
                    # return {'gamma': float(0.9384), 'lr': float(0.008208)}
                    pass
        else:
            if dataset == 'mnist':
                if model_nr == 10:
                    # lr = (0.00005 + 0.00005 * hyperparameter_index)
                    # lr = 0.0005  # temporarily
                    lr = 0.00055
                    return {'gamma': float(1), 'lr': float(lr)}
                elif model_nr == 9:
                    # # lr=(0.001+0.001*hyperparameter_index) if hyperparameter_index<=10 else (0.0001+0.0001*(hyperparameter_index-11))
                    # lr=(0.00005+0.00005*hyperparameter_index) if hyperparameter_index<=8 else (0.0005+0.0001*(hyperparameter_index-9))
                    # if method>0 and separate_hyperparameters_for_optimized:
                    #     # lr=0.0003+0.0003*hyperparameter_index if hyperparameter_index<=4 else (0.002+0.001*(hyperparameter_index-5))
                    #     lr = 0.0009 + 0.0006 * hyperparameter_index if hyperparameter_index <= 1 else (
                    #             0.002 + 0.001 * (hyperparameter_index - 2))
                    # return {'gamma': float(1), 'lr': float(lr)}

                    # if method>0 and separate_hyperparameters_for_optimized:
                    #     if method == 1:
                    #         if iter_count == 2:
                    #             return {'gamma': float(1), 'lr': float(0.0015)}
                    #         elif iter_count > 2:
                    #             return {'gamma': float(1), 'lr': float(0.0015)}#exception
                    if method > 0 and separate_hyperparameters_for_optimized:
                        if method == 1:
                            # if iter_count == 2:
                            #     return {'gamma': float(1), 'lr': float(0.00075)}
                            # elif iter_count > 2:
                            #     return {'gamma': float(1), 'lr': float(0.00075)}

                            # lr = (0.0005 + 0.000125 * hyperparameter_index)
                            # lr = 0.000625
                            # #return {'gamma': float(1), 'lr': float(lr)}
                            # return {'gamma': float(1),
                            #         'lr': float(lr * np.e ** ((hyperparameter_index - 4) / (4) * (np.log(3))))}
                            return {'gamma': float(1),
                                    'lr': float([0.00006, 0.00008, 0.00009, 0.0001, 0.00011, 0.00012, 0.00014][
                                                    hyperparameter_index])}

                    # return {'gamma': float(1), 'lr': float(0.00025)}
                    # lr = (0.00015 + 0.00005 * hyperparameter_index)
                    # lr = 0.00035
                    # #return {'gamma': float(1), 'lr': float(lr)}
                    # return {'gamma': float(1), 'lr': float(lr * np.e ** ((hyperparameter_index - 4) / (4) * (np.log(3))))}
                    return {'gamma': float(1),
                            'lr': float([0.00015, 0.0002, 0.00025,0.000275, 0.0003, 0.00035, 0.0004][
                                            hyperparameter_index])}
                elif model_nr == 8:
                    return {'gamma': float(1), 'lr': float(0.001)}
                elif model_nr == 7:
                    # if hyperparameter_index < 3:
                    #     return {'gamma': float(1), 'lr': float(0.00005 + hyperparameter_index * 0.00005)}
                    # return {'gamma': float(1), 'lr': float(0.0002 + (hyperparameter_index - 3) * 0.0001)}
                    # return {'gamma': float(1), 'lr': float(0.0009)}
                    return {'gamma': float(1), 'lr': float(0.0001 + hyperparameter_index * 0.0001)}

                elif model_nr == 6:
                    # return {'gamma':float(rand_between(0.9, 1)),'lr':float(0.1**rand_between(1,3))}
                    # return {'gamma':float(rand_between(0.906-0.1, 0.906+0.05)),'lr':float(0.0152*(0.3**rand_between(-1,1)))}
                    ###################
                    # return {'gamma':float(rand_between(0.908-0.05, 0.908+0.05)),'lr':float(0.0202*(0.7**rand_between(-1,1)))}
                    # return {'gamma':float(rand_between(0.925-0.01, 0.925+0.01)),'lr':float(0.02115*(0.95**rand_between(-1,1)))}
                    # return {'gamma':float(rand_between(0.929-0.01, 0.929+0.01)),'lr':float(0.02204*(0.9**rand_between(-1,0)))}
                    # return {'gamma': float(rand_between(0.933 - 0.01, 0.933 + 0.01)),'lr': float(0.02424 * (0.9 ** rand_between(-1, 0)))}
                    # return {'gamma': float(rand_between(0.937 - 0.01, 0.937 + 0.01)),'lr': float(0.02666 * (0.9 ** rand_between(-1, 0)))}
                    # return {'gamma': float(0.9329), 'lr': float(0.02787)}
                    # lr = (0.001 + 0.0005 * hyperparameter_index)
                    # return {'gamma': float(1), 'lr': float(lr)}
                    # if method > 0 and separate_hyperparameters_for_optimized:
                    #     lr = 0.0008 + 0.0004 * hyperparameter_index
                    #     return {'gamma': float(1), 'lr': float(lr)}
                    if method == 1:
                        if iter_count == 2:
                            return {'gamma': float(1), 'lr': float(0.0008)}
                        elif iter_count > 2:
                            return {'gamma': float(1), 'lr': float(0.0008)}

                    return {'gamma': float(1), 'lr': float(0.0008)}
                elif model_nr == 5:
                    if method == 1 and separate_hyperparameters_for_optimized:
                        # return {'gamma': float(1), 'lr': float(0.001*0.3**rand_between(-1,1))}
                        return {'gamma': float(1), 'lr': float(0.001023 * 0.9 ** rand_between(-1, 1))}

                    # return {'gamma':float(rand_between(0.906-0.1, 0.906+0.05)),'lr':float(0.0152*(0.3**rand_between(-1,1)))}
                    # return {'gamma':float(rand_between(0.896-0.05, 0.896+0.05)),'lr':float(0.00514*(0.3**rand_between(-0.15,1)))}
                    ####################
                    # return {'gamma':float(rand_between(0.896-0.05, 0.896+0.05)),'lr':float(0.00514*(0.7**rand_between(-1,1)))}
                    # return {'gamma':float(rand_between(0.9182-0.01, 0.9182+0.01)),'lr':float(0.00654*(0.95**rand_between(-1,1)))}
                    # return {'gamma':float(0.9145),'lr':float(0.006728)}
                    # return {'gamma': float(1), 'lr': float(0.00009 + hyperparameter_index * 0.00001)}
                    return {'gamma': float(1), 'lr': float(0.0001 + hyperparameter_index * 0.0001)}

            elif dataset == 'fashion_mnist':
                if model_nr == 10:
                    # lr = (0.00005 + 0.00005 * hyperparameter_index)
                    # lr = 0.0005#temporarily
                    lr = 0.00055
                    return {'gamma': float(1), 'lr': float(lr)}
                elif model_nr == 9:
                    # lr = (0.00005 + 0.00005 * hyperparameter_index) if hyperparameter_index <= 8 else (
                    #             0.0005 + 0.0001 * (hyperparameter_index - 9))
                    # if method > 0 and separate_hyperparameters_for_optimized:
                    #     # lr = 0.0003 + 0.0003 * hyperparameter_index if hyperparameter_index <= 4 else (
                    #     #             0.002 + 0.001 * (hyperparameter_index - 5))
                    #     lr = 0.0015 + 0.0005 * hyperparameter_index if hyperparameter_index <= 3 else (
                    #             0.004 + 0.001 * (hyperparameter_index - 4))
                    # return {'gamma': float(1), 'lr': float(lr)}
                    if method > 0 and separate_hyperparameters_for_optimized:
                        if method == 1:
                            # if iter_count == 2:
                            #     return {'gamma': float(1), 'lr': float(0.0009)}
                            # elif iter_count > 2:
                            #     return {'gamma': float(1), 'lr': float(0.0009)}

                            # lr = (0.0006 + 0.00015 * hyperparameter_index)
                            # lr = 0.00105
                            # #return {'gamma': float(1), 'lr': float(lr)}
                            # return {'gamma': float(1), 'lr': float(lr * np.e ** ((hyperparameter_index - 4) / (4) * (np.log(3))))}
                            return {'gamma': float(1), 'lr': float([0.0006,0.0008,0.0009,0.001,0.0011,0.0012,0.0014][hyperparameter_index])}
                    # return {'gamma': float(1), 'lr': float(0.0003)}
                    # lr = (0.0002 + 0.00005 * hyperparameter_index)
                    # lr = 0.0004
                    # #return {'gamma': float(1), 'lr': float(lr)}
                    # return {'gamma': float(1), 'lr': float(lr * np.e ** ((hyperparameter_index - 4) / (4) * (np.log(3))))}
                    return {'gamma': float(1),
                            'lr': float([0.0003,0.00035,0.0004, 0.00045, 0.0005, 0.00055, 0.0006][hyperparameter_index])}
                if model_nr == 8:
                    return {'gamma': float(1), 'lr': float(0.001)}
                elif model_nr == 7:
                    # if hyperparameter_index<3:
                    #     return {'gamma': float(1), 'lr': float(0.00005+hyperparameter_index*0.00005)}
                    # return {'gamma': float(1), 'lr': float(0.0002 + (hyperparameter_index-3) * 0.0001)}
                    # return {'gamma': float(1), 'lr': float(0.0002)}
                    return {'gamma': float(1), 'lr': float(0.00002 + hyperparameter_index * 0.00003)}

                elif model_nr == 6:
                    # return {}
                    # return {'gamma': float(rand_between(0.937 - 0.03, 0.937 + 0.03)),'lr': float(0.02666 * (0.7 ** rand_between(-2, 0)))}
                    # return {'gamma': float(rand_between(0.955 - 0.01, 0.955 + 0.01)),'lr': float(0.04535 * (0.9 ** rand_between(-1, 1)))}
                    ############## fixed
                    # return {'gamma': float(0.9581), 'lr': float(0.04678)}
                    # lr = (0.001 + 0.0005 * hyperparameter_index)
                    # return {'gamma': float(1), 'lr': float(lr)}
                    if method > 0 and separate_hyperparameters_for_optimized:
                        # lr = 0.0015 + 0.0005 * hyperparameter_index
                        # return {'gamma': float(1), 'lr': float(lr)}
                        if method == 1:
                            if iter_count == 2:
                                return {'gamma': float(1), 'lr': float(0.0019)}
                            elif iter_count > 2:
                                return {'gamma': float(1), 'lr': float(0.0015)}

                    return {'gamma': float(1), 'lr': float(0.0015)}

                elif model_nr == 5:
                    # return {'gamma': float(rand_between(0.9145 - 0.03, 0.9145 + 0.03)),'lr': float(0.006728 * (0.7 ** rand_between(-1, 1)))}
                    # return {'gamma': float(rand_between(0.930 - 0.01, 0.930 + 0.01)),'lr': float(0.00750 * (0.9 ** rand_between(-1, 1)))}
                    # return {'gamma': float(rand_between(0.939 - 0.01, 0.939 + 0.01)),'lr': float(0.008130 * (0.9 ** rand_between(-1, 0)))}
                    # return {'gamma': float(0.9384), 'lr': float(0.008208)}

                    # return {'gamma': float(1), 'lr': float(0.0001)}
                    # return {'gamma': float(1), 'lr': float(0.0001*3**rand_between(-1,1))}
                    # return {'gamma': float(1), 'lr': float(0.00009+hyperparameter_index*0.00001)}
                    return {'gamma': float(1), 'lr': float(0.0001 + hyperparameter_index * 0.0001)}
            elif dataset=='imdb':
                if model_nr==11:
                    if method == 1:
                        if iter_count == 2:
                            return {'gamma': float(1), 'lr': float(0.0004)}
                        elif iter_count > 2:
                            return {'gamma': float(1), 'lr': float(0.0004)}
                    #return {'gamma': float(1), 'lr': float(0.0001)}
                    #return {'gamma': float(1), 'lr': float(0.0005)}
                    #return {'gamma': float(1), 'lr': float(0.00003**(1+hyperparameter_index/(10-1)*(np.log(0.0003)/np.log(0.00003)-1)))}
                    return {'gamma': float(1), 'lr': float(0.0004)}
                if model_nr==12:
                    if method == 1:
                        if iter_count == 2:
                            #return {'gamma': float(1), 'lr': float(0.0001)}
                            #return {'gamma': float(1), 'lr': float(0.0025 ** (1 + (hyperparameter_index - 4) / (4) * (np.log(0.0003) / np.log(0.0001) - 1)))}
                            #return {'gamma': float(1), 'lr': float(0.0003641*1.3 ** (hyperparameter_index+1))}
                            return {'gamma': 1.0, 'lr': 0.00047333}
                            # return {'gamma': 1.0, 'lr': 0.0006906444697974225}#outlier of method 0
                        elif iter_count > 2:
                            #return {'gamma': float(1), 'lr': float(0.0001)}
                            #return {'gamma': float(1), 'lr': float(0.0025 ** (1 + (hyperparameter_index - 4) / (4) * (np.log(0.0003) / np.log(0.0001) - 1)))}
                            #return {'gamma': float(1), 'lr': float(0.0003641 * 1.3 ** (hyperparameter_index + 1))}
                            return {'gamma': 1.0, 'lr': 0.00047333}
                            # return {'gamma': 1.0, 'lr': 0.0006906444697974225}#outlier of method 0
                    #return {'gamma': float(1), 'lr': float(0.0001)}
                    #return {'gamma': float(1), 'lr': float(0.0015 ** (1 + (hyperparameter_index - 4) / (4) * (np.log(0.0003) / np.log(0.0001) - 1)))}
                    #return {'gamma': float(1), 'lr': float(0.0005689123460087089*1.25**-hyperparameter_index)}
                    #return {'gamma': 1.0, 'lr': 0.0006906444697974225}#outlier
                    return {'gamma': 1.0, 'lr': 0.0003641}
        return {'gamma': float(1), 'lr': float(0.0001)}

    summary=[]

    backup_name='_backup_'+save_name

    state=read_state(backup_name)
    trainings_different_model_start, trainings_same_model_start=(0,0)
    if state:
        global hyperparameter_index
        hyperparameter_index,actual_stats,summary,trainings_different_model_start,trainings_same_model_start,r_state=read_state(backup_name)
        torch.set_rng_state(r_state)

    for _test_params in range(trainings_different_model_start,trainings_different_model):
        params=generate_params()
        params['lr']=lr_mul*params['lr']
        if trainings_same_model_start==0:
            actual_stats= {}
        for _test_same_model in range(trainings_same_model_start,trainings_same_model):
            trainings_same_model_start=0
            #model = NN(copy.deepcopy(model_layers)).to(device)

            if cross_val:
                train_loader=None
                #test_loader=None

                skf=StratifiedKFold(n_splits=folds_num,shuffle=True,random_state=1)
                for i, (train_index, val_index) in enumerate(skf.split(dataset1.train_data,dataset1.train_labels)):
                    print("Test number [of different models . of the same model . split]: " + str(_test_params) + "." + str(
                        _test_same_model)+"."+str(i))
                    train_subsampler = torch.utils.data.SubsetRandomSampler(train_index)
                    val_subsampler = torch.utils.data.SubsetRandomSampler(val_index)
                    train_loader = torch.utils.data.DataLoader(
                        dataset1,
                        batch_size=batch_size, sampler=train_subsampler)
                    val_loader = torch.utils.data.DataLoader(
                        dataset1,
                        batch_size=batch_size, sampler=val_subsampler)

                    model = NN(get_layers()).to(device)
                    if model_load:
                        model=load_model(model_name)
                    print("Params: " + str(params))

                    model.stats = actual_stats
                    model.stats['params'] = params

                    lr = params['lr']
                    optimizer = optim.RMSprop(model.parameters(), lr)

                    gamma = params['gamma']
                    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

                    model2 = NN(copy.deepcopy(model.layers)).to(device)
                    model2.train()
                    model2.load_state_dict(model.state_dict())
                    # opt2 = type(optimizer)(model2.parameters(),lr=0.01,momentum=0,nesterov=False)
                    # opt2 = type(optimizer)(model2.parameters())
                    opt2 = type(optimizer)(model2.parameters(), lr=optimizer.param_groups[0]['lr'])
                    opt2.load_state_dict(optimizer.state_dict())

                    for epoch in range(1, epochs + 1):
                        if mini_batch:
                            train_minibatch(model, device, train_loader, optimizer, epoch, model2, opt2)
                        else:
                            train(model, device, train_loader, optimizer, epoch, model2, opt2)
                        # train(args, model, device, train_loader, optimizer, epoch)#train(args, model, device, train_loader, optimizer, epoch,model2,opt2)
                        test(model, device, test_loader, train_loader,val_loader=val_loader, epoch=epoch)
                        scheduler.step()


                        actual_stats = model.stats
                        print(model.stats)
                        processed_stats = process_stats(model)
                        processed_stats['evaluation']=evaluate_stats(processed_stats)

                        if model_save:
                            save_model(model, model_name)
                        if stop_criteria(processed_stats):
                            break
                    if stop_criteria(processed_stats):
                        break
            else:
                print("Test number [of different models . of the same model]: " + str(_test_params) + "." + str(
                    _test_same_model))

                model = NN(get_layers()).to(device)
                if model_load:
                    model = load_model(model_name)

                # model_copy=None
                # if calculate_dist:
                #     model_copy=copy.deepcopy(model).to('cpu')

                #set_params(model,params)
                #print("Model parameters: \ngradient_factor: "+str(float(model.gradient_factor))+"   step_factor: "+str(float(model.step_factor))+"    gradient_factor_simple_layers: "+str(float(model.gradient_factor_simple_layers)))
                print("Params: "+str(params))

                #if preserve_stats:
                model.stats = actual_stats
                model.stats['params']=params
                #optimizer = optim.Adadelta(model.parameters(), lr=lr)
                # lr = 0.007
                # if model_nr == 3 or model_nr == 4:
                #     lr = 0.007
                # elif model_nr == 5:
                #     lr = 0.0001
                # elif model_nr == 6:
                #     lr = 0.001
                # optimizer = optim.Adam(model.parameters(), lr)
                # optimizer = optim.Adam(model.parameters(), lr, betas=(0., 0.999))
                # if model_nr == 3 or model_nr == 4:
                #     lr = 0.07
                # elif model_nr == 5:
                #     lr = 0.001
                # elif model_nr == 6:
                #     lr = 0.01
                lr=params['lr']
                optimizer=optim.RMSprop(model.parameters(), lr)


                #optimizer=optim.SGD(model.parameters())
                #optimizer=optim.SGD(model.parameters(),lr=0.3,momentum=0,nesterov=False)

                gamma=params['gamma']
                scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

                model2 = NN(copy.deepcopy(model.layers)).to(device)
                model2.train()
                model2.load_state_dict(model.state_dict())
                #opt2 = type(optimizer)(model2.parameters(),lr=0.01,momentum=0,nesterov=False)
                #opt2 = type(optimizer)(model2.parameters())
                opt2 = type(optimizer)(model2.parameters(),lr=optimizer.param_groups[0]['lr'])
                opt2.load_state_dict(optimizer.state_dict())

                show_time = True
                start_time = time.time()
                for epoch in range(1, epochs + 1):
                    if pretraining is not None and not pretraining['finished']:
                        # global method
                        if pretraining['pretraining_finish_criterion'](actual_stats,epoch-1):
                            pretraining['finished'] = True
                            method = pretraining['method_after_pretraining']
                        else:
                            method = pretraining['method']
                        hyperparameter_index-=1
                        new_params=generate_params()
                        optimizer.param_groups[0]['lr']=new_params['lr']
                        gamma = params['gamma']
                        scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
                        actual_stats['params']=new_params
                        print("New params: " + str(actual_stats['params']))

                    if mini_batch:
                        train_minibatch(model, device, train_loader, optimizer, epoch, model2, opt2)
                    else:
                        train(model, device, train_loader, optimizer, epoch, model2, opt2)
                    if show_time:
                        print("--- %s seconds ---" % (time.time() - start_time))
                    #train(args, model, device, train_loader, optimizer, epoch)#train(args, model, device, train_loader, optimizer, epoch,model2,opt2)
                    test(model, device, test_loader,train_loader,epoch=epoch)
                    scheduler.step()

                    #model.stats['evaluation'] = evaluate_stats(process_stats(model))
                    actual_stats = model.stats
                    print(model.stats)
                    processed_stats=process_stats(model)
                    processed_stats['evaluation'] = evaluate_stats(processed_stats)

                    if model_save:
                        save_model(model,model_name)

                    if stop_criteria(processed_stats):
                        break
                if stop_criteria(processed_stats):
                    break

            save_state(backup_name,
                       (hyperparameter_index-1, actual_stats, summary, _test_params, _test_same_model + 1,torch.get_rng_state()))

        if best_stats is None or evaluate_stats(processed_stats)>=evaluate_stats(best_stats):
            best_stats = processed_stats
            print("New best stats:")
            print(best_stats)
            write_to_file(save_name,"New best stats!")
            write_to_file(save_name,str(model.stats))
        write_to_file(save_name, "Actual processed stats:")
        write_to_file(save_name, str(processed_stats))
        write_to_file(save_name, "Best stats so far:")
        write_to_file(save_name, str(best_stats))
        mean,mean_err=plots.utils.mean_and_sem_err(processed_stats['train_loss_min'])
        median, median_err = plots.utils.median_and_sem_err(processed_stats['train_loss_min'])
        summary.append((processed_stats['params'],processed_stats['evaluation'],{'avg_train_loss_min':mean,'pm_avg_train_loss':mean_err,'median_train_loss_min':median,'pm_median_train_loss':median_err}))
        summary=sorted(summary,key=lambda x: x[1])
        write_to_file('summary_'+save_name, str(summary),mode='w')
        write_to_file('short_summary_'+save_name, str([(s[0],s[1]) for s in summary]),mode='w')

        save_state(backup_name,(hyperparameter_index,actual_stats,summary,_test_params+1,0,torch.get_rng_state()))

if __name__ == '__main__':
    main()
