from __future__ import print_function
import argparse

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import copy

model_nr=4 #3 or 4
class NN(nn.Module):
    nonlinear_complex=(torch.nn.Softmax,torch.nn.LogSoftmax)#(torch.nn.Softmax,torch.nn.LogSoftmax)#todo: delete all classes that doesn't exist among model.layers
    nonlinear_simple=(torch.nn.Sigmoid,torch.nn.Tanh)#(torch.nn.Sigmoid,torch.nn.Tanh)#(torch.nn.Sigmoid,torch.nn.ReLU)

    stats=None
    def __init__(self,layers):
        super(NN, self).__init__()

        self.layers=layers
        self.layer_list=nn.ModuleList(self.layers)

        self.intermediate_outputs=[None for i in range(len(self.layers))]
        self.intermediate_outputs_grad=[None for i in range(len(self.layers))]

        #torch.autograd.set_detect_anomaly(True)

    #todo: check weights dtype
    small_num=(torch.finfo(torch.float32).eps)#**0.75#sys.float_info.epsilon**0.5

    @torch.no_grad()#copy of gradient_calc
    def gradient_calc_last_layer2(self,layer_input,layer_input_candidate,layer_input_candidate_grad,layer_out,layer_out_candidate_gradient,layer,loss_fn,target):
        input_mod=layer_input.detach().clone()
        # layer_input_candidate=layer_input_candidate.detach().clone()
        # layer_input_candidate+=0.001
        reshaped_input = layer_input.reshape((layer_input.size()[0], -1)).detach()
        reshaped_input_mod=input_mod.reshape((input_mod.size()[0],-1)).detach()
        reshaped_input_candidate=layer_input_candidate.reshape((input_mod.size()[0],-1)).detach()
        #reshaped_input_grad=layer_input.grad.reshape((layer_input.grad.size()[0],-1)).detach()
        reshaped_input_grad=layer_input_candidate_grad.reshape((layer_input_candidate_grad.size()[0],-1)).detach()

        reshaped_output_candidate_gradient=layer_out_candidate_gradient.reshape((layer_out_candidate_gradient.size()[0],-1)).detach()
        #last_index=-1
        #out=torch.empty_like(layer_input)
        for i in range(reshaped_input_mod.size()[1]):
            if i>0:
                reshaped_input_mod[:,i-1]=reshaped_input[:,i-1].detach()


            out1 = layer(input_mod)
            loss1 = loss_fn(out1, target)

            #reshaped_input_mod[:, i] = reshaped_input[:, i].detach() + 0.0001 #test
            #reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach() #original one
            reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach()*self.step_factor+reshaped_input[:, i]*(1.-self.step_factor)

            out = layer(input_mod)
            # out=nn.LogSoftmax(dim=1)(input_mod.detach().clone())
            loss2 = loss_fn(out, target)

            for j in range(reshaped_input_grad.size()[0]):
                diff=reshaped_input_mod[j,i]-reshaped_input[j,i]
                if torch.abs(diff)>self.small_num:
                    #reshaped_input_grad[j,i]=torch.sum((out[j].detach()-layer_out[j].detach())*layer_out.grad[j].detach()/diff)
                    reshaped_input_grad[j, i] =reshaped_input_grad[j, i]*(1-self.gradient_factor)+self.gradient_factor* (loss2 - loss1) / diff.detach()
                else:
                    print("Diff<=very small number")
    @torch.no_grad()
    def gradient_calc_last_layer(self, layer_input, layer_input_candidate, layer_input_candidate_grad, layer_out,
                      layer_out_candidate_gradient, layer,loss_fn,target):#doesn't work well
        input_mod = layer_input.detach().clone()
        # layer_input_candidate=layer_input_candidate.detach().clone()
        # layer_input_candidate+=0.001
        reshaped_input = layer_input.reshape((layer_input.size()[0], -1)).detach()
        reshaped_input_mod = input_mod.reshape((input_mod.size()[0], -1)).detach()
        reshaped_input_candidate = layer_input_candidate.reshape((input_mod.size()[0], -1)).detach()
        # reshaped_input_grad=layer_input.grad.reshape((layer_input.grad.size()[0],-1)).detach()
        reshaped_input_grad = layer_input_candidate_grad.reshape((layer_input_candidate_grad.size()[0], -1)).detach()

        reshaped_output_candidate_gradient = layer_out_candidate_gradient.reshape(
            (layer_out_candidate_gradient.size()[0], -1)).detach()
        # last_index=-1
        # out=torch.empty_like(layer_input)


        for i in range(reshaped_input_mod.size()[1]):
            if i > 0:
                reshaped_input_mod[:, i - 1] = reshaped_input[:, i - 1].detach()

            out1 = layer(input_mod)
            loss1 = loss_fn(out1, target)

            # reshaped_input_mod[:, i] = reshaped_input[:, i].detach() + 0.0001 #test
            #reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach() #original one
            reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach() * 0.5 + reshaped_input[:, i] * 0.5


            out = layer(input_mod)
            # out=nn.LogSoftmax(dim=1)(input_mod.detach().clone())
            loss2 = loss_fn(out,target)

            for j in range(reshaped_input_grad.size()[0]):
                diff = reshaped_input_mod[j, i] - reshaped_input[j, i]
                if torch.abs(diff) > self.small_num:
                    # reshaped_input_grad[j,i]=torch.sum((out[j].detach()-layer_out[j].detach())*layer_out.grad[j].detach()/diff)
                    #reshaped_input_grad[j, i] = reshaped_input_grad[j, i] * 0.5 + 0.5 * (loss2-loss1) / diff.detach()
                    reshaped_input_grad[j, i] =reshaped_input_grad[j, i]*0.5+0.5* (loss2 - loss1) / diff.detach()
                else:
                    print("Diff<=very small number")

    step_factor=0.347#0.336#0.39#0.9#0.39#0.85
    gradient_factor=0.622#0.621#0.53#0.05#0.53#0.05

    gradient_factor_simple_layers=0.354#0.25#0.25#0.5#0.5#0.25#1.
    @torch.no_grad()
    def gradient_calc_simple(self,layer_input,layer_input_candidate,layer_input_candidate_grad,layer_out,layer_out_candidate_gradient,layer):
        # #input_mod=layer_input.detach().clone()
        # # layer_input_candidate=layer_input_candidate.detach().clone()
        # # layer_input_candidate+=0.001
        # reshaped_input = layer_input.reshape((layer_input.size()[0], -1)).detach()
        # #reshaped_input_mod=input_mod.reshape((input_mod.size()[0],-1)).detach()
        # reshaped_input_candidate=layer_input_candidate.reshape((layer_input_candidate.size()[0],-1)).detach()
        # #reshaped_input_grad=layer_input.grad.reshape((layer_input.grad.size()[0],-1)).detach()
        # reshaped_input_grad=layer_input_candidate_grad.reshape((layer_input_candidate_grad.size()[0],-1)).detach()
        #
        # reshaped_output_candidate_gradient=layer_out_candidate_gradient.reshape((layer_out_candidate_gradient.size()[0],-1)).detach()
        # #last_index=-1
        # #out=torch.empty_like(layer_input)
        out = layer(layer_input_candidate)

        layer_input_candidate_grad[:]=torch.where(torch.abs(layer_input_candidate-layer_input)>self.small_num,
                                                  layer_input_candidate_grad*(1.-self.gradient_factor_simple_layers)+self.gradient_factor_simple_layers*(out-layer_out)*layer_out_candidate_gradient/(layer_input_candidate-layer_input),
                                                  layer_input_candidate_grad)

        # for i in range(reshaped_input.size()[1]):#that code maybe works
        #
        #     for j in range(reshaped_input_grad.size()[0]):
        #         diff=reshaped_input_candidate[j,i]-reshaped_input[j,i]
        #         if torch.abs(diff)>self.small_num:
        #             #reshaped_input_grad[j,i]=torch.sum((out[j].detach()-layer_out[j].detach())*layer_out.grad[j].detach()/diff)
        #             reshaped_input_grad[j,i]=reshaped_input_grad[j,i].detach()*(1.-self.gradient_factor)+self.gradient_factor*((out[j,i].detach()-layer_out[j,i].detach())*reshaped_output_candidate_gradient[j,i].detach()/diff.detach())
        #         #else:
        #         #    print("Diff<=very small number")
        #         #else default gradient value is kept
    @torch.no_grad()
    def gradient_calc(self,layer_input,layer_input_candidate,layer_input_candidate_grad,layer_out,layer_out_candidate_gradient,layer):
        input_mod=layer_input.detach().clone()
        # layer_input_candidate=layer_input_candidate.detach().clone()
        # layer_input_candidate+=0.001
        reshaped_input = layer_input.reshape((layer_input.size()[0], -1)).detach()
        reshaped_input_mod=input_mod.reshape((input_mod.size()[0],-1)).detach()
        reshaped_input_candidate=layer_input_candidate.reshape((input_mod.size()[0],-1)).detach()
        #reshaped_input_grad=layer_input.grad.reshape((layer_input.grad.size()[0],-1)).detach()
        reshaped_input_grad=layer_input_candidate_grad.reshape((layer_input_candidate_grad.size()[0],-1)).detach()

        reshaped_output_candidate_gradient=layer_out_candidate_gradient.reshape((layer_out_candidate_gradient.size()[0],-1)).detach()
        #last_index=-1
        #out=torch.empty_like(layer_input)
        for i in range(reshaped_input_mod.size()[1]):
            if i>0:
                reshaped_input_mod[:,i-1]=reshaped_input[:,i-1].detach()

            #reshaped_input_mod[:, i] = reshaped_input[:, i].detach() + 0.0001 #test
            #reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach() #original one
            reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach()*self.step_factor+reshaped_input[:, i]*(1.-self.step_factor)

            out=layer(input_mod)
            #out=nn.LogSoftmax(dim=1)(input_mod.detach().clone())

            for j in range(reshaped_input_grad.size()[0]):
                diff=reshaped_input_mod[j,i]-reshaped_input[j,i]
                if torch.abs(diff)>self.small_num:
                    #reshaped_input_grad[j,i]=torch.sum((out[j].detach()-layer_out[j].detach())*layer_out.grad[j].detach()/diff)
                    reshaped_input_grad[j,i]=reshaped_input_grad[j,i]*(1.-self.gradient_factor)+self.gradient_factor*torch.sum((out[j].detach()-layer_out[j].detach())*reshaped_output_candidate_gradient[j].detach()/diff.detach())
                else:
                    print("Diff<=very small number")
                #else default gradient value is kept

            #reshaped_input_grad[:, i]=1
            #last_index=i

        #reshaped_input_mod[:, reshaped_input_mod.size()[1]-1] = reshaped_input[:, reshaped_input_mod.size()[1]-1].detach()
        #reshaped_input_grad._version-=1

    input_cached=None

    def backward_grad_correction(self,loss,model2,loss_fn,target):
        ind = len(self.intermediate_outputs) - 1
        grad = torch.autograd.grad(outputs=loss,
                                   inputs=self.intermediate_outputs[len(self.intermediate_outputs) - 1],
                                   # inputs=self.intermediate_outputs[0],
                                   grad_outputs=None
                                   , retain_graph=True)
        self.intermediate_outputs_grad[ind] = grad[0]
        for ind in range(len(self.intermediate_outputs) - 1, -1, -1):
            if hasattr(self.layers[ind], 'weight'):
                if ind>0:
                    self.intermediate_outputs[ind] = self.layers[ind](self.intermediate_outputs[ind-1])
                else:
                    self.intermediate_outputs[ind]=self.layers[ind](self.input_cached)
                grad_outputs = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                                   inputs=self.layers[ind].weight,
                                                   # inputs=self.intermediate_outputs[0],
                                                   grad_outputs=self.intermediate_outputs_grad[ind]
                                                   , retain_graph=True)
                self.layers[ind].weight.grad = grad_outputs[0]

                model2.layers[ind].weight.grad=grad_outputs[0].detach()#.clone()
                # if ind > 0:
                #     grad_layers = torch.autograd.grad(outputs=self.layers[ind].weight,
                #                                       inputs=self.intermediate_outputs[ind - 1],
                #                                       # inputs=self.intermediate_outputs[0],
                #                                       grad_outputs=self.layers[ind].weight.grad
                #                                       , retain_graph=True)
                #     self.intermediate_outputs[ind - 1].grad = grad_layers[0]
            if ind > 0:
                grad_layers = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                                  inputs=self.intermediate_outputs[ind - 1],
                                                  # inputs=self.intermediate_outputs[0],
                                                  grad_outputs=self.intermediate_outputs_grad[ind]
                                                  , retain_graph=True,
                                                  #allow_unused=True
                                                  )
                self.intermediate_outputs_grad[ind - 1] = grad_layers[0]

                # if ind==len(self.intermediate_outputs)-1:
                #     self.gradient_calc(self.intermediate_outputs[ind - 1],self.intermediate_outputs[ind - 1],
                #                        self.intermediate_outputs[ind],self.layers[ind])
                # pass

                enhanced=True
                if enhanced:
                    if isinstance(self.layers[ind], self.nonlinear_complex):
                        # self.gradient_calc(model2.intermediate_outputs[ind - 1], self.intermediate_outputs[ind - 1],
                        #                    model2.intermediate_outputs[ind],self.intermediate_outputs[ind].grad, self.layers[ind])

                        # self.gradient_calc(model2.intermediate_outputs[ind - 1], self.intermediate_outputs[ind - 1],
                        #                    self.intermediate_outputs_grad[ind - 1],
                        #                    model2.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                        #                    self.layers[ind])
                        log=False
                        if log:
                            print("----")
                            print(self.intermediate_outputs_grad[ind - 1])
                        self.gradient_calc(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                                           self.intermediate_outputs_grad[ind - 1],
                                           self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                                           self.layers[ind])
                        # self.gradient_calc_last_layer2(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                        #                    self.intermediate_outputs_grad[ind - 1],
                        #                    self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                        #                    self.layers[ind],loss_fn,target)
                        # self.intermediate_outputs_grad[ind - 1]=0.5*self.intermediate_outputs_grad[ind - 1]+\
                        #                                         model2.intermediate_outputs[ind - 1].grad * 0.5
                        # self.gradient_calc_last_layer(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                        #                    self.intermediate_outputs_grad[ind - 1],
                        #                    self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                        #                    self.layers[ind],loss_fn,target)
                        if log:
                            print(self.intermediate_outputs_grad[ind - 1])
                    elif isinstance(self.layers[ind], self.nonlinear_simple):
                        self.gradient_calc_simple(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                                           self.intermediate_outputs_grad[ind - 1],
                                           self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                                           self.layers[ind])
                #self.intermediate_outputs[ind - 1].grad.detach()

    def backward(self,loss):

        ind=len(self.intermediate_outputs)-1
        grad = torch.autograd.grad(outputs=loss,
                                   inputs=self.intermediate_outputs[len(self.intermediate_outputs)-1],
                                   # inputs=self.intermediate_outputs[0],
                                   grad_outputs=None
                                   , retain_graph=True)
        self.intermediate_outputs[ind].grad=grad[0]
        for ind in range(len(self.intermediate_outputs)-1,-1,-1):
            if hasattr(self.layers[ind],'weight'):
                grad_outputs = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                           inputs=self.layers[ind].weight,
                                           # inputs=self.intermediate_outputs[0],
                                           grad_outputs=self.intermediate_outputs[ind].grad
                                           , retain_graph=True)
                self.layers[ind].weight.grad=grad_outputs[0]
                # if ind > 0:
                #     grad_layers = torch.autograd.grad(outputs=self.layers[ind].weight,
                #                                       inputs=self.intermediate_outputs[ind - 1],
                #                                       # inputs=self.intermediate_outputs[0],
                #                                       grad_outputs=self.layers[ind].weight.grad
                #                                       , retain_graph=True)
                #     self.intermediate_outputs[ind - 1].grad = grad_layers[0]
            if ind > 0:
                grad_layers = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                                  inputs=self.intermediate_outputs[ind - 1],
                                                  # inputs=self.intermediate_outputs[0],
                                                  grad_outputs=self.intermediate_outputs[ind].grad
                                                  , retain_graph=True)
                self.intermediate_outputs[ind - 1].grad = grad_layers[0]
                # if ind==len(self.intermediate_outputs)-1:
                #     self.gradient_calc(self.intermediate_outputs[ind - 1],self.intermediate_outputs[ind - 1],
                #                        self.intermediate_outputs[ind],self.layers[ind])
                    #pass
                # if isinstance(self.layers[ind],self.nonlinear_complex):
                #     self.gradient_calc(self.intermediate_outputs[ind - 1], self.intermediate_outputs[ind - 1],
                #                        self.intermediate_outputs[ind], self.layers[ind])
                #pass

    def forward(self, x):
        self.input_cached=x
        #layers_num=len(self.layers)
        #for i in range(layers_num-1):
        #    x=self.layers[i](x)
        #output=self.layers[layers_num-1](x)
        #return output

        i=0
        for layer in self.layers:
            if self.intermediate_outputs[i] is None:
                x=layer(x)
                #self.intermediate_outputs[i]=(Variable(x, requires_grad=True))
                self.intermediate_outputs[i]=x
                self.intermediate_outputs[i].require_grad=True

                #self.intermediate_outputs_grad[i]=torch.empty_like(x)
            else:
                self.intermediate_outputs[i] = layer(x)
                x=self.intermediate_outputs[i]
            i+=1
        return x

    def copy_grad_from(self,model):
        for ind in range(0,len(self.layers)):
            if hasattr(self.layers[ind],'weight') and hasattr(model.layers[ind],'weight'):
                if hasattr(self.layers[ind].weight,'grad') and self.layers[ind].weight.grad is not None:
                    self.layers[ind].weight.grad[:]=model.layers[ind].weight.grad[:]
                else:
                    self.layers[ind].weight.grad = model.layers[ind].weight.grad.detach().clone()
                    #self.layers[ind].weight.grad = model.layers[ind].weight.grad.detach()



opt_loop2 = False#True #in development
iter_count=1
opt_loop=False #doesnt work well
opt_1_iter_optimized = False #doesnt work well
opt_1_iter = False#True
test_of_delegated_step=False
copy_test=False
opt_test=False
copy_grad_experiment=False
def train(model, device, train_loader, optimizer, epoch,model2=None,opt2=None,log_interval=10):
    model.train()

    if model2 is None or opt2 is None:
        model2 = NN(copy.deepcopy(model.layers))
        model2.train()
        model2.load_state_dict(model.state_dict().copy())
        opt2 = type(optimizer)(model2.parameters())
        opt2.load_state_dict(optimizer.state_dict().copy())
    model2.step_factor=model.step_factor
    model2.gradient_factor=model.gradient_factor
    #opt2.lr=optimizer.lr

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        if opt_loop2:
            for _ in range(iter_count):
                model2.load_state_dict(model.state_dict())
                opt2.load_state_dict(optimizer.state_dict())

                #opt2.zero_grad()
                #optimizer.zero_grad()
                # data.requires_grad=True

                output = model2(data)
                #loss = F.nll_loss(output, target)
                loss = F.binary_cross_entropy_with_logits(output,target)
                # loss.requires_grad=True

                # model2.backward(loss)
                loss.backward(retain_graph=True)

                # loss.backward()
                # model2.backward(loss)
                opt2.step()

                output1 = model(data)
                loss1 = F.binary_cross_entropy_with_logits(output1, target)

                output = model2(data)
                loss = F.binary_cross_entropy_with_logits(output, target)#todo check what this line changes

                model.backward_grad_correction(loss1, model2, F.nll_loss, target)
                # model2.backward_grad_correction(loss,model)

                #optimizer.step()
                #opt2.zero_grad()


            # model2.load_state_dict(model.state_dict())
            # opt2.load_state_dict(optimizer.state_dict())
            # opt2.step()
            # model.load_state_dict(model2.state_dict())
            # optimizer.load_state_dict(opt2.state_dict())


            optimizer.step()
        elif opt_loop:
            # optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            # loss.requires_grad=True

            # model2.backward(loss)
            loss.backward(retain_graph=True)

            for iter in range(iter_count):
                model2.load_state_dict(model.state_dict())
                opt2.load_state_dict(optimizer.state_dict())

                # opt2.zero_grad()
                # optimizer.zero_grad()
                # data.requires_grad=True

                # loss.backward()
                # model2.backward(loss)
                output2 = model2(data)
                loss2 = F.nll_loss(output2, target)
                # loss.requires_grad=True

                # model2.backward(loss)
                loss2.backward(retain_graph=False)
                # if iter==0:
                #     model2.copy_grad_from(model)
                opt2.step()

                # output1 = model(data)
                # loss1 = F.nll_loss(output1, target)
                # output = model(data)
                # loss = F.nll_loss(output, target)



                if iter>0:
                    output=model(data)
                    loss = F.nll_loss(output, target)
                    # loss.backward(retain_graph=True)
                output2 = model2(data)
                #loss2 = F.nll_loss(output2, target)#todo check what this line changes

                model.backward_grad_correction(loss, model2, F.nll_loss, target)
                # model2.backward_grad_correction(loss,model)

            # optimizer.step()
            # opt2.zero_grad()

            # model2.load_state_dict(model.state_dict())
            # opt2.load_state_dict(optimizer.state_dict())
            # opt2.step()
            # model.load_state_dict(model2.state_dict())
            # optimizer.load_state_dict(opt2.state_dict())

            optimizer.step()
        elif opt_1_iter_optimized:
            #optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            # loss.requires_grad=True

            # model2.backward(loss)
            loss.backward(retain_graph=True)

            model2.load_state_dict(model.state_dict())
            opt2.load_state_dict(optimizer.state_dict())

            # opt2.zero_grad()
            # optimizer.zero_grad()
            # data.requires_grad=True



            # loss.backward()
            # model2.backward(loss)
            model2.copy_grad_from(model)
            opt2.step()

            #output1 = model(data)
            #loss1 = F.nll_loss(output1, target)
            #output = model(data)
            #loss = F.nll_loss(output, target)

            output2 = model2(data)
            #loss = F.nll_loss(output, target)#todo check what this line changes

            model.backward_grad_correction(loss, model2, F.nll_loss, target)
            # model2.backward_grad_correction(loss,model)

            #optimizer.step()
            #opt2.zero_grad()


            # model2.load_state_dict(model.state_dict())
            # opt2.load_state_dict(optimizer.state_dict())
            # opt2.step()
            # model.load_state_dict(model2.state_dict())
            # optimizer.load_state_dict(opt2.state_dict())


            optimizer.step()
        elif opt_1_iter:
            model2.load_state_dict(model.state_dict())
            opt2.load_state_dict(optimizer.state_dict())

            opt2.zero_grad()
            optimizer.zero_grad()
            # data.requires_grad=True

            output = model2(data)
            loss = F.nll_loss(output, target)
            # loss.requires_grad=True

            # model2.backward(loss)
            loss.backward(retain_graph=True)

            # loss.backward()
            # model2.backward(loss)
            opt2.step()

            output1 = model(data)
            loss1 = F.nll_loss(output1, target)

            output = model2(data)
            loss = F.nll_loss(output, target)#todo check what this line changes

            model.backward_grad_correction(loss1, model2, F.nll_loss, target)
            # model2.backward_grad_correction(loss,model)

            #optimizer.step()
            #opt2.zero_grad()


            # model2.load_state_dict(model.state_dict())
            # opt2.load_state_dict(optimizer.state_dict())
            # opt2.step()
            # model.load_state_dict(model2.state_dict())
            # optimizer.load_state_dict(opt2.state_dict())


            optimizer.step()
        elif test_of_delegated_step:
            model2.load_state_dict(model.state_dict())
            opt2.load_state_dict(optimizer.state_dict())

            opt2.zero_grad()
            optimizer.zero_grad()
            # data.requires_grad=True

            output = model2(data)
            loss = F.nll_loss(output, target)
            # loss.requires_grad=True

            # model2.backward(loss)
            loss.backward(retain_graph=True)

            # loss.backward()
            # model2.backward(loss)
            opt2.step()

            output1 = model(data)
            loss1 = F.nll_loss(output1, target)
            model.backward_grad_correction(loss1, model2, F.nll_loss, target)
            # model2.backward_grad_correction(loss,model)

            #optimizer.step()
            #opt2.zero_grad()#fails there
            #optimizer.zero_grad()
            model2.load_state_dict(model.state_dict())
            opt2.load_state_dict(optimizer.state_dict())
            model2.copy_grad_from(model)
            opt2.step()
            model.load_state_dict(model2.state_dict())
            optimizer.load_state_dict(opt2.state_dict())

        elif copy_grad_experiment:
            model2.load_state_dict(model.state_dict())
            opt2.load_state_dict(optimizer.state_dict())

            opt2.zero_grad()
            optimizer.zero_grad()
            # data.requires_grad=True
            output = model2(data)
            loss = F.nll_loss(output, target)
            # loss.requires_grad=True

            # if True:
            #     model2.backward(loss)
            #     #loss.backward(retain_graph=True)
            #     # model2.backward(loss)
            #
            #
            #     for i in range(len(model2.layers)):
            #         if hasattr(model2.layers[i],"weight"):
            #             model.layers[i].weight.grad=model2.layers[i].weight.grad.detach()
            #_ = model(data)
            model2.backward_grad_correction(loss,model)
            optimizer.step()
            #opt2.step()
            #model.load_state_dict(model2.state_dict())
            #optimizer.load_state_dict(opt2.state_dict())
        else:
            if not opt_test:
                if not copy_test:
                    optimizer.zero_grad()
                    #data.requires_grad=True
                    output = model(data)
                    loss = F.cross_entropy(output, target)
                    #loss.requires_grad=True

                    #loss.backward()
                    model.backward(loss)
                    optimizer.step()
                else:
                    model2.load_state_dict(model.state_dict())
                    opt2.load_state_dict(optimizer.state_dict())

                    opt2.zero_grad()
                    # data.requires_grad=True
                    output = model2(data)
                    loss = F.nll_loss(output, target)
                    # loss.requires_grad=True

                    model2.backward(loss)
                    #loss.backward()
                    #model2.backward(loss)
                    opt2.step()
                    model.load_state_dict(model2.state_dict())
                    optimizer.load_state_dict(opt2.state_dict())
            else:
                model2.load_state_dict(model.state_dict())
                opt2.load_state_dict(optimizer.state_dict())

                opt2.zero_grad()
                optimizer.zero_grad()
                # data.requires_grad=True

                output = model2(data)
                loss = F.nll_loss(output, target)
                # loss.requires_grad=True

                #model2.backward(loss)
                loss.backward(retain_graph=True)

                # loss.backward()
                # model2.backward(loss)
                opt2.step()

                output1=model(data)
                loss1=F.nll_loss(output1, target)
                model.backward_grad_correction(loss1,model2,F.nll_loss,target)
                #model2.backward_grad_correction(loss,model)

                optimizer.step()

                #model.load_state_dict(model2.state_dict())
                #optimizer.load_state_dict(opt2.state_dict())

        #dziala:
        #print(model.layers[len(model.layers)-5].weight.grad)
        #loss.backward(inputs=model.layers[len(model.layers)-5].weight)
        #print(model.layers[len(model.layers)-5].weight.grad)

        #data.requires_grad=True
        #output.requires_grad=True

        #print(torch.autograd.grad(outputs=output,inputs=data))
        #output.backward(gradient=loss.)
        #model.backward(loss)

        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))



def test(model, device, test_loader,train_loader,epoch=1):
    if model.stats is None:
        model.stats={}
        model.stats['train_accuracy']={}
        model.stats['test_accuracy'] = {}
        model.stats['train_loss'] = {}
        model.stats['test_loss'] = {}
    if epoch  not in model.stats['train_accuracy']:
        model.stats['train_accuracy'][epoch]=[]
    if epoch  not in model.stats['train_loss']:
        model.stats['train_loss'][epoch]=[]
    if epoch  not in model.stats['test_accuracy']:
        model.stats['test_accuracy'][epoch]=[]
    if epoch  not in model.stats['test_loss']:
        model.stats['test_loss'][epoch]=[]
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target,reduction='sum').item()
            #test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Avg loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    model.stats['test_accuracy'][epoch].append(100. * correct / len(test_loader.dataset))
    model.stats['test_loss'][epoch].append(test_loss)

    train_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            train_loss += F.cross_entropy(output, target, reduction='sum').item()
            #train_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    train_loss /= len(train_loader.dataset)

    print('\nTraining set: Avg loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))
    model.stats['train_accuracy'][epoch].append(100. * correct / len(train_loader.dataset))
    model.stats['train_loss'][epoch].append(train_loss)

def evaluate_stats(stats):
    # # return -stats['train_loss_avg']

    min_loss=np.Inf #minimal training loss in any epoch
    for epoch_num,score in stats['train_loss'].items():
        if min_loss>score:
            min_loss=score
        # if min_loss>sum(scores)/len(scores):
        #     min_loss=sum(scores)/len(scores)
    return -min_loss #"-" to minimize min_loss value instead of default maximization

def stop_criteria(stats):
    return evaluate_stats(stats)<-3.
    
gamma=0.7
lr=1
def process_stats(model):
    stats={}
    for key,value in model.stats.items():
        stats[key]={}
        all_scores=[]
        for epoch_num,list_of_scores in value.items():
            stats[key][epoch_num]=sum(list_of_scores)/len(list_of_scores)
            all_scores+=list_of_scores
        stats[str(key)+'_avg']=sum(all_scores)/len(all_scores)
    #stats['gradient_factor']=model.gradient_factor
    #stats['step_factor'] = model.step_factor
    stats['gamma']=gamma
    stats['lr']=lr
    return stats
def write_to_file(name,text):
    with open(name, "a") as myfile:
        myfile.write(text+'\n')

def rand_between(a,b):
    return torch.FloatTensor(1).uniform_(a, b)[0]

def show_model_parameter_numbers(model):
    param_num=0
    s=""
    for p in model.parameters():
        if p.requires_grad:
            param_num+=p.numel()
            s+=str(p.numel())+"  "
    print("Model parameters: "+s)
    print("Sum: "+str(param_num))
def main():
    device = torch.device("cpu")

    batch_size=64#64
    global gamma
    global lr
    lr=1

    train_kwargs = {'batch_size': batch_size}
    test_kwargs = {'batch_size': batch_size}

    torch.manual_seed(1)

    transform=transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize((0.5,), (0.5,)),
        ])
    dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1,shuffle=True,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2,shuffle=True, **test_kwargs)

    def get_layers():
        if model_nr==3:
            return [nn.Conv2d(1, 16, 3, 1),
                    nn.ReLU(),
                    nn.BatchNorm2d(16),
                    nn.Conv2d(16, 16, 3, 1),
                    nn.ReLU(),
                    nn.BatchNorm2d(16),
                    nn.MaxPool2d(2),
                    # nn.Dropout(0.25),
                    nn.Flatten(),
                    nn.Linear(2304, 32),
                    nn.ReLU(),
                    nn.BatchNorm1d(32),
                    # nn.Dropout(0.5),
                    nn.Linear(32, 10),
                    nn.LogSoftmax(dim=1)]
        elif model_nr==4:
            return [nn.Conv2d(1, 8, 3, 1),
                    nn.ReLU(),
                    nn.BatchNorm2d(8),
                    # nn.MaxPool2d(2),
                    nn.Conv2d(8, 8, 3, 1),
                    nn.ReLU(),
                    nn.BatchNorm2d(8),
                    nn.Conv2d(8, 16, kernel_size=5, stride=2, padding=2),
                    nn.ReLU(),
                    nn.BatchNorm2d(16),

                    nn.Conv2d(16, 16, 3, 1),
                    nn.ReLU(),
                    nn.BatchNorm2d(16),
                    # nn.MaxPool2d(2),
                    nn.Conv2d(16, 16, 3, 1),
                    nn.ReLU(),
                    nn.BatchNorm2d(16),
                    nn.Conv2d(16, 16, kernel_size=5, stride=2, padding=2),
                    nn.ReLU(),
                    nn.BatchNorm2d(16),

                    # nn.MaxPool2d(2),
                    # nn.Dropout(0.25),
                    nn.Flatten(),
                    nn.Linear(256, 10),
                    # nn.ReLU(),
                    # nn.Dropout(0.5),
                    # nn.Linear(32, 10),
                    nn.Softmax(dim=1)]
        return None
        # return [nn.Conv2d(1, 8, 3, 1),
        #         nn.ReLU(),
        #         nn.BatchNorm2d(8),
        #         #nn.MaxPool2d(2),
        #         nn.Conv2d(8, 8, 3, 1),
        #         nn.ReLU(),
        #         nn.BatchNorm2d(8),
        #         nn.Conv2d(8,16, kernel_size=5, stride=2,padding=2),
        #         nn.ReLU(),
        #         nn.BatchNorm2d(16),
        #
        #         nn.Conv2d(16, 16, 3, 1),
        #         nn.ReLU(),
        #         nn.BatchNorm2d(16),
        #         #nn.MaxPool2d(2),
        #         nn.Conv2d(16, 16, 3, 1),
        #         nn.ReLU(),
        #         nn.BatchNorm2d(16),
        #         nn.Conv2d(16, 16, kernel_size=5, stride=2,padding=2),
        #         nn.ReLU(),
        #         nn.BatchNorm2d(16),
        #
        #         #nn.MaxPool2d(2),
        #         # nn.Dropout(0.25),
        #         nn.Flatten(),
        #         nn.Linear(256, 10),
        #         #nn.ReLU(),
        #         # nn.Dropout(0.5),
        #         #nn.Linear(32, 10),
        #         nn.Softmax(dim=1)]
        # return [nn.Conv2d(1, 16, 3, 1),
        #  nn.ReLU(),
        #  nn.BatchNorm2d(16),
        #  nn.Conv2d(16, 16, 3, 1),
        #  nn.ReLU(),
        #  nn.BatchNorm2d(16),
        #  nn.MaxPool2d(2),
        #  # nn.Dropout(0.25),
        #  nn.Flatten(),
        #  nn.Linear(2304, 32),
        #  nn.ReLU(),
        #  nn.BatchNorm1d(32),
        #  # nn.Dropout(0.5),
        #  nn.Linear(32, 10),
        #  nn.LogSoftmax(dim=1)]
        # return [nn.Conv2d(1, 32, 3, 1),
        #         nn.ReLU(),
        #         nn.MaxPool2d(2),
        #         nn.Conv2d(32, 12, 3, 1),
        #         nn.ReLU(),
        #         nn.MaxPool2d(2),
        #         # nn.Dropout(0.25),
        #         nn.Flatten(),
        #         nn.Linear(300, 32),
        #         nn.ReLU(),
        #         # nn.Dropout(0.5),
        #         nn.Linear(32, 10),
        #         nn.Softmax(dim=1)]
        # return [nn.Conv2d(1, 16, 3, 1),
        #         nn.ReLU(),
        #         nn.Conv2d(16, 16, 3, 1),
        #         nn.ReLU(),
        #         nn.MaxPool2d(2),
        #         # nn.Dropout(0.25),
        #         nn.Flatten(),
        #         nn.Linear(2304, 32),
        #         nn.ReLU(),
        #         # nn.Dropout(0.5),
        #         nn.Linear(32, 10),
        #         nn.LogSoftmax(dim=1)]
        # return [nn.Conv2d(1, 32, 3, 1),
        #  nn.ReLU(),
        #  nn.Conv2d(32, 64, 3, 1),
        #  nn.ReLU(),
        #  nn.MaxPool2d(2),
        #  # nn.Dropout(0.25),
        #  nn.Flatten(),
        #  nn.Linear(9216, 64),
        #  nn.ReLU(),
        #  # nn.Dropout(0.5),
        #  nn.Linear(64, 10),
        #  nn.LogSoftmax(dim=1)]
        # return [nn.Conv2d(1, 32, 3, 1),
        #             nn.Tanh(),
        #             nn.Conv2d(32, 64, 3, 1),
        #             nn.Tanh(),
        #             nn.MaxPool2d(2),
        #             # nn.Dropout(0.25),
        #             nn.Flatten(),
        #             nn.Linear(9216, 64),
        #             nn.Tanh(),
        #             # nn.Dropout(0.5),
        #             nn.Linear(64, 10),
        #             nn.LogSoftmax(dim=1)]

    # model_layers=[nn.Conv2d(1, 32, 3, 1),
    #                 nn.ReLU(),
    #                  nn.Conv2d(32, 64, 3, 1),
    #                  nn.ReLU(),
    #                 nn.MaxPool2d(2),
    #                  #nn.Dropout(0.25),
    #                 nn.Flatten(),
    #                  nn.Linear(9216, 64),
    #                  nn.ReLU(),
    #                  #nn.Dropout(0.5),
    #                  nn.Linear(64, 10),
    #                  nn.LogSoftmax(dim=1)]

    # model_layers = [nn.Conv2d(1, 32, 3, 1),
    #                 nn.Tanh(),
    #                 nn.Conv2d(32, 64, 3, 1),
    #                 nn.Tanh(),
    #                 nn.MaxPool2d(2),
    #                 # nn.Dropout(0.25),
    #                 nn.Flatten(),
    #                 nn.Linear(9216, 64),
    #                 nn.Tanh(),
    #                 # nn.Dropout(0.5),
    #                 nn.Linear(64, 10),
    #                 nn.LogSoftmax(dim=1)]
    #lr=0.1

    epochs=20 #todo: change to ~20
    save_name='results'
    save_name+='_model'+str(model_nr)
    save_name+='.txt'
    model=None

    best_stats=None
    processed_stats=None
    actual_stats = None
    #preserve_stats=False

    def set_params(model,params):
        #model.gradient_factor = params[0]
        #model.step_factor = params[1]
        model.gradient_factor_simple_layers = params
    def generate_params():
        #return (rand_between(0,1),rand_between(0,1))
        #return (rand_between(-1, 1)*0.1+0.53, rand_between(-1, 1)*0.1+0.39)
        #return (rand_between(-1, 1)*0.1+0.621, rand_between(-1, 1)*0.1+0.336)
        return {'gamma':rand_between(0.7, 1),'lr':rand_between(0.004,0.01)}

    for _test_params in range(100):
        params=generate_params()
        actual_stats=None
        for _test_same_model in range(3):#todo: change to 3
            print("Test number [of different models . of the same model]: "+str(_test_params)+"."+str(_test_same_model))
            #model = NN(copy.deepcopy(model_layers)).to(device)
            model = NN(get_layers()).to(device)

            show_model_parameter_numbers(model)
            #set_params(model,params)
            #print("Model parameters: \ngradient_factor: "+str(float(model.gradient_factor))+"   step_factor: "+str(float(model.step_factor))+"    gradient_factor_simple_layers: "+str(float(model.gradient_factor_simple_layers)))
            print("Gamma = "+str(float(params['gamma']))+"   lr="+str(float(params['lr'])))
            gamma = params['gamma']
            lr=params['lr']

            #if preserve_stats:
            model.stats = actual_stats

            #lr=0.001
            #optimizer = optim.Adadelta(model.parameters(), lr=lr)
            optimizer = optim.Adam(model.parameters(),lr=lr)
            #optimizer=optim.SGD(model.parameters())
            #optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0,nesterov=False)


            scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

            model2 = NN(copy.deepcopy(model.layers)).to(device)
            model2.train()
            model2.load_state_dict(model.state_dict())
            #opt2 = type(optimizer)(model2.parameters(),lr=0.01,momentum=0,nesterov=False)
            opt2 = type(optimizer)(model2.parameters())
            opt2.load_state_dict(optimizer.state_dict())

            for epoch in range(1, epochs + 1):
                train(model, device, train_loader, optimizer, epoch, model2, opt2)
                #train(args, model, device, train_loader, optimizer, epoch)#train(args, model, device, train_loader, optimizer, epoch,model2,opt2)
                test(model, device, test_loader,train_loader,epoch)
                scheduler.step()

                actual_stats = model.stats
                print(model.stats)
                processed_stats=process_stats(model)
                if stop_criteria(processed_stats):
                    break
            if stop_criteria(processed_stats):
                break

        if best_stats is None or evaluate_stats(processed_stats)>=evaluate_stats(best_stats):
            best_stats = processed_stats
            print("New best stats:")
            print(best_stats)
            write_to_file(save_name,"New best stats!")
            write_to_file(save_name,str(model.stats))
        write_to_file(save_name, "Actual processed stats:")
        write_to_file(save_name, str(processed_stats))
        write_to_file(save_name, "Best stats so far:")
        write_to_file(save_name, str(best_stats))


if __name__ == '__main__':
    main()
