import os
import csv
import torch
from torch import nn
from torchvision import transforms, datasets
import resnet
import argparse
from copy import deepcopy
import time
import math

parser = argparse.ArgumentParser()
parser.add_argument('--scale_lr', action='store_true')
parser.add_argument('--lr', type=float, default=0.1)

CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

def get_train_loader(batch_size, num_workers=4, shuffle=True, train=True, download=True):

    transform = transforms.Compose(
        [
#        transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])
    trainset = datasets.CIFAR100(root='dataset', train=train,
                                download=download, transform=transform)
    return torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                shuffle=shuffle, num_workers=num_workers, drop_last=True)

def get_test_loader(batch_size, num_workers=4, shuffle=False, train=False, download=True):
    transform = transforms.Compose(
        [transforms.ToTensor(),
#        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])
    trainset = datasets.CIFAR100(root='dataset', train=train,
                                download=download, transform=transform)
    return torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                shuffle=shuffle, num_workers=num_workers)
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.sum_square = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.sum_square += (val**2) * n
        self.count += n
        self.avg = self.sum / self.count

def reset_parameters(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            fi, _ = torch.nn.init._calculate_fan_in_and_fan_out(m.weight)
            init_std = 1/fi**0.5
            torch.nn.init.normal_(m.weight, std = init_std)
            if m.bias != None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.ones_(m.weight)
            if m.bias != None:
                nn.init.zeros_(m.bias)

def get_lrscale(model, optimizer, trainloader, loss_function, batchsize):
    start = time.time()
    losses = AverageMeter()
    grad_model = deepcopy(model)
    grad_model.train()
    for q in grad_model.parameters():
        nn.init.zeros_(q)
        assert torch.is_nonzero(torch.sum(torch.abs(q))) == False
    model.train()
    for batch_index, (images, labels) in enumerate(trainloader):
        labels = labels.cuda()
        images = images.cuda()
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_function(outputs, labels)
        loss.backward()
        with torch.no_grad():
            for p, q in zip(model.parameters(), grad_model.parameters()):
                if p.grad is not None:
                    q.add_(torch.abs(p.grad), alpha=1)
        losses.update(loss.item(), batchsize)

    with torch.no_grad():
        for q in grad_model.parameters():
            q.mul_(1/(batch_index+1))
    print('Initialization Gradient of {total_samples} samples\tLoss: {:0.4f}'.format(
        losses.avg,
        total_samples=len(trainloader.dataset)
    ))
    l_paramnum = []
    l_name = []
    for name in model.named_parameters():
        l_name.append(name[0])
        l_paramnum.append(torch.numel(name[1]))
    lr_model = None
    if args.scale_lr:
        lr_model = deepcopy(model)
        lr_model.train()
        for l in lr_model.parameters():
            nn.init.ones_(l)
        with torch.no_grad():
            gnsum = 0
            psum = 0
            lwlr = []
            for lm, q in zip(lr_model.parameters(), grad_model.parameters()):
                lm.copy_(torch.div(lm, torch.pow(torch.mean(q), 0.5)))
                gnsum += torch.sum(lm).item()
                psum += torch.numel(q)
            pwratio = gnsum / psum
            for lm in lr_model.parameters():
                lm.copy_(lm/pwratio)
                lr_mean = torch.mean(lm).item()
                lwlr.append(lr_mean)
    del grad_model, lr_model
    finish = time.time()
    print('Initialization Gradient calculating time consumed: {:.2f}s'.format(finish - start))
    return lwlr

args = parser.parse_args()

model = resnet.resnet50()
reset_parameters(model)

batchsize=256
weight_decay = 5e-4

print(sum(p.numel() for p in model.parameters()))
model.cuda()
loss_function = nn.CrossEntropyLoss()
param_list = [{'params': [p], 'lr': args.lr*1.0, 'weight_decay': weight_decay} for i, p in enumerate(model.parameters())]
optimizer = torch.optim.SGD(param_list, momentum = 0.9)

trainloader = get_train_loader(batchsize)
testloader = get_test_loader(batchsize)
if args.scale_lr:
    lwlr = get_lrscale(model, optimizer, trainloader, loss_function, batchsize)
    lwlr_iter=iter(lwlr)
    for group in optimizer.param_groups:
        for param in group['params']:
            group['lr']*= next(lwlr_iter)

#print(lwlr)

opt_configs=['lr', args.lr, 'batch size', batchsize, 'weight decay', weight_decay, 'layerwise lr init', args.scale_lr]
lwlr=[]
for group in optimizer.param_groups:
    for param in group['params']:
        lwlr.append(group['lr']/args.lr)
lw_name, lw_param, lw_std = [], [], []
for name in model.named_parameters():
    lw_name.append(name[0])
    lw_param.append(torch.numel(name[1]))
    lw_std.append(torch.std(name[1]).item())

with open("results.csv", 'a') as file:
    csv.writer(file, delimiter=',').writerow(opt_configs)
#    csv.writer(file, delimiter=',').writerow(lw_name)
#    csv.writer(file, delimiter=',').writerow(lw_param)
    csv.writer(file, delimiter=',').writerow(lwlr)

total_epochs = 200
warmupepochs = 1
warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0/(warmupepochs*int(50000//batchsize)), end_factor=1.0, total_iters=warmupepochs*int(50000//batchsize))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(total_epochs-warmupepochs)*(int(50000//batchsize)+0))

model.train()
start = time.time()

#lrcheck = []
for epoch in range(0, total_epochs):
    losses = AverageMeter()
    for batch_index, (images, labels) in enumerate(trainloader):
        labels = labels.cuda()
        images = images.cuda()
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_function(outputs, labels)
        loss.backward()
#        if batch_index == 0:
#            lrcheck.extend([warmup.get_last_lr()[0], scheduler.get_last_lr()[0]])
        optimizer.step()
        losses.update(loss.item(), batchsize)
        warmup.step()
        if epoch >= warmupepochs:
            scheduler.step()

correct = 0
total = 0
model.eval()
with torch.no_grad():
    testloss = AverageMeter()
    for data in testloader:
        images, labels = data
        labels = labels.cuda()
        images = images.cuda()
        outputs = model(images)
        loss = loss_function(outputs, labels)
        testloss.update(loss.item(), batchsize)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

finish = time.time()
print('Train and Test consumed: {:.2f}s, Train loss: {}'.format(finish - start, losses.avg))
print('Accuracy of the network on the 10000 test images:', 100 * correct / total)

with open("results.csv", 'a') as file:
    csv.writer(file, delimiter=',').writerow(['train loss', losses.avg, 'test loss', testloss.avg, 'test accuracy', 100*correct/total])
#    csv.writer(file, delimiter=',').writerow(lrcheck)
