'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import autocast, GradScaler

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from resnet import ResNet18

from utils import progress_bar
import numpy as np


parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--wrong-rate', default=0, type=float, help='wrong rate')
parser.add_argument('--checkpoint-name', default='ckptddp.pth', type=str, help='checkpoint name')
parser.add_argument('--epochs', default=200, type=int, help='number of epochs to train')
args = parser.parse_args()


def maybe_flip_labels(c, flip_rate, num_classes):
    # Decide whether to flip the label
    if np.random.rand() < flip_rate:
        # Get all possible classes and select a wrong label different from the original one
        possible_labels = list(range(num_classes))
        possible_labels.remove(c)
        wrong_label = np.random.choice(possible_labels)
        return wrong_label
    return c

# Setup DDP:
dist.init_process_group("nccl")
assert 256 % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
local_rank = dist.get_rank()
device = local_rank % torch.cuda.device_count()
local_seed = 1334 + local_rank
torch.cuda.set_device(device)

best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
train_sampler = DistributedSampler(trainset)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=(train_sampler is None), num_workers=2, sampler=train_sampler)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
test_sampler = DistributedSampler(testset)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2, sampler=test_sampler)

# Model
print('==> Building model..')
net = ResNet18()
net = net.to(device)
net = DDP(net, device_ids=[local_rank])


# Calculate the number of parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

num_params = count_parameters(net)
if dist.get_rank() == 0:
    print(f'Number of parameters: {num_params}')

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
scaler = GradScaler()

# Training
def train(epoch):
    if dist.get_rank() == 0:
        print('\nEpoch: %d' % epoch)
    net.train()
    train_sampler.set_epoch(epoch)
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        # targets_flipped = torch.tensor([maybe_flip_labels(targets.item(), flip_rate=args.wrong_rate, num_classes=10) for label in targets])
        targets_flipped = torch.tensor([maybe_flip_labels(label.item(), flip_rate=args.wrong_rate, num_classes=10) for label in targets], device=device)
        with autocast():
            outputs = net(inputs)
            # loss = criterion(outputs, targets)
            loss = criterion(outputs, targets_flipped)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if dist.get_rank() == 0:
            progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))


def test(epoch):
    global best_acc
    net.eval()
    test_sampler.set_epoch(epoch)
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            with autocast():
                outputs = net(inputs)
                loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            if dist.get_rank() == 0:
                progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                             % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Reduce accuracy and loss across all processes
    reduced_loss = torch.tensor(test_loss).to(device)
    reduced_total = torch.tensor(total).to(device)
    reduced_correct = torch.tensor(correct).to(device)

    dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM)
    dist.all_reduce(reduced_total, op=dist.ReduceOp.SUM)
    dist.all_reduce(reduced_correct, op=dist.ReduceOp.SUM)

    reduced_loss = reduced_loss.item() / dist.get_world_size()
    reduced_total = reduced_total.item()
    reduced_correct = reduced_correct.item()

    acc = 100. * reduced_correct / reduced_total
    if dist.get_rank() == 0:
        print(f'Test Loss: {reduced_loss:.3f} | Test Acc: {acc:.3f}% ({reduced_correct}/{reduced_total})')

    # Save checkpoint.
    if acc > best_acc and dist.get_rank() == 0:
        print('Saving..')
        state = {
            'net': net.module.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, os.path.join('./checkpoint', args.checkpoint_name))
        best_acc = acc


for epoch in range(start_epoch, start_epoch+args.epochs):
    train(epoch)
    test(epoch)
    scheduler.step()

# Clean up
dist.destroy_process_group()