from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
#import sys


# def hook(module, grad_input, grad_output):
#     pass
class Net(nn.Module):
    many_to_many_nonlinear=(torch.nn.Softmax,torch.nn.LogSoftmax)
    def __init__(self):
        super(Net, self).__init__()
        # self.conv1 = nn.Conv2d(1, 32, 3, 1)
        # self.conv2 = nn.Conv2d(32, 64, 3, 1)
        # self.dropout1 = nn.Dropout(0.25)
        # self.dropout2 = nn.Dropout(0.5)
        # self.fc1 = nn.Linear(9216, 128)
        # self.fc2 = nn.Linear(128, 10)
        # self.softmax=nn.LogSoftmax(dim=1)

        self.layers=[nn.Conv2d(1, 32, 3, 1),
                    nn.ReLU(),
                     nn.Conv2d(32, 64, 3, 1),
                     nn.ReLU(),
                    nn.MaxPool2d(2),
                     nn.Dropout(0.25),
                    nn.Flatten(),
                     nn.Linear(9216, 128),
                     nn.ReLU(),
                     nn.Dropout(0.5),
                     nn.Linear(128, 10),
                     nn.LogSoftmax(dim=1)]
        self.layer_list=nn.ModuleList(self.layers)

        #layer_num = len(self.layers)
        #for i in range(layer_num - 1, -1, -1):
        #    self.layers[i].register_backward_hook(hook)
        self.intermediate_outputs=[None for i in range(len(self.layers))]
        self.intermediate_outputs_grad=[None for i in range(len(self.layers))]


        torch.autograd.set_detect_anomaly(True)
        # self.tensors=[]
        # for i in range(len(self.layers)):
        #     #self.layers[i]=\
        #     if hasattr(self.layers[i],'weight'):
        #         self.tensors.append(Variable(self.layers[i].weight,requires_grad=True))
        #     else:
        #         self.tensors.append(None)

    # def intermediate_outputs_init(self,outputs_to_copy):
    #     for i in range(len(outputs_to_copy)):
    #         self.intermediate_outputs[i]=torch.empty_like(outputs_to_copy[i])
    #         self.intermediate_outputs[i].require_grad=True

    #todo: check weights dtype
    small_num=(torch.finfo(torch.float32).eps)**0.75#sys.float_info.epsilon**0.5

    @torch.no_grad()
    def gradient_calc_last_layer(self, layer_input, layer_input_candidate, layer_input_candidate_grad, layer_out,
                      layer_out_candidate_gradient, layer,loss_fn,target):
        input_mod = layer_input.detach().clone()
        # layer_input_candidate=layer_input_candidate.detach().clone()
        # layer_input_candidate+=0.001
        reshaped_input = layer_input.reshape((layer_input.size()[0], -1)).detach()
        reshaped_input_mod = input_mod.reshape((input_mod.size()[0], -1)).detach()
        reshaped_input_candidate = layer_input_candidate.reshape((input_mod.size()[0], -1)).detach()
        # reshaped_input_grad=layer_input.grad.reshape((layer_input.grad.size()[0],-1)).detach()
        reshaped_input_grad = layer_input_candidate_grad.reshape((layer_input_candidate_grad.size()[0], -1)).detach()

        reshaped_output_candidate_gradient = layer_out_candidate_gradient.reshape(
            (layer_out_candidate_gradient.size()[0], -1)).detach()
        # last_index=-1
        # out=torch.empty_like(layer_input)


        for i in range(reshaped_input_mod.size()[1]):
            if i > 0:
                reshaped_input_mod[:, i - 1] = reshaped_input[:, i - 1].detach()

            out1 = layer(input_mod)
            loss1 = loss_fn(out1, target)

            # reshaped_input_mod[:, i] = reshaped_input[:, i].detach() + 0.0001 #test
            #reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach() #original one
            reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach() * 0.5 + reshaped_input[:, i] * 0.5


            out = layer(input_mod)
            # out=nn.LogSoftmax(dim=1)(input_mod.detach().clone())
            loss2 = loss_fn(out,target)

            for j in range(reshaped_input_grad.size()[0]):
                diff = reshaped_input_mod[j, i] - reshaped_input[j, i]
                if torch.abs(diff) > self.small_num:
                    # reshaped_input_grad[j,i]=torch.sum((out[j].detach()-layer_out[j].detach())*layer_out.grad[j].detach()/diff)
                    #reshaped_input_grad[j, i] = reshaped_input_grad[j, i] * 0.5 + 0.5 * (loss2-loss1) / diff.detach()
                    reshaped_input_grad[j, i] =reshaped_input_grad[j, i]*0.5+0.5* (loss2 - loss1) / diff.detach()
                else:
                    print("Diff<=very small number")

    @torch.no_grad()
    def gradient_calc(self,layer_input,layer_input_candidate,layer_input_candidate_grad,layer_out,layer_out_candidate_gradient,layer):
        input_mod=layer_input.detach().clone()
        # layer_input_candidate=layer_input_candidate.detach().clone()
        # layer_input_candidate+=0.001
        reshaped_input = layer_input.reshape((layer_input.size()[0], -1)).detach()
        reshaped_input_mod=input_mod.reshape((input_mod.size()[0],-1)).detach()
        reshaped_input_candidate=layer_input_candidate.reshape((input_mod.size()[0],-1)).detach()
        #reshaped_input_grad=layer_input.grad.reshape((layer_input.grad.size()[0],-1)).detach()
        reshaped_input_grad=layer_input_candidate_grad.reshape((layer_input_candidate_grad.size()[0],-1)).detach()

        reshaped_output_candidate_gradient=layer_out_candidate_gradient.reshape((layer_out_candidate_gradient.size()[0],-1)).detach()
        #last_index=-1
        #out=torch.empty_like(layer_input)
        for i in range(reshaped_input_mod.size()[1]):
            if i>0:
                reshaped_input_mod[:,i-1]=reshaped_input[:,i-1].detach()

            #reshaped_input_mod[:, i] = reshaped_input[:, i].detach() + 0.0001 #test
            #reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach() #original one
            reshaped_input_mod[:, i] = reshaped_input_candidate[:, i].detach()*0.5+reshaped_input[:, i]*0.5

            out=layer(input_mod)
            #out=nn.LogSoftmax(dim=1)(input_mod.detach().clone())

            for j in range(reshaped_input_grad.size()[0]):
                diff=reshaped_input_mod[j,i]-reshaped_input[j,i]
                if torch.abs(diff)>self.small_num:
                    #reshaped_input_grad[j,i]=torch.sum((out[j].detach()-layer_out[j].detach())*layer_out.grad[j].detach()/diff)
                    reshaped_input_grad[j,i]=reshaped_input_grad[j,i]*0.5+0.5*torch.sum((out[j].detach()-layer_out[j].detach())*reshaped_output_candidate_gradient[j].detach()/diff.detach())
                else:
                    print("Diff<=very small number")
                #else default gradient value is kept

            #reshaped_input_grad[:, i]=1
            #last_index=i

        #reshaped_input_mod[:, reshaped_input_mod.size()[1]-1] = reshaped_input[:, reshaped_input_mod.size()[1]-1].detach()
        #reshaped_input_grad._version-=1

    input_cached=None

    def backward_grad_correction(self,loss,model2,loss_fn,target):
        ind = len(self.intermediate_outputs) - 1
        grad = torch.autograd.grad(outputs=loss,
                                   inputs=self.intermediate_outputs[len(self.intermediate_outputs) - 1],
                                   # inputs=self.intermediate_outputs[0],
                                   grad_outputs=None
                                   , retain_graph=True)
        self.intermediate_outputs_grad[ind] = grad[0]
        for ind in range(len(self.intermediate_outputs) - 1, -1, -1):
            if hasattr(self.layers[ind], 'weight'):
                if ind>0:
                    self.intermediate_outputs[ind] = self.layers[ind](self.intermediate_outputs[ind-1])
                else:
                    self.intermediate_outputs[ind]=self.layers[ind](self.input_cached)
                grad_outputs = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                                   inputs=self.layers[ind].weight,
                                                   # inputs=self.intermediate_outputs[0],
                                                   grad_outputs=self.intermediate_outputs_grad[ind]
                                                   , retain_graph=True)
                self.layers[ind].weight.grad = grad_outputs[0]

                model2.layers[ind].weight.grad=grad_outputs[0].detach()#.clone()
                # if ind > 0:
                #     grad_layers = torch.autograd.grad(outputs=self.layers[ind].weight,
                #                                       inputs=self.intermediate_outputs[ind - 1],
                #                                       # inputs=self.intermediate_outputs[0],
                #                                       grad_outputs=self.layers[ind].weight.grad
                #                                       , retain_graph=True)
                #     self.intermediate_outputs[ind - 1].grad = grad_layers[0]
            if ind > 0:
                grad_layers = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                                  inputs=self.intermediate_outputs[ind - 1],
                                                  # inputs=self.intermediate_outputs[0],
                                                  grad_outputs=self.intermediate_outputs_grad[ind]
                                                  , retain_graph=True)
                self.intermediate_outputs_grad[ind - 1] = grad_layers[0]

                # if ind==len(self.intermediate_outputs)-1:
                #     self.gradient_calc(self.intermediate_outputs[ind - 1],self.intermediate_outputs[ind - 1],
                #                        self.intermediate_outputs[ind],self.layers[ind])
                # pass

                enhanced=False
                if enhanced and isinstance(self.layers[ind], self.many_to_many_nonlinear):
                    # self.gradient_calc(model2.intermediate_outputs[ind - 1], self.intermediate_outputs[ind - 1],
                    #                    model2.intermediate_outputs[ind],self.intermediate_outputs[ind].grad, self.layers[ind])

                    # self.gradient_calc(model2.intermediate_outputs[ind - 1], self.intermediate_outputs[ind - 1],
                    #                    self.intermediate_outputs_grad[ind - 1],
                    #                    model2.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                    #                    self.layers[ind])
                    log=False
                    if log:
                        print("----")
                        print(self.intermediate_outputs_grad[ind - 1])
                    # self.gradient_calc(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                    #                    self.intermediate_outputs_grad[ind - 1],
                    #                    self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                    #                    self.layers[ind])
                    self.intermediate_outputs_grad[ind - 1]=0.6*self.intermediate_outputs_grad[ind - 1]+\
                                                            model2.intermediate_outputs[ind - 1].grad * 0.4
                    # self.gradient_calc_last_layer(self.intermediate_outputs[ind - 1], model2.intermediate_outputs[ind - 1],
                    #                    self.intermediate_outputs_grad[ind - 1],
                    #                    self.intermediate_outputs[ind], self.intermediate_outputs_grad[ind],
                    #                    self.layers[ind],loss_fn,target)
                    if log:
                        print(self.intermediate_outputs_grad[ind - 1])
                #self.intermediate_outputs[ind - 1].grad.detach()

    def backward(self,loss):
        # print(self.layers[len(self.layers)-5].weight.grad)
        # loss.backward(inputs=self.layers[len(self.layers)-5].weight)
        # print(self.layers[len(self.layers)-5].weight.grad)
        # print(self.layers[2].weight.grad)
        # loss.backward(inputs=self.layers[2].weight)
        # print(self.layers[2].weight.grad)

        # #test:
        # print(self.layers[2].weight.grad)
        # #loss.backward(inputs=self.intermediate_outputs[0],retain_graph=True)
        # #loss.backward(inputs=self.layers[len(self.layers)-5].weight,retain_graph=True)
        # d=torch.autograd.grad(outputs=loss,
        #                       #inputs=self.layers[2].weight,
        #                     inputs=self.intermediate_outputs[2],
        #                     grad_outputs=loss.grad
        #                     , retain_graph=True)
        # print(d[0])
        # #torch.autograd.grad(outputs=self.intermediate_outputs[len(self.intermediate_outputs)-1], inputs=self.intermediate_outputs[0],grad_outputs=self.intermediate_outputs[len(self.intermediate_outputs)-1]
        # #                    ,retain_graph=True)
        # #torch.autograd.grad(outputs=loss, inputs=None)
        # #loss.backward()
        # #self.layers[2].weight.grad=d[0]
        # self.intermediate_outputs[2].grad=d[0]
        # #print(self.layers[2].weight.grad)
        # print(self.intermediate_outputs[2].grad)
        #
        # grad = torch.autograd.grad(outputs=self.intermediate_outputs[2],
        #                         inputs=self.layers[0].weight,
        #                         # inputs=self.intermediate_outputs[0],
        #                         grad_outputs=self.intermediate_outputs[2].grad
        #                         , retain_graph=True)
        # print(grad)
        # loss.backward()
        # print(self.layers[0].weight.grad)

        ind=len(self.intermediate_outputs)-1
        grad = torch.autograd.grad(outputs=loss,
                                   inputs=self.intermediate_outputs[len(self.intermediate_outputs)-1],
                                   # inputs=self.intermediate_outputs[0],
                                   grad_outputs=None
                                   , retain_graph=True)
        self.intermediate_outputs[ind].grad=grad[0]
        for ind in range(len(self.intermediate_outputs)-1,-1,-1):
            if hasattr(self.layers[ind],'weight'):
                grad_outputs = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                           inputs=self.layers[ind].weight,
                                           # inputs=self.intermediate_outputs[0],
                                           grad_outputs=self.intermediate_outputs[ind].grad
                                           , retain_graph=True)
                self.layers[ind].weight.grad=grad_outputs[0]
                # if ind > 0:
                #     grad_layers = torch.autograd.grad(outputs=self.layers[ind].weight,
                #                                       inputs=self.intermediate_outputs[ind - 1],
                #                                       # inputs=self.intermediate_outputs[0],
                #                                       grad_outputs=self.layers[ind].weight.grad
                #                                       , retain_graph=True)
                #     self.intermediate_outputs[ind - 1].grad = grad_layers[0]
            if ind > 0:
                grad_layers = torch.autograd.grad(outputs=self.intermediate_outputs[ind],
                                                  inputs=self.intermediate_outputs[ind - 1],
                                                  # inputs=self.intermediate_outputs[0],
                                                  grad_outputs=self.intermediate_outputs[ind].grad
                                                  , retain_graph=True)
                self.intermediate_outputs[ind - 1].grad = grad_layers[0]
                # if ind==len(self.intermediate_outputs)-1:
                #     self.gradient_calc(self.intermediate_outputs[ind - 1],self.intermediate_outputs[ind - 1],
                #                        self.intermediate_outputs[ind],self.layers[ind])
                    #pass
                # if isinstance(self.layers[ind],self.many_to_many_nonlinear):
                #     self.gradient_calc(self.intermediate_outputs[ind - 1], self.intermediate_outputs[ind - 1],
                #                        self.intermediate_outputs[ind], self.layers[ind])
                #pass

        # return
        # layer_num=len(self.layers)
        # for i in range(layer_num-1,-1,-1):
        #     #self.layers[i].register_backward_hook()
        #     if i==0 or not hasattr(self.layers[i-1],'weight'):
        #         continue
        #     loss.backward(inputs=self.layers[i-1].weight)
        #     loss=self.layers[i-1].weight
        #     #loss=self.layers[i-1].backward(loss)
    def forward(self, x):
        self.input_cached=x
        #layers_num=len(self.layers)
        #for i in range(layers_num-1):
        #    x=self.layers[i](x)
        #output=self.layers[layers_num-1](x)
        #return output

        #self.intermediate_outputs.clear()
        i=0
        for layer in self.layers:
            if self.intermediate_outputs[i] is None:
                x=layer(x)
                #self.intermediate_outputs[i]=(Variable(x, requires_grad=True))
                self.intermediate_outputs[i]=x
                self.intermediate_outputs[i].require_grad=True

                #self.intermediate_outputs_grad[i]=torch.empty_like(x)
            else:
                self.intermediate_outputs[i] = layer(x)
                x=self.intermediate_outputs[i]
            i+=1
        return x

        # x = self.conv1(x)
        # #x=self.layers[0](x)
        # #x=self.layers[1](x)
        # x = F.relu(x)
        # x = self.conv2(x)
        # x = F.relu(x)
        # #x = self.layers[2](x)
        # #x = self.layers[3](x)
        # x = F.max_pool2d(x, 2)
        # x = self.dropout1(x)
        # x = torch.flatten(x, 1)
        # x = self.fc1(x)
        # x = F.relu(x)
        # x = self.dropout2(x)
        # x = self.fc2(x)
        # #output = F.log_softmax(x, dim=1)
        # output=self.softmax(x)
        # return output


def train(args, model, device, train_loader, optimizer, epoch,model2=None,opt2=None):
    model.train()
    copy_test=False
    opt_test=True
    copy_grad_experiment=False

    if model2 is None or opt2 is None:
        model2 = Net()
        model2.train()
        model2.load_state_dict(model.state_dict())
        opt2 = type(optimizer)(model2.parameters())
        opt2.load_state_dict(optimizer.state_dict())

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        if copy_grad_experiment:
            model2.load_state_dict(model.state_dict())
            opt2.load_state_dict(optimizer.state_dict())

            opt2.zero_grad()
            optimizer.zero_grad()
            # data.requires_grad=True
            output = model2(data)
            loss = F.nll_loss(output, target)
            # loss.requires_grad=True

            # if True:
            #     model2.backward(loss)
            #     #loss.backward(retain_graph=True)
            #     # model2.backward(loss)
            #
            #
            #     for i in range(len(model2.layers)):
            #         if hasattr(model2.layers[i],"weight"):
            #             model.layers[i].weight.grad=model2.layers[i].weight.grad.detach()
            #_ = model(data)
            model2.backward_grad_correction(loss,model)
            optimizer.step()
            #opt2.step()
            #model.load_state_dict(model2.state_dict())
            #optimizer.load_state_dict(opt2.state_dict())
        else:
            if not opt_test:
                if not copy_test:
                    optimizer.zero_grad()
                    #data.requires_grad=True
                    output = model(data)
                    loss = F.nll_loss(output, target)
                    #loss.requires_grad=True

                    #loss.backward()
                    model.backward(loss)
                    optimizer.step()
                else:
                    model2.load_state_dict(model.state_dict())
                    opt2.load_state_dict(optimizer.state_dict())

                    opt2.zero_grad()
                    # data.requires_grad=True
                    output = model2(data)
                    loss = F.nll_loss(output, target)
                    # loss.requires_grad=True

                    model2.backward(loss)
                    #loss.backward()
                    #model2.backward(loss)
                    opt2.step()
                    model.load_state_dict(model2.state_dict())
                    optimizer.load_state_dict(opt2.state_dict())
            else:
                model2.load_state_dict(model.state_dict())
                opt2.load_state_dict(optimizer.state_dict())

                opt2.zero_grad()
                optimizer.zero_grad()
                # data.requires_grad=True

                output = model2(data)
                loss = F.nll_loss(output, target)
                # loss.requires_grad=True

                model2.backward(loss)
                #loss.backward(retain_graph=True)
                
                # loss.backward()
                # model2.backward(loss)
                opt2.step() #todo error is here

                output1=model(data)
                loss1=F.nll_loss(output1, target)
                loss1.backward()

                output2=model2(data)
                loss2 = F.nll_loss(output2, target)
                loss2.backward()
                #model.backward_grad_correction(loss1,model2,F.nll_loss,target)
                #model2.backward_grad_correction(loss,model)
                for ind in range(len(model.layers)):
                    if hasattr(model.layers[ind],"weight"):
                        model.layers[ind].weight.grad[:]=model.layers[ind].weight.grad[:]*0.5+model2.layers[ind].weight.grad[:]*0.5

                optimizer.step()

                #model.load_state_dict(model2.state_dict())
                #optimizer.load_state_dict(opt2.state_dict())

        #dziala:
        #print(model.layers[len(model.layers)-5].weight.grad)
        #loss.backward(inputs=model.layers[len(model.layers)-5].weight)
        #print(model.layers[len(model.layers)-5].weight.grad)

        #data.requires_grad=True
        #output.requires_grad=True

        #print(torch.autograd.grad(outputs=output,inputs=data))
        #output.backward(gradient=loss.)
        #model.backward(loss)

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break


def test(model, device, test_loader,train_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(train_loader.dataset)

    print('\nTraining set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--no-mps', action='store_true', default=False,
                        help='disables macOS GPU training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    use_mps = not args.no_mps and torch.backends.mps.is_available()

    torch.manual_seed(args.seed)

    if use_cuda:
        device = torch.device("cuda")
    elif use_mps:
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    args.batch_size=64#64
    args.lr=1
    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}

    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
    dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
    #optimizer = optim.Adam(model.parameters(),0.0001)
    #optimizer=optim.SGD(model.parameters(),lr=0.0001)


    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    model2 = Net()
    model2.train()
    model2.load_state_dict(model.state_dict())
    opt2 = type(optimizer)(model2.parameters())
    opt2.load_state_dict(optimizer.state_dict())

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch, model2, opt2)
        #train(args, model, device, train_loader, optimizer, epoch)#train(args, model, device, train_loader, optimizer, epoch,model2,opt2)
        test(model, device, test_loader,train_loader)
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")


if __name__ == '__main__':
    main()
