import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torchvision.models as models
import torch.distributed as dist
from torch.multiprocessing import Process
import os
from tensorboardX import SummaryWriter
import argparse
import torch.nn as nn
import math
import time
import optimizer_BSA as optimizer
import copy
import random
import mnist


class Lenet(torch.nn.Module):
    def __init__(self):
        super(Lenet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        #y = self.relu5(y)
        return y
        

def change_lr(optim, decay):
    for param_group in optim.param_groups:
        param_group['lr'] *= decay

def train(args):
 
    transform_train = transforms.Compose([
        transforms.ToTensor()
    ])
    
    transform = transforms.Compose(
         [transforms.ToTensor()])

    device = torch.device("cuda:%s"%(args.gpu[0]))

    print(args.max_iter)
    args.batch_size = args.batch_size
    if not os.path.isdir("%s"%(args.log_dir)):
        os.mkdir("%s"%(args.log_dir))
    writer = SummaryWriter("%s"%(args.log_dir))

    
    path = 'data'
    
    trainloader_list = []
    loader_iter = []
    index_seq = np.arange(0,50000)
    np.random.shuffle(index_seq)
    N = args.N
    for i in range(0,N):
        trainset = mnist.MNIST_train(root=path, train=True,
                             download=True, transform=transform_train, random_list = index_seq, index = i)
        if (i==0):
             index_seq = trainset.tmp_list 
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                          shuffle=True, num_workers=4)
                                          
        trainloader_list.append(trainloader)
        loader_iter.append(iter(trainloader)) 
    
    criterion = nn.CrossEntropyLoss()
    model = Lenet()
    model = model.to(device)
    init_x = None
    #init_x = torch.zeros(N, device = device)
    #init_x[1] = 0.5
    #init_x[2] = 0.2
    #init_x[3] = 0.3
    optim = optimizer.SGD(#model.parameters(), lr = args.lr, momentum=args.beta, weight_decay = args.weight_decay)
              model.parameters(), lr=args.lr, momentum = args.momentum, weight_decay=args.weight_decay, N = N, x_lr = args.xlr, lamb_lr = args.lamblr,init_x = init_x,  device= device);
    
    testset = torchvision.datasets.MNIST(root=path,train=False,download = True, transform = transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size = 1000, shuffle = False, num_workers = 8)
    #scalar_list = [0.9,0.9,0.9,0.9,0.9]
    
                
    for step in range(0,args.max_iter):
        for local_step in range(args.update_x):
            for i in range(1,N):
                  
                try:
                    data = next(loader_iter[i])
                except StopIteration:
                    loader_iter[i] = iter(trainloader_list[i])
                    data = next(loader_iter[i])
                x,y = data
            
                model.zero_grad()
                y = y.to(device)
                x = x.to(device)
                y1 = model(x)
                loss = criterion(y1,y)
                loss.backward()
                optim.step(node = i)
            
            active = torch.rand(args.N)
            active[0] = 0
            while torch.sum(active>=args.join_p)==0:
                active = torch.rand(args.N)
                active[0] = 0
                    
            optim.step_w(active>=args.join_p)
        try:
            data = next(loader_iter[0])
        except StopIteration:
            loader_iter[0] = iter(trainloader_list[0])
            data = next(loader_iter[0])
        x,y = data
            
        model.zero_grad()
        y = y.to(device)
        x = x.to(device)
        y1 = model(x)
        loss = criterion(y1,y)
        loss.backward()
        optim.step(node=0)
        optim.pre_step_x()
        
        for local_step in range(args.update_x):
            for i in range(1,N):
                  
                try:
                    data = next(loader_iter[i])
                except StopIteration:
                    loader_iter[i] = iter(trainloader_list[i])
                    data = next(loader_iter[i])
                x,y = data
            
                model.zero_grad()
                y = y.to(device)
                x = x.to(device)
                y1 = model(x)
                loss = criterion(y1,y)
                optim.local_step_w1(loss, node = i)
            
            active1 = torch.rand(args.N)
            active1[0] = 0
            while torch.sum(active1>=args.join_p)==0:
                active1 = torch.rand(args.N)
                active1[0] = 0
            optim.update_mul(active1>=args.join_p)
        for i in range(1,N): 
            optim.gen_update_x(node = i)
        optim.step_x(active>=args.join_p)
        
        
        if (step%(10)==0):
                test_model = copy.deepcopy(model)
                test_model.eval()
                total = 0
                ac = 0
                total_loss = 0
                for j,data1 in enumerate(testloader):
                    tx,ty = data1
                    ty = ty.to(device)
                    tx = tx.to(device)
                    ty1 = test_model(tx)
                    test_loss = criterion(ty1,ty)
                    ty2 = torch.argmax(ty1,dim = 1)
                    ac = ac + torch.sum(ty2==ty).item()
                    total = total + torch.sum(ty2==ty2).item()
                    total_loss = total_loss + test_loss.data.item()
                writer.add_scalar('acc', ac/total,step)
                writer.add_scalar('test_loss', total_loss/100,step)
                print("step:",step," test loss:",total_loss/100," acc:",ac/total)
                print("step:", step,"loss:",loss)
                print(optim.x)
                writer.add_scalar('loss', loss.data.item(), step)
            
    torch.save(model.state_dict(),args.filename)
    writer.close()
def parse():
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--filename", type = str, default = "model.pth")
    parser.add_argument("--log_dir", type = str, default = "logs/")
    parser.add_argument("--gpu", type = str, default = "0")

    parser.add_argument("--lr",type = float, default = 0.1)
    parser.add_argument("--momentum",type = float, default = 0.9)
    parser.add_argument("--xlr", type=float, default = 1e-1) 
    parser.add_argument("--lamblr", type = float, default = 1e-1)
    parser.add_argument("--Gamma", type = float, default = 1)
    parser.add_argument("--update_x", type = int, default = 1)
    parser.add_argument("--join_p", type = float, default = 0)
    parser.add_argument("--N", type = int, default = 11)
    
  
    parser.add_argument("--epoch", type = int, default = 1000)
    parser.add_argument("--batch_size", type = int, default = 64)
    parser.add_argument("--decay_time",type = int, default = 12000)
    parser.add_argument("--weight_decay",type = float, default = 0)
    parser.add_argument("--decay_lr", type = float, default = 0.2)
    parser.add_argument("--max_iter",type = int, default = 2000)
    return parser.parse_args()

if __name__ == "__main__":

    processes = []
    args =  parse()
    print(args)  
    args.gpu = args.gpu.split(',')
    train(args)
