import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torchvision.models as models
import torch.distributed as dist
from torch.multiprocessing import Process
import os
from tensorboardX import SummaryWriter
import argparse
import torch.nn as nn
import math
import time
import optimizer_UR as optimizer
import copy
import random
import cifar

        


def train(args):
 
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.507,0.487,0.441], std = [0.267,0.256,0.276])
    ])
    
    transform = transforms.Compose(
         [transforms.ToTensor(),
         transforms.Normalize(mean=[0.507, 0.487, 0.441],
                              std=[0.267, 0.256, 0.276])])


    device = torch.device("cuda:%s"%(args.gpu[0]))

    print(args.max_iter)
    args.batch_size = args.batch_size
    if not os.path.isdir("%s"%(args.log_dir)):
        os.mkdir("%s"%(args.log_dir))
    writer = SummaryWriter("%s"%(args.log_dir))

    
    path = 'data'
    
    trainloader_list = []
    loader_iter = []
    index_seq = np.arange(0,50000)
    np.random.shuffle(index_seq)
    N = args.N
    for i in range(0,N):
        trainset = cifar.CIFAR10(root=path, train=True,
                             download=True, transform=transform_train, random_list = index_seq, index = i)
        if (i==0):
             index_seq = trainset.tmp_list 
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                          shuffle=True, num_workers=4)
                                          
        trainloader_list.append(trainloader)
        loader_iter.append(iter(trainloader)) 
    
    criterion = nn.CrossEntropyLoss()
    model = torchvision.models.vgg19(num_classes=10)
    model = model.to(device)
    init_x = None
    #init_x = torch.zeros(N, device = device)
    #init_x[1] = 0.5
    #init_x[2] = 0.2
    #init_x[3] = 0.3
    optim = optimizer.SGD(#model.parameters(), lr = args.lr, momentum=args.beta, weight_decay = args.weight_decay)
              model.parameters(), lr=args.lr, momentum = args.momentum, weight_decay=args.weight_decay, N = N, x_lr = args.xlr, lamb_lr = args.lamblr,init_x = init_x,  device= device);
    
    testset = torchvision.datasets.CIFAR10(root=path,train=False,download = True, transform = transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size = 1000, shuffle = False, num_workers = 8)
    #scalar_list = [0.9,0.9,0.9,0.9,0.9]
    
                
    for step in range(0,args.max_iter):
        optim.pre_step_x()
        for local_step in range(args.update_x):
            for i in range(1,N):
                  
                try:
                    data = next(loader_iter[i])
                except StopIteration:
                    loader_iter[i] = iter(trainloader_list[i])
                    data = next(loader_iter[i])
                x,y = data
            
                model.zero_grad()
                y = y.to(device)
                x = x.to(device)
                y1 = model(x)
                loss = criterion(y1,y)
                loss.backward()
                optim.step(node = i)
            
            active = torch.rand(args.N)
            active[0] = 0
            while torch.sum(active>=args.join_p)==0:
                active = torch.rand(args.N)
                active[0] = 0
            optim.pre_step_w1()
            for i in range(1,N):
                try:
                    data = next(loader_iter[i])
                except StopIteration:
                    loader_iter[i] = iter(trainloader_list[i])
                    data = next(loader_iter[i])
                x,y = data
            
                model.zero_grad()
                y = y.to(device)
                x = x.to(device)
                y1 = model(x)
                loss = criterion(y1,y)
                optim.local_step_w1(loss, node = i, active=active)
            optim.update_mul(active>=args.join_p) 
            optim.step_w(active>=args.join_p)
        try:
            data = next(loader_iter[0])
        except StopIteration:
            loader_iter[0] = iter(trainloader_list[0])
            data = next(loader_iter[0])
        x,y = data
            
        model.zero_grad()
        y = y.to(device)
        x = x.to(device)
        y1 = model(x)
        loss = criterion(y1,y)
        loss.backward()
        optim.step(node=0)
        
        for i in range(1,N): 
            optim.gen_update_x(node = i)
        optim.step_x(active>=args.join_p)
        
        
        if (step%(200)==0):
                test_model = copy.deepcopy(model)
                test_model.eval()
                total = 0
                ac = 0
                total_loss = 0
                for j,data1 in enumerate(testloader):
                    tx,ty = data1
                    ty = ty.to(device)
                    tx = tx.to(device)
                    ty1 = test_model(tx)
                    test_loss = criterion(ty1,ty)
                    ty2 = torch.argmax(ty1,dim = 1)
                    ac = ac + torch.sum(ty2==ty).item()
                    total = total + torch.sum(ty2==ty2).item()
                    total_loss = total_loss + test_loss.data.item()
                writer.add_scalar('acc', ac/total,step)
                writer.add_scalar('test_loss', total_loss/100,step)
                for rs in range(1,N):
                     writer.add_scalar('x_%d'%(rs), optim.x[rs],step)
                print("step:",step," test loss:",total_loss/100," acc:",ac/total)
                print("step:", step,"loss:",loss)
                print(optim.x)
                writer.add_scalar('loss', loss.data.item(), step)
            
    torch.save(model.state_dict(),args.filename)
    writer.close()
def parse():
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--filename", type = str, default = "model.pth")
    parser.add_argument("--log_dir", type = str, default = "logs/")
    parser.add_argument("--gpu", type = str, default = "0")

    parser.add_argument("--lr",type = float, default = 0.1)
    parser.add_argument("--momentum",type = float, default = 0.9)
    parser.add_argument("--xlr", type=float, default = 1e-3) 
    parser.add_argument("--lamblr", type = float, default = 1e-3)
    parser.add_argument("--Gamma", type = float, default = 1)
    parser.add_argument("--update_x", type = int, default = 1)
    parser.add_argument("--join_p", type = float, default = 0)
    parser.add_argument("--N", type = int, default = 11)
    
  
    parser.add_argument("--epoch", type = int, default = 1000)
    parser.add_argument("--batch_size", type = int, default = 64)
    parser.add_argument("--decay_time",type = int, default = 12000)
    parser.add_argument("--weight_decay",type = float, default = 0)
    parser.add_argument("--decay_lr", type = float, default = 0.2)
    parser.add_argument("--max_iter",type = int, default = 20000)
    return parser.parse_args()

if __name__ == "__main__":

    processes = []
    args =  parse()
    print(args)  
    args.gpu = args.gpu.split(',')
    train(args)
