# train.py
import os
import argparse
import pyvww
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, DistributedSampler
from models.vit_base import get_vit_base
from models.vit import get_vit_tiny
from models.vit_multi_branch import get_multibranch_vit
from loader import TinyImageNetDataset


def get_dataloaders(dataset_name, data_dir, batch_size, rank, world_size, img_size):
    normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    transform_train = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomCrop(img_size, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    transform_val = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        normalize,
    ])

    if dataset_name.lower() == 'tiny-imagenet':
        # trainset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform_train)
        trainset = TinyImageNetDataset(data_dir, "train", transform_train)
        valset = TinyImageNetDataset(data_dir, 'val', transform_val)
        num_classes = 200
    elif dataset_name.lower() == 'cifar10':
        trainset = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform_train)
        valset = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform_val)
        num_classes = 10
    elif dataset_name == 'VWW':
        # /usr/homes/cxz760/data os.path.join('/fs/scratch/PDS0359/data/visualwakewords', 'annotations/instances_train.json'
        trainset = pyvww.pytorch.VisualWakeWordsClassification(root=os.path.join(data_dir, 'coco2014/all2014'),
                                                                    annFile=os.path.join(data_dir, 'visualwakewords/annotations/instances_train.json'),
                                                                    transform=transform_train)
        valset = pyvww.pytorch.VisualWakeWordsClassification(root=os.path.join(data_dir, 'coco2014/all2014'),
                                                               annFile=os.path.join(data_dir, 'visualwakewords/annotations/instances_val.json'),
                                                               transform=transform_val)
        num_classes = 2
    else:
        raise ValueError("Unsupported dataset. Choose from 'tiny-imagenet' or 'cifar10' or 'VWW'.")

    train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(valset, num_replicas=world_size, rank=rank, shuffle=False)

    trainloader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=4, pin_memory=True)
    valloader = DataLoader(valset, batch_size=batch_size, sampler=val_sampler, num_workers=4, pin_memory=True)
    return trainloader, valloader, train_sampler, num_classes


def train(rank, world_size, args):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    trainloader, valloader, train_sampler, num_classes = get_dataloaders(args.dataset, args.data_dir, args.batch_size, rank, world_size, args.img_size)

    # Load model
    if args.parallel:
        model = get_multibranch_vit(num_classes=num_classes).to(device)
    else:
        model = get_vit_small(num_classes=num_classes, pretrained=False).to(device)

    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    for epoch in range(args.epochs):
        train_sampler.set_epoch(epoch)
        model.train()
        running_loss, total, correct = 0.0, 0, 0

        for step, (imgs, labels) in enumerate(trainloader):
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * imgs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            if step % 20 == 0 and rank == 0:
                print(f"[Step {step}/{len(trainloader)}] Train Loss: {loss.item()}")

        if rank == 0:
            train_acc = correct / total
            print(f"[Epoch {epoch+1}] Train Loss: {running_loss/total:.4f}, Acc: {train_acc:.4f}")
        scheduler.step()
        if epoch % 5 == 0:
            evaluate(model, valloader, device, rank)
    evaluate(model, valloader, device, rank)
    dist.destroy_process_group()


def evaluate(model, dataloader, device, rank):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            outputs = model(imgs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    acc = correct / total
    if rank == 0:
        print(f"Validation Acc: {acc:.4f}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='tiny-imagenet', choices=['tiny-imagenet', 'cifar10', 'VWW'])
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--img_size', type=int, default=224)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--lr', type=float, default=5e-3)
    parser.add_argument('--parallel', action='store_true', help="Use parallel ViT model")
    parser.add_argument('--world_size', type=int, default=torch.cuda.device_count())
    args = parser.parse_args()

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    mp.spawn(train, args=(args.world_size, args), nprocs=args.world_size, join=True)


if __name__ == '__main__':
    main()
