# train.py
import os
import argparse
import pyvww
import time  # <<< 新增

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.transforms import RandAugment
from torch.cuda.amp import autocast, GradScaler

from timm.data import Mixup
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.data.transforms import RandomResizedCropAndInterpolation

from models.vit_base import get_vit_base
from models.vit import get_vit_tiny
from models.vit_multi_branch import get_multibranch_vit
from loader import TinyImageNetDataset


class WarmupCosine:
    """
    Linear warm-up for `warmup_epochs`, then cosine decay until `max_epochs`.
    Works with PyTorch's LambdaLR.
    """
    def __init__(self, warmup_epochs: int, max_epochs: int):
        self.warmup = max(0, int(warmup_epochs))
        self.max_epochs = max_epochs

    def __call__(self, epoch: int):
        if epoch < self.warmup and self.warmup > 0:
            return float(epoch + 1) / float(self.warmup)
        # cosine from 1 -> 0
        t = (epoch - self.warmup) / max(1, (self.max_epochs - self.warmup))
        import math
        return 0.5 * (1.0 + math.cos(math.pi * t))


def get_layerwise_lr_decay_params(model, lr_base=1e-4, decay_rate=0.8, weight_decay=0.01):
    """
    返回按层衰减学习率的 param_groups
    适配自定义 ViT (model.blocks)
    """
    param_groups = []
    num_layers = len(model.blocks)

    # Transformer blocks
    for i, layer in enumerate(model.blocks):
        lr = lr_base * (decay_rate ** (num_layers - 1 - i))
        param_groups.append({
            "params": layer.parameters(),
            "lr": lr,
            "weight_decay": weight_decay
        })

    # patch embedding 层（最小 lr）
    param_groups.append({
        "params": model.patch_embed.parameters(),
        "lr": lr_base * (decay_rate ** num_layers),
        "weight_decay": weight_decay
    })

    # 最后的 norm + head（通常最大 lr）
    param_groups.append({
        "params": list(model.norm.parameters()) + list(model.head.parameters()),
        "lr": lr_base,
        "weight_decay": weight_decay
    })

    return param_groups


def get_dataloaders(dataset_name, data_dir, batch_size, rank, world_size, img_size):
    normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                     std=(0.229, 0.224, 0.225))
    transform_train = transforms.Compose([
        RandomResizedCropAndInterpolation(img_size),
        transforms.RandomHorizontalFlip(),
        RandAugment(num_ops=10, magnitude=5),  # <<< torchvision 版本
        transforms.ColorJitter(0.3, 0.3, 0.3, 0.1),  # <<< color jitter=0.3
        transforms.ToTensor(),
        normalize,
    ])
    # 256 -> center crop 224 这类套路（跟随论文“resize到224×224”精神；更稳的是验证时 256/224）
    transform_val = transforms.Compose([
        transforms.Resize(int(img_size * 1.14)),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        normalize,
    ])

    if dataset_name.lower() == 'tiny-imagenet':
        # trainset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform_train)
        trainset = TinyImageNetDataset(data_dir, "train", transform_train)
        valset = TinyImageNetDataset(data_dir, 'val', transform_val)
        num_classes = 200
    elif dataset_name.lower() == 'cifar10':
        trainset = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform_train)
        valset = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform_val)
        num_classes = 10
    elif dataset_name.lower() == 'cifar100':
        trainset = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform_train)
        valset = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform_val)
        num_classes = 100
    elif dataset_name == 'VWW':
        # /usr/homes/cxz760/data os.path.join('/fs/scratch/PDS0359/data/visualwakewords', 'annotations/instances_train.json'
        trainset = pyvww.pytorch.VisualWakeWordsClassification(root=os.path.join(data_dir, 'coco2014/all2014'),
                                                                    annFile=os.path.join(data_dir, 'visualwakewords/annotations/instances_train.json'),
                                                                    transform=transform_train)
        valset = pyvww.pytorch.VisualWakeWordsClassification(root=os.path.join(data_dir, 'coco2014/all2014'),
                                                               annFile=os.path.join(data_dir, 'visualwakewords/annotations/instances_val.json'),
                                                               transform=transform_val)
        num_classes = 2
    elif dataset_name.lower() == 'imagenet':
        # Assuming the ImageNet data is organized in standard ImageFolder structure
        train_dir = os.path.join(data_dir, 'train')
        val_dir = os.path.join(data_dir, 'val')

        trainset = datasets.ImageFolder(train_dir, transform=transform_train)
        valset = datasets.ImageFolder(val_dir, transform=transform_val)
        num_classes = 1000  # Standard ImageNet has 1000 classes
    else:
        raise ValueError("Unsupported dataset. Choose from 'tiny-imagenet' or 'cifar10' or 'VWW'.")

    train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(valset, num_replicas=world_size, rank=rank, shuffle=False)

    trainloader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=10, pin_memory=True)
    valloader = DataLoader(valset, batch_size=batch_size, sampler=val_sampler, num_workers=10, pin_memory=True)
    return trainloader, valloader, train_sampler, num_classes


def train(rank, world_size, args):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    trainloader, valloader, train_sampler, num_classes = get_dataloaders(args.dataset, args.data_dir, args.batch_size, rank, world_size, args.img_size)

    # Load model
    if args.parallel:
        model = get_multibranch_vit(num_classes=num_classes, dropout=0).to(device)
    else:
        model = get_vit_tiny(num_classes=num_classes, pretrained=False, dropout=0, depth=args.depth).to(device)

    if args.resume and os.path.isfile(args.resume):
        checkpoint = None
        if dist.get_rank() == 0:
            checkpoint = torch.load(args.resume, map_location="cpu")
            # print("===", checkpoint["model"].keys())

        # 所有 rank 上都要有这个 list，不然 broadcast_object_list 会找不到变量
        obj_list = [checkpoint]
        dist.broadcast_object_list(obj_list, src=0)
        checkpoint = obj_list[0]

        if checkpoint is None:
            raise RuntimeError("Failed to load checkpoint")

        # map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
        # checkpoint = torch.load(args.resume, map_location=map_location)
        state_dict = checkpoint["model"]
        # 只加载 encoder 权重
        encoder_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith("module.model.patch_embed"):
                new_k = k.replace("module.model.patch_embed", "patch_embed")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.blocks"):
                new_k = k.replace("module.model.blocks", "blocks")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.cls_token"):
                new_k = k.replace("module.model.cls_token", "cls_token")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.pos_embed"):
                new_k = k.replace("module.model.pos_embed", "pos_embed")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.norm"):
                new_k = k.replace("module.model.norm", "norm")
                encoder_state_dict[new_k] = v
            # 过滤掉 decoder / tch_model / distill_criterion / adapter
            else:
                continue

        missing, unexpected = model.load_state_dict(encoder_state_dict, strict=False)

        if rank == 0:
            start_epoch = checkpoint.get("start_epoch", 0)
            print(f"=> Loaded checkpoint '{args.resume}' (only encoder) at epoch '{start_epoch}'")
            if missing: print("Missing keys:", missing)
            if unexpected: print("Unexpected keys:", unexpected)

    # # ====== Linear Probing for ImageNet ======
    # if args.dataset.lower() == "imagenet":
    #     if rank == 0:
    #         print(">>> Using Linear Probing (freeze backbone, train classifier only) <<<")
    #     for name, param in model.named_parameters():
    #         if "head" not in name:  # 只让分类头训练
    #             param.requires_grad = False
    #
    #     # 只对分类头用 DDP
    #     model.head = nn.parallel.DistributedDataParallel(
    #         model.head, device_ids=[rank]
    #     )

    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True)

    param_groups = get_layerwise_lr_decay_params(
        model.module,  # 注意要传进去 unwrap 后的模型
        lr_base=args.lr,
        decay_rate=0.85,
        weight_decay=0.05
    )

    optimizer = optim.AdamW(
        param_groups,
        betas=(0.9, 0.999),  # <<< 表 A1
        weight_decay=0.05
    )
    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=WarmupCosine(args.warmup_epochs, args.epochs)
    )

    criterion = SoftTargetCrossEntropy().to(device)  # <<< label smoothing
    mixup_fn = Mixup(mixup_alpha=0.2, cutmix_alpha=1.0, prob=1.0, label_smoothing=0.0, num_classes=num_classes)
    scaler = GradScaler()

    # ===== 训练开始时间 =====
    if rank == 0:
        total_start = time.time()

    for epoch in range(args.epochs):
        epoch_start = time.time()  # <<< epoch开始计时
        train_sampler.set_epoch(epoch)
        model.train()
        running_loss, total, correct = 0.0, 0, 0
        for step, (imgs, labels) in enumerate(trainloader):
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            # <<< mixup
            imgs, labels = mixup_fn(imgs, labels)

            optimizer.zero_grad()
            with autocast():
                outputs = model(imgs)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * imgs.size(0)
            total += labels.size(0)
            # _, predicted = outputs.max(1)
            # correct += predicted.eq(labels).sum().item()
            _, predicted = outputs.max(1)
            acc_targets = labels.argmax(dim=1) if labels.ndim == 2 else labels
            correct += predicted.eq(acc_targets).sum().item()
            if step % 20 == 0 and rank == 0:
                print(f"[Step {step}/{len(trainloader)}] Train Loss: {loss.item()}")

        if rank == 0:  # <<< 修改 rank==10 → rank==0
            epoch_time = time.time() - epoch_start
            train_acc = correct / total
            print(f"[Epoch {epoch + 1}] Train Loss: {running_loss / total:.4f}, "
                  f"Acc: {train_acc:.4f}, "
                  f"Time: {epoch_time / 60:.2f} min")
        scheduler.step()
        if epoch % 20 == 0:
            evaluate(model, valloader, device, rank)

        if epoch % 10 == 0 and rank == 0:
            ckpt = {
                "start_epoch": epoch + 1,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "scaler": scaler.state_dict()
            }
            if not os.path.exists(args.save):
                os.makedirs(args.save)
            filename = os.path.join(args.save, "latest_ckpt.pth.tar")
            torch.save(ckpt, filename)

        # ===== 训练结束时间 =====
        if rank == 0:
            total_time = time.time() - total_start
            print(f"=== Training completed in {total_time / 60:.2f} min ({total_time / 3600:.2f} h) ===")

    evaluate(model, valloader, device, rank)
    dist.destroy_process_group()


def evaluate(model, dataloader, device, rank):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            outputs = model(imgs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    acc = correct / total
    if rank == 0:
        print(f"Validation Acc: {acc:.4f}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='tiny-imagenet', choices=['tiny-imagenet', 'cifar10', 'cifar100', 'VWW', 'imagenet'])
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--img_size', type=int, default=224)
    parser.add_argument('--depth', type=int, default=12)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=4e-3)
    parser.add_argument('--warmup_epochs', type=int, default=5, help="linear warmup epochs")

    parser.add_argument('--parallel', action='store_true', help="Use parallel ViT model")
    parser.add_argument('--world_size', type=int, default=torch.cuda.device_count())
    parser.add_argument('--resume', type=str, default='', help="Path to checkpoint to resume")  # >>> 修改
    parser.add_argument('--save', type=str, default='', help="Path to checkpoint to save")
    parser.add_argument('--master_addr', type=str, default='localhost', help="Master node address")
    parser.add_argument('--master_port', type=str, default='12355', help="Master node port")

    args = parser.parse_args()

    os.environ['MASTER_ADDR'] = args.master_addr
    os.environ['MASTER_PORT'] = args.master_port

    mp.spawn(train, args=(args.world_size, args), nprocs=args.world_size, join=True)


if __name__ == '__main__':
    main()
