import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torchvision.models as models
import torch.distributed as dist
from torch.multiprocessing import Process
import os
from tensorboardX import SummaryWriter
import argparse
import torch.nn as nn
import math
import time
import optimizer_DM as optimizer
import copy
import random
import cifar



def change_lr(optim, decay):
    for param_group in optim.param_groups:
        param_group['lr'] *= decay

def train(args):
 
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.507,0.487,0.441], std = [0.267,0.256,0.276])
    ])
    
    transform = transforms.Compose(
         [transforms.ToTensor(),
         transforms.Normalize(mean=[0.507, 0.487, 0.441],
                              std=[0.267, 0.256, 0.276])])

    device = torch.device("cuda:%s"%(args.gpu[0]))

    print(args.max_iter)
    args.batch_size = args.batch_size
    if not os.path.isdir("%s"%(args.log_dir)):
        os.mkdir("%s"%(args.log_dir))
    writer = SummaryWriter("%s"%(args.log_dir))

    
    path = data'
    
    trainloader_list = []
    loader_iter = []
    index_seq = np.arange(0,50000)
    np.random.shuffle(index_seq)
    N = args.N
    for i in range(0,N):
        trainset = cifar.CIFAR10(root=path, train=True,
                             download=True, transform=transform_train, random_list = index_seq, index = i)
        if (i==0):
             index_seq = trainset.tmp_list 
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                          shuffle=True, num_workers=4)
                                          
        trainloader_list.append(trainloader)
        loader_iter.append(iter(trainloader)) 
    
    criterion = nn.CrossEntropyLoss()
    model = torchvision.models.vgg19(num_classes=10)
    model = model.to(device)
    init_x = None
    #init_x = torch.zeros(N, device = device)
    #init_x[1] = 0.5
    #init_x[2] = 0.2
    #init_x[3] = 0.3
    optim = optimizer.SGD(#model.parameters(), lr = args.lr, momentum=args.beta, weight_decay = args.weight_decay)
              model.parameters(), lr=args.lr, momentum = args.momentum, weight_decay=args.weight_decay, N = N, x_lr = args.xlr, lamb_lr = args.lamblr,init_x = init_x,  device= device);
    
    testset = torchvision.datasets.CIFAR10(root=path,train=False,download = True, transform = transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size = 100, shuffle = False, num_workers = 4)
    #scalar_list = [0.9,0.9,0.9,0.9,0.9]
    
                
    for step in range(0,args.max_iter):
        for i in range(1,N):
                  
                try:
                    data = next(loader_iter[i])
                except StopIteration:
                    loader_iter[i] = iter(trainloader_list[i])
                    data = next(loader_iter[i])
                x,y = data
                optim.switch_w()
                model.zero_grad()
                y = y.to(device)
                x = x.to(device)
                y1 = model(x)
                loss = criterion(y1,y)
                loss.backward()
                optim.step(node = i)
        
                optim.switch_w0()
                model.zero_grad()
                y = y.to(device)
                x = x.to(device)
                y1 = model(x)
                loss = criterion(y1,y)
                loss.backward()
                optim.step0(node = i)
                
        active = torch.rand(args.N)
        active[0] = 0
        while torch.sum(active>=args.join_p)==0:
            active = torch.rand(args.N)
            active[0] = 0
                    
        try:
            data = next(loader_iter[0])
        except StopIteration:
            loader_iter[0] = iter(trainloader_list[0])
            data = next(loader_iter[0])
        x,y = data
        optim.switch_w()           
        model.zero_grad()
        y = y.to(device)
        x = x.to(device)
        y1 = model(x)
        loss = criterion(y1,y)
        loss.backward()
        optim.step(node=0)
        
        optim.switch_w0()           
        model.zero_grad()
        y = y.to(device)
        x = x.to(device)
        y1 = model(x)
        loss = criterion(y1,y)
        loss.backward()
        optim.step0(node=0)
        
        optim.pre_step_x()
        
        for local_step in range(args.update_x):
            for i in range(1,N):
                  
                try:
                    data = next(loader_iter[i])
                except StopIteration:
                    loader_iter[i] = iter(trainloader_list[i])
                    data = next(loader_iter[i])
                x,y = data
                optim.switch_w()
                model.zero_grad()
                y = y.to(device)
                x = x.to(device)
                y1 = model(x)
                loss = criterion(y1,y)
                optim.local_step_w1(loss, node = i)
            
                optim.switch_w0()
                model.zero_grad()
                y = y.to(device)
                x = x.to(device)
                y1 = model(x)
                loss = criterion(y1,y)
                optim.local_step_w10(loss, node = i)
            
            active1 = torch.rand(args.N)
            active1[0] = 0
            while torch.sum(active1>=args.join_p)==0:
                active1 = torch.rand(args.N)
                active1[0] = 0
            optim.update_mul(active1>=args.join_p)
        for i in range(1,N): 
            optim.gen_update_x(node = i)
        
        optim.step_x(active>=args.join_p)
        optim.step_w(active>=args.join_p)
        
        
        if (step%(200)==0):
                test_model = copy.deepcopy(model)
                test_model.eval()
                total = 0
                ac = 0
                total_loss = 0
                for j,data1 in enumerate(testloader):
                    tx,ty = data1
                    ty = ty.to(device)
                    tx = tx.to(device)
                    ty1 = test_model(tx)
                    test_loss = criterion(ty1,ty)
                    ty2 = torch.argmax(ty1,dim = 1)
                    ac = ac + torch.sum(ty2==ty).item()
                    total = total + torch.sum(ty2==ty2).item()
                    total_loss = total_loss + test_loss.data.item()
                writer.add_scalar('acc', ac/total,step)
                writer.add_scalar('test_loss', total_loss/100,step)
                print("step:",step," test loss:",total_loss/100," acc:",ac/total)
                print("step:", step,"loss:",loss)
                print(optim.x)
                for rs in range(1,N):
                     writer.add_scalar('x_%d'%(rs), optim.x[rs],step)
                writer.add_scalar('loss', loss.data.item(), step)
            
    torch.save(model.state_dict(),args.filename)
    writer.close()
def parse():
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--filename", type = str, default = "model.pth")
    parser.add_argument("--log_dir", type = str, default = "logs/")
    parser.add_argument("--gpu", type = str, default = "0")

    parser.add_argument("--lr",type = float, default = 0.1)
    parser.add_argument("--momentum",type = float, default = 0.9)
    parser.add_argument("--xlr", type=float, default = 1e-2) 
    parser.add_argument("--lamblr", type = float, default = 1e-2)
    parser.add_argument("--Gamma", type = float, default = 1)
    parser.add_argument("--update_x", type = int, default = 1)
    parser.add_argument("--join_p", type = float, default = 0)
    parser.add_argument("--N", type = int, default = 11)
    
  
    parser.add_argument("--epoch", type = int, default = 1000)
    parser.add_argument("--batch_size", type = int, default = 64)
    parser.add_argument("--decay_time",type = int, default = 12000)
    parser.add_argument("--weight_decay",type = float, default = 0)
    parser.add_argument("--decay_lr", type = float, default = 0.2)
    parser.add_argument("--max_iter",type = int, default = 20000)
    return parser.parse_args()

if __name__ == "__main__":

    processes = []
    args =  parse()
    print(args)  
    args.gpu = args.gpu.split(',')
    train(args)
