import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

import vit
import preprocess_cifar_dataset
from util import AverageMeter, accuracy, RandomMixup, RandomCutmix

import time

epochs = 100
warmup_epochs = 5
lr = 1e-3
weight_decay = 1e-1
batch_size = 100
clip_grad_norm = None

label_smoothing = 0.1
random_erase_prob = 0.1
mixup_alpha = 0.2
cutmix_alpha = 1.0
amp = True

print(f'vit {epochs} epochs {warmup_epochs} warmup epochs lr {lr} weight decay {weight_decay} batch_size {batch_size}')

dim = 256
heads = 8
mlp_dim = 4 * dim
dim_head = dim // heads
print(f'hidden dimension {dim}, {heads} heads')

model = vit.ViT(image_size=32, patch_size=4, num_classes=10, dim=dim, mlp_dim=mlp_dim, depth=8, heads=heads, dim_head=dim_head).cuda()
# optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
s1 = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.2, total_iters=warmup_epochs)
s2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs - warmup_epochs)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[s1, s2], milestones=[warmup_epochs])

mean = [0.49139967861519745, 0.4821584083946076, 0.44653091444546616]
std = [0.2470322324632823, 0.24348512800005553, 0.2615878417279641]
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.RandomCrop(32, 4),
                                      transforms.TrivialAugmentWide(interpolation=transforms.functional.InterpolationMode.BILINEAR),
                                      transforms.PILToTensor(),
                                      transforms.ConvertImageDtype(torch.float),
                                      transforms.Normalize(mean=mean, std=std),
                                      transforms.RandomErasing(p=random_erase_prob)])

collate_fn = None
mixup_transforms = []
if mixup_alpha > 0.0:
    mixup_transforms.append(RandomMixup(num_classes=10, p=1.0, alpha=mixup_alpha))
if cutmix_alpha > 0.0:
    mixup_transforms.append(RandomCutmix(num_classes=10, p=1.0, alpha=cutmix_alpha))
if mixup_transforms:
    mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
    def collate_fn(batch): return mixupcutmix(*torch.utils.data.dataloader.default_collate(batch))

train_dataset = torchvision.datasets.CIFAR10(root='../data', train=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4, collate_fn=collate_fn)

train_dataset2 = preprocess_cifar_dataset.PreProcessCIFAR10(root='../data', train=True)
train_loader2 = torch.utils.data.DataLoader(train_dataset2, batch_size=10000, shuffle=False, pin_memory=False)
test_dataset = preprocess_cifar_dataset.PreProcessCIFAR10(root='../data', train=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10000, shuffle=False, pin_memory=False)

scaler = torch.cuda.amp.GradScaler() if amp else None


def train(epoch):
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    grad_norm = AverageMeter('GradNorm', ':6.2f')

    model.train()

    for _, (images, target) in enumerate(train_loader):
        images, target = images.cuda(), target.cuda()
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            output = model(images)
            loss = F.cross_entropy(output, target, label_smoothing=label_smoothing)

        acc1 = accuracy(output, target)
        top1.update(acc1[0], images.size(0))
        losses.update(loss.item(), images.size(0))

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            if clip_grad_norm is not None:
                scaler.unscale_(optimizer)
                grad_norm.update(torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm).item(), images.size(0))
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if clip_grad_norm is not None:
                grad_norm.update(torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm).item(), images.size(0))
            optimizer.step()

    print(f'{epoch} {losses.avg} {top1.avg.item():.2f} {grad_norm.avg}', end=' ')


def validate(val_loader, flag):
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')

    model.eval()

    with torch.no_grad():
        for _, (images, target) in enumerate(val_loader):
            output = model(images)
            loss = F.cross_entropy(output, target)
            losses.update(loss.item(), images.size(0))
            acc1 = accuracy(output, target)
            top1.update(acc1[0], images.size(0))

    if flag:
        print(f'{losses.avg} {top1.avg.item():.2f} {output.std(dim=1).mean().item()}', end=' ')
    else:
        print(f'{losses.avg} {top1.avg.item():.2f} {output.std(dim=1).mean().item()}', flush=True)


for epoch in range(epochs):
    start = time.time()
    train(epoch)
    print(time.time() - start, end=' ')

    validate(train_loader2, True)
    validate(test_loader, False)
    scheduler.step()
