import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torchvision.models as models
import torch.distributed as dist
from torch.multiprocessing import Process
import os
from tensorboardX import SummaryWriter
import argparse
import torch.nn as nn
import math
import time
import optimizer_DM as optimizer
import copy
import random
import mnist


class Lenet(torch.nn.Module):
    def __init__(self):
        super(Lenet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        #y = self.relu5(y)
        return y
        

def change_lr(optim, decay):
    for param_group in optim.param_groups:
        param_group['lr'] *= decay

def train(args):
 
    transform_train = transforms.Compose([
        transforms.ToTensor()
    ])
    
    transform = transforms.Compose(
         [transforms.ToTensor()])

    device = torch.device("cuda:%s"%(args.gpu[0]))

    print(args.max_iter)
    args.batch_size = args.batch_size
    if not os.path.isdir("%s"%(args.log_dir)):
        os.mkdir("%s"%(args.log_dir))
    writer = SummaryWriter("%s"%(args.log_dir))

    
    path = 'data'
    
    trainloader_list = []
    loader_iter = []
    index_seq = np.arange(0,50000)
    np.random.shuffle(index_seq)
    N = args.N
    for i in range(0,N):
        trainset = mnist.MNIST_train(root=path, train=True,
                             download=True, transform=transform_train, random_list = index_seq, index = i)
        if (i==0):
             index_seq = trainset.tmp_list 
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                          shuffle=True, num_workers=4)
                                          
        trainloader_list.append(trainloader)
        loader_iter.append(iter(trainloader)) 
    
    criterion = nn.CrossEntropyLoss()
    model = Lenet()
    model = model.to(device)
    init_x = None
    #init_x = torch.zeros(N, device = device)
    #init_x[1] = 0.5
    #init_x[2] = 0.2
    #init_x[3] = 0.3
    optim = optimizer.SGD(#model.parameters(), lr = args.lr, momentum=args.beta, weight_decay = args.weight_decay)
              model.parameters(), lr=args.lr, momentum = args.momentum, weight_decay=args.weight_decay, N = N, x_lr = args.xlr, lamb_lr = args.lamblr,init_x = init_x,  device= device);
    
    testset = torchvision.datasets.MNIST(root=path,train=False,download = True, transform = transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size = 1000, shuffle = False, num_workers = 8)
    #scalar_list = [0.9,0.9,0.9,0.9,0.9]
    
                
    for step in range(0,args.max_iter):
        for i in range(1,N):
                  
                try:
                    data = next(loader_iter[i])
                except StopIteration:
                    loader_iter[i] = iter(trainloader_list[i])
                    data = next(loader_iter[i])
                x,y = data
                optim.switch_w()
                model.zero_grad()
                y = y.to(device)
                x = x.to(device)
                y1 = model(x)
                loss = criterion(y1,y)
                loss.backward()
                optim.step(node = i)
        
                optim.switch_w0()
                model.zero_grad()
                y = y.to(device)
                x = x.to(device)
                y1 = model(x)
                loss = criterion(y1,y)
                loss.backward()
                optim.step0(node = i)
                
        active = torch.rand(args.N)
        active[0] = 0
        while torch.sum(active>=args.join_p)==0:
            active = torch.rand(args.N)
            active[0] = 0
                    
        try:
            data = next(loader_iter[0])
        except StopIteration:
            loader_iter[0] = iter(trainloader_list[0])
            data = next(loader_iter[0])
        x,y = data
        optim.switch_w()           
        model.zero_grad()
        y = y.to(device)
        x = x.to(device)
        y1 = model(x)
        loss = criterion(y1,y)
        loss.backward()
        optim.step(node=0)
        
        optim.switch_w0()           
        model.zero_grad()
        y = y.to(device)
        x = x.to(device)
        y1 = model(x)
        loss = criterion(y1,y)
        loss.backward()
        optim.step0(node=0)
        
        optim.pre_step_x()
        
        for local_step in range(args.update_x):
            for i in range(1,N):
                  
                try:
                    data = next(loader_iter[i])
                except StopIteration:
                    loader_iter[i] = iter(trainloader_list[i])
                    data = next(loader_iter[i])
                x,y = data
                optim.switch_w()
                model.zero_grad()
                y = y.to(device)
                x = x.to(device)
                y1 = model(x)
                loss = criterion(y1,y)
                optim.local_step_w1(loss, node = i)
            
                optim.switch_w0()
                model.zero_grad()
                y = y.to(device)
                x = x.to(device)
                y1 = model(x)
                loss = criterion(y1,y)
                optim.local_step_w10(loss, node = i)
            
            active1 = torch.rand(args.N)
            active1[0] = 0
            while torch.sum(active1>=args.join_p)==0:
                active1 = torch.rand(args.N)
                active1[0] = 0
            optim.update_mul(active1>=args.join_p)
        for i in range(1,N): 
            optim.gen_update_x(node = i)
        
        optim.step_x(active>=args.join_p)
        optim.step_w(active>=args.join_p)
        
        
        if (step%(10)==0):
                test_model = copy.deepcopy(model)
                test_model.eval()
                total = 0
                ac = 0
                total_loss = 0
                for j,data1 in enumerate(testloader):
                    tx,ty = data1
                    ty = ty.to(device)
                    tx = tx.to(device)
                    ty1 = test_model(tx)
                    test_loss = criterion(ty1,ty)
                    ty2 = torch.argmax(ty1,dim = 1)
                    ac = ac + torch.sum(ty2==ty).item()
                    total = total + torch.sum(ty2==ty2).item()
                    total_loss = total_loss + test_loss.data.item()
                writer.add_scalar('acc', ac/total,step)
                writer.add_scalar('test_loss', total_loss/100,step)
                print("step:",step," test loss:",total_loss/100," acc:",ac/total)
                print("step:", step,"loss:",loss)
                print(optim.x)
                writer.add_scalar('loss', loss.data.item(), step)
            
    torch.save(model.state_dict(),args.filename)
    writer.close()
def parse():
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--filename", type = str, default = "model.pth")
    parser.add_argument("--log_dir", type = str, default = "logs/")
    parser.add_argument("--gpu", type = str, default = "0")

    parser.add_argument("--lr",type = float, default = 0.1)
    parser.add_argument("--momentum",type = float, default = 0.9)
    parser.add_argument("--xlr", type=float, default = 1e-2) 
    parser.add_argument("--lamblr", type = float, default = 1e-1)
    parser.add_argument("--Gamma", type = float, default = 1)
    parser.add_argument("--update_x", type = int, default = 1)
    parser.add_argument("--join_p", type = float, default = 0)
    parser.add_argument("--N", type = int, default = 11)
    
  
    parser.add_argument("--epoch", type = int, default = 1000)
    parser.add_argument("--batch_size", type = int, default = 64)
    parser.add_argument("--decay_time",type = int, default = 12000)
    parser.add_argument("--weight_decay",type = float, default = 0)
    parser.add_argument("--decay_lr", type = float, default = 0.2)
    parser.add_argument("--max_iter",type = int, default = 2000)
    return parser.parse_args()

if __name__ == "__main__":

    processes = []
    args =  parse()
    print(args)  
    args.gpu = args.gpu.split(',')
    train(args)
