import os
import shutil
import time
import argparse

import numpy
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.distributed as dist
import datetime
import random
from collections import Counter
import json
import math

from models import *

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--epochs', default=160, type=int)
parser.add_argument('--lr', default=0.1, type=float)
parser.add_argument('--batch_size', default=256, type=int,metavar='bs', help='batchsize')
parser.add_argument('--lrdecay', default='stagewise', type=str)
parser.add_argument('--m', default=0.0, type=float)  
parser.add_argument('--model', default="resnet20", type=str)
parser.add_argument('--warmup', default=0, type=int)

parser.add_argument('--ttag', default=1, type=int)

# dataset partition
parser.add_argument('--data', default="cifar10", type=str,
                    metavar='DATA', help='name of dataset')

wd = 0.0001
adjust = [80, 120]

def coordinate(rank, world_size):
    args = parser.parse_args()

    output_path ='Loss-OrMo_reshuffle_wd'+str(wd)+"_"+str(args.model)+"_"+str(args.data)+"_m_"+str(args.m)+"_epochs_"+str(args.epochs)+"_"+ str(world_size)+"_bs_"+str(args.batch_size)+"_"+str(args.lrdecay)+'_warmuplr_'+str(args.warmup) + '_lr_' + str(args.lr)+'_'+str(args.ttag)

    os.mkdir(output_path)

    model = modelload(args)

    model_flat = flatten_all(model)
    w_flat = flatten(model)
    g_flat = torch.zeros_like(w_flat)
    u_flat = torch.zeros_like(w_flat)

    w_store_flat = torch.zeros_like(w_flat)

    dist.broadcast(model_flat, world_size) 
    criterion = nn.CrossEntropyLoss().cuda()
    cudnn.benchmark = True
    
    # data loading code
    trainset, valset = dataload(args.data)
    num_ite_epoch = math.ceil(len(trainset)/args.batch_size/world_size)*world_size   
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size*10, pin_memory=True,shuffle=False, num_workers=2)
    val_loader = torch.utils.data.DataLoader(valset, batch_size=100, pin_memory=True,shuffle=False, num_workers=2)
    
    t1 = time.time()
    time_spend = []
    lr_list = []
    delay_list = []

    print(math.ceil(3/7))
    for ite in range(num_ite_epoch * args.epochs):
        epoch = ite // num_ite_epoch
        # adjust learning rate
        current_lr = adjust_learning_rate(epoch, args)

        g_flat.zero_()
        time_stamp_old = torch.FloatTensor([0])   

        src = dist.recv(time_stamp_old, tag=111)
        dist.recv(g_flat, src, tag=222)

        delay_list.append(ite-int(time_stamp_old))

        head_bucket = math.ceil(ite / world_size)
        index_bucket = math.ceil(int(time_stamp_old) / world_size)
        bucket_delay = head_bucket - index_bucket

        if math.ceil(ite / world_size) != math.ceil((ite-1) / world_size):
            w_flat.add_(-args.m, u_flat)
            u_flat.mul_(args.m)

        u_flat.add_(g_flat, alpha = current_lr*(args.m**bucket_delay))

        w_flat.add_(g_flat, alpha=-(current_lr*(1-(args.m**(bucket_delay+1))))/(1-args.m))             
               
        time_stamp_t = torch.FloatTensor([ite+1])
        dist.send(time_stamp_t, src, tag=333)
        dist.send(w_flat, src, tag=444)

        if (ite+1) % num_ite_epoch == 0:
            t2 = time.time()       

            w_store_flat.copy_(w_flat)
            dist.send(w_store_flat, src, tag=555)

            time_spend.append(t2-t1)
            lr_list.append(current_lr)

    for k in range(world_size):
        g_flat.zero_()
        time_stamp_old = torch.FloatTensor([0])   

        dist.recv(time_stamp_old, k, tag=111)
        dist.recv(g_flat, k, tag=222)

        time_stamp_t = torch.FloatTensor([num_ite_epoch * (args.epochs+100)])
        dist.send(time_stamp_t, k, tag=333)
    print('training finished.')

    # test saved models
    output_file = open(output_path + '.txt', "w")
    ite_len = len(time_spend)

    for ite in range(ite_len):
        # print(ite)
        checkpoint = torch.load(output_path+'/'+str(ite+1)+'.pth')
        model.load_state_dict(checkpoint['model'])

        loss_train, prec_train = validate(train_loader, model, criterion)
        loss_val, prec_tval = validate(val_loader, model, criterion)
        os.remove(output_path + '/' + str(ite+1) + '.pth')
        output_file.write('%d %3f %3f %3f %3f %3f %3f\n' % (ite, time_spend[ite], lr_list[ite], loss_train, prec_train, loss_val, prec_tval))
        output_file.flush()
    output_file.write('%d\n' % (num_ite_epoch * args.epochs))
    output_file.flush()           
    # close output file, stop
    json.dump(sorted(Counter(delay_list).items()), output_file)

    output_file.close()
    os.removedirs(output_path)

def run(rank, world_size):
    args = parser.parse_args()

    print('Start node: %d  Total: %3d' % (rank, world_size))

    output_path ='Loss-OrMo_reshuffle_wd'+str(wd)+"_"+str(args.model)+"_"+str(args.data)+"_m_"+str(args.m)+"_epochs_"+str(args.epochs)+"_"+ str(world_size)+"_bs_"+str(args.batch_size)+"_"+str(args.lrdecay)+'_warmuplr_'+str(args.warmup) + '_lr_' + str(args.lr)+'_'+str(args.ttag)

    model = modelload(args)
    model_flat = flatten_all(model)

    time_stamp = torch.FloatTensor([0])
    dist.broadcast(model_flat, world_size)
    unflatten_all(model, model_flat)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    cudnn.benchmark = True

    ttime = [0]
    iite = [0]
    # Data loading code
    trainset, valset = dataload(args.data)
    num_ite_epoch = math.ceil(len(trainset)/args.batch_size/world_size)*world_size   
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, num_replicas=world_size, rank=rank)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, pin_memory=True,shuffle=False, num_workers=2, sampler=train_sampler)

    for epoch in range(1000):
        # train for one epoch
        train_sampler.set_epoch(epoch)
        next_epoch = train(train_loader, model, criterion, epoch, rank, world_size, output_path, num_ite_epoch, args, time_stamp, ttime, iite)

        if next_epoch == 0:
            break
    print(str(rank)+':'+str(ttime[0]/iite[0]))
  
def train(train_loader, model, criterion, epoch, rank, world_size, output_path, num_ite_epoch, args, time_stamp, ttime, iite):
    # switch to train mode
    model.train()
    # cost = 0
    w_flat = flatten(model)
    g_flat = torch.zeros_like(w_flat)
    w_store_flat = torch.zeros_like(w_flat)

    for i, (input, target) in enumerate(train_loader):
        t1 = time.time()

        input_var = torch.autograd.Variable(input.cuda())
        target = target.cuda()
        target_var = torch.autograd.Variable(target)

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # compute gradient and do SGD step
        model.zero_grad()
        loss.backward()

        g_flat.zero_()
        flatten_g(model, g_flat)

        w_flat = flatten(model)
        g_flat.add_(wd, w_flat)

        t2 = time.time()        

        ttime[0] += t2-t1
        iite[0] +=1

        time_stamp_new = torch.FloatTensor([0])

        # communicate
        dist.send(time_stamp, world_size, tag=111)        
        dist.send(g_flat, world_size, tag=222)
        dist.recv(time_stamp_new, world_size, tag=333)
        
        cur_time_stamp = int(time_stamp_new[0].numpy())      
        if cur_time_stamp == num_ite_epoch * (args.epochs+100):
            return 0

        dist.recv(w_flat, world_size, tag=444)
        time_stamp.copy_(time_stamp_new)

        # save model for testing
        if (cur_time_stamp) % num_ite_epoch == 0:
            dist.recv(w_store_flat, world_size, tag=555)
            unflatten(model, w_store_flat)
            print(str(cur_time_stamp // num_ite_epoch)+' save')
            state = {
                'model': model.state_dict(),
            }
            torch.save(state, output_path + '/' + str(int(numpy.ceil(cur_time_stamp/num_ite_epoch))) + '.pth')
        unflatten(model, w_flat)

    return 1

def adjust_learning_rate(epoch, args):
    lr = args.lr
    if hasattr(args, 'warmup') and epoch < args.warmup:
        lr = (epoch+1) * lr / (args.warmup)
    elif args.lrdecay == 'cosine':
        lr *= 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup) / (args.epochs - args.warmup)))
    elif args.lrdecay == 'stagewise':
        decay_stage = [0.5, 0.75]
        count = 0
        for i in decay_stage:
            if epoch >= args.epochs*i:
                count += 1
        lr = args.lr * (0.1**count)
    return lr


def modelload(args):  
    if args.model == 'resnet20':
        if args.data == 'cifar10':
            model = resnet20(num_classes = 10).cuda()
    return model


def dataload(dataname):
    if dataname == 'cifar10':
        train_transform = transforms.Compose(
            [transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
        trainset = datasets.CIFAR10(root='./data/cifar10', train=True,download=False, transform=train_transform)

        val_transform = transforms.Compose(
            [transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
        valset = datasets.CIFAR10(root='./data/cifar10', train=False,download=False, transform=val_transform)

    return trainset, valset

def validate(val_loader, model, criterion):
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for i, (input, target) in enumerate(val_loader):
        input_var = torch.autograd.Variable(input.cuda())
        target = target.cuda()
        target_var = torch.autograd.Variable(target)

        # compute output
        with torch.no_grad():
            output = model(input_var)
            loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target, topk=(1,))
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1[0], input.size(0))

    return losses.avg, top1.avg


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def flatten_all(model):
    vec = []
    for param in model.parameters():
        vec.append(param.data.view(-1))
    for b in model.buffers():
        vec.append(b.data.float().view(-1))
    return torch.cat(vec)


def unflatten_all(model, vec):
    pointer = 0
    for param in model.parameters():
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        param.data = vec[pointer:pointer + num_param].view(param.size())
        pointer += num_param
    for b in model.buffers():
        num_param = torch.prod(torch.LongTensor(list(b.size())))
        b.data = vec[pointer:pointer + num_param].view(b.size())
        pointer += num_param


def flatten(model):
    vec = []
    for param in model.parameters():
        vec.append(param.data.view(-1))
    return torch.cat(vec)


def unflatten(model, vec):
    pointer = 0
    for param in model.parameters():
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        param.data = vec[pointer:pointer + num_param].view(param.size())
        pointer += num_param


def flatten_g(model, vec):
    pointer = 0
    for param in model.parameters():
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        vec[pointer:pointer + num_param] = param.grad.data.view(-1)
        pointer += num_param


def unflatten_g(model, vec):
    pointer = 0
    for param in model.parameters():
        num_param = torch.prod(torch.LongTensor(list(param.size())))
        param.grad.data = vec[pointer:pointer + num_param].view(param.size())
        pointer += num_param


if __name__ == '__main__':
    dist.init_process_group('mpi')
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    # run(rank, world_size)

    if rank == world_size - 1:
        coordinate(rank, world_size - 1)
    else:
        run(rank, world_size - 1)
