import argparse
import math
import os
import random
import time
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms

# Optional but recommended (you already import timm in your model file)
from timm.data import Mixup
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

# -----------------------------------------------------------------------------
# Import your model. Put the code you shared into `model.py` in the same folder.
# It must expose a function `get_vit_tiny(num_classes, pretrained=False, dropout=0.1)`
# -----------------------------------------------------------------------------
from model import get_vit_tiny  # noqa: E402


# --------------------------- Utils -------------------------------------------
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.sum = 0.0
        self.cnt = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.cnt += n

    @property
    def avg(self):
        return self.sum / max(1, self.cnt)


def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def is_main_process():
    return os.getenv("RANK", "0") == "0"


def init_distributed_mode():
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
    else:
        print("Not using distributed mode")
        return None, None, None

    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(local_rank)
    return rank, world_size, local_rank


def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# Warmup + Cosine scheduler helper
class WarmupCosine:
    def __init__(self, optimizer, warmup_epochs, max_epochs, base_lr, min_lr=1e-6, steps_per_epoch=1):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.base_lr = base_lr
        self.min_lr = min_lr
        self.steps_per_epoch = steps_per_epoch
        self.total_steps = max(1, max_epochs * steps_per_epoch)
        self.warmup_steps = warmup_epochs * steps_per_epoch
        self.step_num = 0

    def step(self):
        self.step_num += 1
        if self.step_num <= self.warmup_steps and self.warmup_steps > 0:
            lr = self.base_lr * self.step_num / float(self.warmup_steps)
        else:
            progress = (self.step_num - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
            lr = self.min_lr + (self.base_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
        for pg in self.optimizer.param_groups:
            pg["lr"] = lr
        return lr


# --------------------------- Training / Eval ---------------------------------

def build_dataloaders(data_dir, img_size, batch_size, workers, world_size, rank, mixup_alpha, cutmix_alpha):
    # DeiT-style augmentations
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(img_size, scale=(0.08, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.25, value="random"),
    ])

    val_transform = transforms.Compose([
        transforms.Resize(int(img_size / 0.875), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    train_dir = os.path.join(data_dir, "train")
    val_dir = os.path.join(data_dir, "val")

    train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
    val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)

    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=workers,
        pin_memory=True,
        drop_last=True,
        persistent_workers=workers > 0,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        sampler=val_sampler,
        num_workers=workers,
        pin_memory=True,
        drop_last=False,
        persistent_workers=workers > 0,
    )

    mixup_fn = None
    if mixup_alpha > 0 or cutmix_alpha > 0:
        mixup_fn = Mixup(mixup_alpha=mixup_alpha, cutmix_alpha=cutmix_alpha, num_classes=len(train_dataset.classes))

    return train_loader, val_loader, mixup_fn, len(train_dataset.classes)


def save_checkpoint(state, path):
    if is_main_process():
        torch.save(state, path)


def train_one_epoch(model, train_loader, criterion, optimizer, scaler, mixup_fn, device, epoch, scheduler):
    model.train()
    loss_meter = AverageMeter()
    top1_meter = AverageMeter()

    for i, (images, targets) in enumerate(train_loader):
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if mixup_fn is not None:
            images, targets = mixup_fn(images, targets)

        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, targets)

        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        current_lr = scheduler.step()

        # Track accuracy only when not using soft targets
        if isinstance(criterion, (nn.CrossEntropyLoss, LabelSmoothingCrossEntropy)):
            acc1, = accuracy(outputs, targets.argmax(dim=1) if outputs.shape == targets.shape else targets, topk=(1,))
            top1_meter.update(acc1.item(), images.size(0))

        loss_meter.update(loss.item(), images.size(0))

        if is_main_process() and (i % 50 == 0):
            print(f"Epoch[{epoch}] Step[{i}/{len(train_loader)}] loss={loss_meter.avg:.4f} lr={current_lr:.6f}")

    return loss_meter.avg, top1_meter.avg


def validate(model, val_loader, device):
    model.eval()
    loss_meter = AverageMeter()
    top1_meter = AverageMeter()
    top5_meter = AverageMeter()
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for images, targets in val_loader:
            images = images.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, targets)
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            loss_meter.update(loss.item(), images.size(0))
            top1_meter.update(acc1.item(), images.size(0))
            top5_meter.update(acc5.item(), images.size(0))
    return loss_meter.avg, top1_meter.avg, top5_meter.avg


# ------------------------------ Main -----------------------------------------

def main():
    parser = argparse.ArgumentParser(description="DeiT-style DDP training on ImageNet for provided ViT model")
    parser.add_argument('--data', type=str, required=True, help='path to ImageNet root (containing train/ and val/)')
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--batch-size', type=int, default=128, help='per-GPU batch size')
    parser.add_argument('--img-size', type=int, default=224)
    parser.add_argument('--lr', type=float, default=5e-4, help='base LR for total batch=1024 (scaled linearly)')
    parser.add_argument('--min-lr', type=float, default=1e-6)
    parser.add_argument('--weight-decay', type=float, default=0.05)
    parser.add_argument('--warmup-epochs', type=int, default=5)
    parser.add_argument('--workers', type=int, default=8)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--label-smoothing', type=float, default=0.1)
    parser.add_argument('--mixup-alpha', type=float, default=0.8)
    parser.add_argument('--cutmix-alpha', type=float, default=1.0)
    parser.add_argument('--resume', type=str, default='')
    parser.add_argument('--output', type=str, default='./outputs')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--no-amp', action='store_true')

    args = parser.parse_args()

    os.makedirs(args.output, exist_ok=True)

    rank, world_size, local_rank = init_distributed_mode()
    assert dist.is_initialized(), 'Please launch with torchrun for DDP.'

    set_seed(args.seed + (rank or 0))
    cudnn.benchmark = True

    device = torch.device(f'cuda:{local_rank}')

    # ---------------- Model ----------------
    model = get_vit_tiny(num_classes=1000, pretrained=False, dropout=args.dropout)
    model.cuda(local_rank)

    # Optimizer
    # Linear LR scaling by global batch size (DeiT recipe)
    effective_batch = args.batch_size * world_size
    base_lr = args.lr * effective_batch / 1024.0

    param_groups = [
        {"params": [p for n, p in model.named_parameters() if p.requires_grad and p.ndim >= 2], "weight_decay": args.weight_decay},
        {"params": [p for n, p in model.named_parameters() if p.requires_grad and p.ndim < 2], "weight_decay": 0.0},  # biases, norms
    ]
    optimizer = optim.AdamW(param_groups, lr=base_lr, betas=(0.9, 0.999))

    # Data
    train_loader, val_loader, mixup_fn, num_classes = build_dataloaders(
        args.data, args.img_size, args.batch_size, args.workers, world_size, rank, args.mixup_alpha, args.cutmix_alpha
    )

    # Loss
    if mixup_fn is not None:
        criterion = SoftTargetCrossEntropy()
    else:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)

    # DDP wrap after optimizer created is fine; ensure broadcast buffers True
    model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)

    scaler = torch.cuda.amp.GradScaler(enabled=not args.no-amp)

    steps_per_epoch = len(train_loader)
    scheduler = WarmupCosine(optimizer, args.warmup_epochs, args.epochs, base_lr=base_lr, min_lr=args.min_lr, steps_per_epoch=steps_per_epoch)

    start_epoch = 0
    best_acc1 = 0.0

    # Resume
    if args.resume and os.path.isfile(args.resume):
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.module.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scaler.load_state_dict(checkpoint['scaler'])
        start_epoch = checkpoint.get('epoch', 0)
        best_acc1 = checkpoint.get('best_acc1', 0.0)
        # fast-forward scheduler
        scheduler.step_num = checkpoint.get('sched_step', start_epoch * steps_per_epoch)
        print(f"Resumed from {args.resume} at epoch {start_epoch}")

    if is_main_process():
        print(f"Starting training for {args.epochs} epochs; world_size={world_size}, per-GPU batch={args.batch_size}, LR(base)={base_lr:.6f}")

    for epoch in range(start_epoch, args.epochs):
        train_loader.sampler.set_epoch(epoch)

        train_loss, train_acc1 = train_one_epoch(
            model, train_loader, criterion, optimizer, scaler, mixup_fn, device, epoch, scheduler
        )

        # Validation (synchronize BN not required; ViT has LN only)
        val_loss, val_acc1, val_acc5 = validate(model, val_loader, device)

        if is_main_process():
            print(f"Epoch {epoch:03d}: train_loss={train_loss:.4f}, train_acc1={train_acc1:.2f} | val_loss={val_loss:.4f}, val@1={val_acc1:.2f}, val@5={val_acc5:.2f}")

            is_best = val_acc1 > best_acc1
            best_acc1 = max(best_acc1, val_acc1)
            ckpt = {
                'epoch': epoch + 1,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scaler': scaler.state_dict(),
                'best_acc1': best_acc1,
                'sched_step': scheduler.step_num,
                'args': vars(args),
            }
            save_checkpoint(ckpt, os.path.join(args.output, 'checkpoint_latest.pth'))
            if is_best:
                save_checkpoint(ckpt, os.path.join(args.output, 'checkpoint_best.pth'))

    if is_main_process():
        print(f"Training complete. Best top-1: {best_acc1:.2f}")


if __name__ == '__main__':
    main()
