# train.py
import os
import argparse
import pyvww
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, DistributedSampler
from models.vit_base import get_vit_base
from models.vit import get_vit_tiny
from models.vit_parallel_repa import get_parallel_vit, LambdaScheduler
from loader import TinyImageNetDataset


def get_dataloaders(dataset_name, data_dir, batch_size, rank, world_size, img_size):
    normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    transform_train = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomCrop(img_size, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    transform_val = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        normalize,
    ])

    if dataset_name.lower() == 'tiny-imagenet':
        # trainset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform_train)
        trainset = TinyImageNetDataset(data_dir, "train", transform_train)
        valset = TinyImageNetDataset(data_dir, 'val', transform_val)
        num_classes = 200
    elif dataset_name.lower() == 'cifar10':
        trainset = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform_train)
        valset = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform_val)
        num_classes = 10
    elif dataset_name == 'VWW':
        # /usr/homes/cxz760/data os.path.join('/fs/scratch/PDS0359/data/visualwakewords', 'annotations/instances_train.json'
        trainset = pyvww.pytorch.VisualWakeWordsClassification(root=os.path.join(data_dir, 'coco2014/all2014'),
                                                                    annFile=os.path.join(data_dir, 'visualwakewords/annotations/instances_train.json'),
                                                                    transform=transform_train)
        valset = pyvww.pytorch.VisualWakeWordsClassification(root=os.path.join(data_dir, 'coco2014/all2014'),
                                                               annFile=os.path.join(data_dir, 'visualwakewords/annotations/instances_val.json'),
                                                               transform=transform_val)
        num_classes = 2
    elif dataset_name.lower() == 'imagenet':
        # Assuming the ImageNet data is organized in standard ImageFolder structure
        train_dir = os.path.join(data_dir, 'train')
        val_dir = os.path.join(data_dir, 'val')

        trainset = datasets.ImageFolder(train_dir, transform=transform_train)
        valset = datasets.ImageFolder(val_dir, transform=transform_val)
        num_classes = 1000  # Standard ImageNet has 1000 classes
    else:
        raise ValueError("Unsupported dataset. Choose from 'tiny-imagenet' or 'cifar10' or 'VWW'.")

    train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(valset, num_replicas=world_size, rank=rank, shuffle=False)

    trainloader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=4, pin_memory=True)
    valloader = DataLoader(valset, batch_size=batch_size, sampler=val_sampler, num_workers=4, pin_memory=True)
    return trainloader, valloader, train_sampler, num_classes


def train(rank, world_size, args):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    trainloader, valloader, train_sampler, num_classes = get_dataloaders(args.dataset, args.data_dir, args.batch_size, rank, world_size, args.img_size)

    # Load model
    if args.parallel:
        model = get_parallel_vit(num_classes=num_classes).to(device)
    else:
        model = get_vit_small(num_classes=num_classes, pretrained=False).to(device)

    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    lambda_scheduler = LambdaScheduler(warmup_steps=args.warmup_step, mode='linear')  # 10k steps ramp

    for epoch in range(args.epochs):
        train_sampler.set_epoch(epoch)
        model.train()
        running_loss, total, correct = 0.0, 0, 0

        for step, (imgs, labels) in enumerate(trainloader):
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            optimizer.zero_grad()
            lambda_off = lambda_scheduler.get_lambda(epoch*args.batch_size+step)
            # forward with lambda_off
            outputs = model.forward(imgs, lambda_off)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * imgs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            if step % 20 == 0 and rank == 0:
                print(f"[Step {step}/{len(trainloader)}] Train Loss: {loss.item()} Lambda: {lambda_off}")

        if rank == 0:
            train_acc = correct / total
            print(f"[Epoch {epoch+1}] Train Loss: {running_loss/total:.4f}, Acc: {train_acc:.4f}")
        scheduler.step()
        if epoch % 5 == 0:
            evaluate(model, valloader, device, rank)
    evaluate(model, valloader, device, rank)
    dist.destroy_process_group()


def evaluate(model, dataloader, device, rank):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            outputs = model(imgs, 1)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    acc = correct / total
    if rank == 0:
        print(f"Validation Acc: {acc:.4f}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='tiny-imagenet', choices=['tiny-imagenet', 'cifar10', 'VWW', 'imagenet'])
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--img_size', type=int, default=224)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--lr', type=float, default=5e-3)
    parser.add_argument('--parallel', action='store_true', help="Use parallel ViT model")
    parser.add_argument('--world_size', type=int, default=torch.cuda.device_count())
    parser.add_argument('--warmup_step', type=int, default=25000)
    args = parser.parse_args()

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    mp.spawn(train, args=(args.world_size, args), nprocs=args.world_size, join=True)


if __name__ == '__main__':
    main()
