# train.py
import os
import argparse
import pyvww
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, DistributedSampler
from models.vit_base import get_vit_base
from models.vit import get_vit_tiny
from models.vit_multi_branch import get_multibranch_vit
from loader import TinyImageNetDataset
from torch.cuda.amp import autocast, GradScaler



def load_checkpoint(path, map_location="cpu"):
    try:
        # 先尝试用默认方式（安全）
        checkpoint = torch.load(path, map_location=map_location)
    except Exception as e:
        print(f"[WARN] Safe load failed ({e}), retrying with weights_only=False")
        # 如果失败，再强制关闭 weights_only
        checkpoint = torch.load(path, map_location=map_location, weights_only=False)

    # 如果只是 state_dict（常见于 .pth）
    if isinstance(checkpoint, dict) and "model" not in checkpoint:
        checkpoint = {"model": checkpoint}

    return checkpoint


class WarmupCosine:
    """
    Linear warm-up for `warmup_epochs`, then cosine decay until `max_epochs`.
    Works with PyTorch's LambdaLR.
    """
    def __init__(self, warmup_epochs: int, max_epochs: int):
        self.warmup = max(0, int(warmup_epochs))
        self.max_epochs = max_epochs

    def __call__(self, epoch: int):
        if epoch < self.warmup and self.warmup > 0:
            return float(epoch + 1) / float(self.warmup)
        # cosine from 1 -> 0
        t = (epoch - self.warmup) / max(1, (self.max_epochs - self.warmup))
        import math
        return 0.5 * (1.0 + math.cos(math.pi * t))


def get_layerwise_lr_decay_params(model, lr_base=1e-4, decay_rate=0.8, weight_decay=0.01):
    """
    返回按层衰减学习率的 param_groups
    适配自定义 ViT (model.blocks)
    """
    param_groups = []
    num_layers = len(model.blocks)

    # Transformer blocks
    for i, layer in enumerate(model.blocks):
        lr = lr_base * (decay_rate ** (num_layers - 1 - i))
        param_groups.append({
            "params": layer.parameters(),
            "lr": lr,
            "weight_decay": weight_decay
        })

    # patch embedding 层（最小 lr）
    param_groups.append({
        "params": model.patch_embed.parameters(),
        "lr": lr_base * (decay_rate ** num_layers),
        "weight_decay": weight_decay
    })

    # 最后的 norm + head（通常最大 lr）
    param_groups.append({
        "params": list(model.norm.parameters()) + list(model.head.parameters()),
        "lr": lr_base,
        "weight_decay": weight_decay
    })

    return param_groups


def get_dataloaders(dataset_name, data_dir, batch_size, rank, world_size, img_size):
    normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                     std=(0.229, 0.224, 0.225))
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    transform_val = transforms.Compose([
        transforms.Resize(int(img_size * 1.14)),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        normalize,
    ])

    dataset_name = dataset_name.lower()

    if dataset_name == 'tiny-imagenet':
        trainset = TinyImageNetDataset(data_dir, "train", transform_train)
        valset = TinyImageNetDataset(data_dir, "val", transform_val)
        num_classes = 200

    elif dataset_name == 'cifar10':
        trainset = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform_train)
        valset = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform_val)
        num_classes = 10

    elif dataset_name == 'cifar100':
        trainset = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform_train)
        valset = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform_val)
        num_classes = 100

    elif dataset_name == 'vww':
        trainset = pyvww.pytorch.VisualWakeWordsClassification(
            root=os.path.join(data_dir, 'coco2014/all2014'),
            annFile=os.path.join(data_dir, 'visualwakewords/annotations/instances_train.json'),
            transform=transform_train)
        valset = pyvww.pytorch.VisualWakeWordsClassification(
            root=os.path.join(data_dir, 'coco2014/all2014'),
            annFile=os.path.join(data_dir, 'visualwakewords/annotations/instances_val.json'),
            transform=transform_val)
        num_classes = 2

    elif dataset_name == 'imagenet':
        train_dir = os.path.join(data_dir, 'train')
        val_dir = os.path.join(data_dir, 'val')
        trainset = datasets.ImageFolder(train_dir, transform=transform_train)
        valset = datasets.ImageFolder(val_dir, transform=transform_val)
        num_classes = 1000

    elif dataset_name == 'flowers':
        trainset = datasets.Flowers102(data_dir, split='train', download=True, transform=transform_train)
        valset = datasets.Flowers102(data_dir, split='test', download=True, transform=transform_val)
        num_classes = 102

    elif dataset_name == 'pets':
        trainset = datasets.OxfordIIITPet(data_dir, split='trainval', download=True, transform=transform_train)
        valset = datasets.OxfordIIITPet(data_dir, split='test', download=True, transform=transform_val)
        num_classes = 37

    elif dataset_name == 'aircraft':
        trainset = datasets.FGVCAircraft(data_dir, split='train', download=True, transform=transform_train)
        valset = datasets.FGVCAircraft(data_dir, split='test', download=True, transform=transform_val)
        num_classes = 100

    elif dataset_name == 'cars':
        trainset = datasets.StanfordCars(data_dir, split='train', transform=transform_train)
        valset = datasets.StanfordCars(data_dir, split='test', transform=transform_val)
        num_classes = 196

    elif dataset_name == 'inat18':
        # 注意 iNaturalist2018 类别特别多（8,142 类），文件很大
        trainset = datasets.INaturalist(data_dir, version='2018', transform=transform_train)
        valset = datasets.INaturalist(data_dir, version='2018', transform=transform_val)
        num_classes = 8142

    else:
        raise ValueError(f"Unsupported dataset {dataset_name}")

    # 分布式采样器
    train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(valset, num_replicas=world_size, rank=rank, shuffle=False)

    trainloader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=8, pin_memory=True)
    valloader = DataLoader(valset, batch_size=batch_size, sampler=val_sampler, num_workers=8, pin_memory=True)

    return trainloader, valloader, train_sampler, num_classes


def train(rank, world_size, args):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    trainloader, valloader, train_sampler, num_classes = get_dataloaders(args.dataset, args.data_dir, args.batch_size, rank, world_size, args.img_size)

    # Load model
    if args.parallel:
        model = get_multibranch_vit(num_classes=num_classes, dropout=0).to(device)
    else:
        model = get_vit_tiny(num_classes=num_classes, pretrained=False, dropout=0).to(device)

    if args.resume and os.path.isfile(args.resume):
        checkpoint = None
        if dist.get_rank() == 0:
            # checkpoint = torch.load(args.resume, map_location="cpu")
            checkpoint = load_checkpoint(args.resume, map_location="cpu")

            # print("===", checkpoint["model"].keys())

        # 所有 rank 上都要有这个 list，不然 broadcast_object_list 会找不到变量
        obj_list = [checkpoint]
        dist.broadcast_object_list(obj_list, src=0)
        checkpoint = obj_list[0]

        if checkpoint is None:
            raise RuntimeError("Failed to load checkpoint")

        # map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
        # checkpoint = torch.load(args.resume, map_location=map_location)
        state_dict = checkpoint["model"]
        # 只加载 encoder 权重
        encoder_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith("module.model.patch_embed"):
                new_k = k.replace("module.model.patch_embed", "patch_embed")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.blocks"):
                new_k = k.replace("module.model.blocks", "blocks")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.cls_token"):
                new_k = k.replace("module.model.cls_token", "cl√s_token")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.pos_embed"):
                new_k = k.replace("module.model.pos_embed", "pos_embed")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.norm"):
                new_k = k.replace("module.model.norm", "norm")
                encoder_state_dict[new_k] = v
            # 过滤掉 decoder / tch_model / distill_criterion / adapter
            else:
                continue

        missing, unexpected = model.load_state_dict(encoder_state_dict, strict=False)

        if rank == 0:
            start_epoch = checkpoint.get("start_epoch", 0)
            print(f"=> Loaded checkpoint '{args.resume}' (only encoder) at epoch '{start_epoch}'")
            if missing: print("Missing keys:", missing)
            if unexpected: print("Unexpected keys:", unexpected)

    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True)

    criterion = nn.CrossEntropyLoss().to(device)

    # optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
    #                       lr=args.lr, momentum=0.9, weight_decay=0.0, nesterov=False)
    param_groups = get_layerwise_lr_decay_params(
        model.module,  # 注意要传进去 unwrap 后的模型
        lr_base=args.lr,
        decay_rate=args.decay_rate,
        weight_decay=0
    )

    optimizer = optim.SGD(
        param_groups,
        momentum=0.9,
        nesterov=False
    )
    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=WarmupCosine(args.warmup_epochs, args.epochs)
    )
    scaler = GradScaler()

    # optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
    #                         lr=args.lr, weight_decay=0.01)
    # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    for epoch in range(args.epochs):
        train_sampler.set_epoch(epoch)
        model.train()
        running_loss, total, correct = 0.0, 0, 0

        for step, (imgs, labels) in enumerate(trainloader):
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            optimizer.zero_grad()
            # <<< 混合精度
            with autocast():
                outputs = model(imgs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * imgs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            if step % 20 == 0 and rank == 0:
                print(f"[Step {step}/{len(trainloader)}] Train Loss: {loss.item()}")

        if rank == 0:
            train_acc = correct / total
            print(f"[Epoch {epoch+1}] Train Loss: {running_loss/total:.4f}, Acc: {train_acc:.4f}")
        scheduler.step()
        if epoch % 5 == 0:
            evaluate(model, valloader, device, rank)
    evaluate(model, valloader, device, rank)
    dist.destroy_process_group()


def evaluate(model, dataloader, device, rank):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            outputs = model(imgs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    acc = correct / total
    if rank == 0:
        print(f"Validation Acc: {acc:.4f}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='tiny-imagenet')
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--img_size', type=int, default=224)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--lr', type=float, default=1e-2)
    parser.add_argument('--warmup_epochs', type=int, default=5, help="linear warmup epochs")
    parser.add_argument('--decay_rate', type=float, default=0.75)

    parser.add_argument('--parallel', action='store_true', help="Use parallel ViT model")
    parser.add_argument('--world_size', type=int, default=torch.cuda.device_count())
    parser.add_argument('--resume', type=str, default='', help="Path to checkpoint to resume")  # >>> 修改
    parser.add_argument('--master_addr', type=str, default='localhost', help="Master node address")
    parser.add_argument('--master_port', type=str, default='12355', help="Master node port")

    args = parser.parse_args()

    os.environ['MASTER_ADDR'] = args.master_addr
    os.environ['MASTER_PORT'] = args.master_port

    mp.spawn(train, args=(args.world_size, args), nprocs=args.world_size, join=True)


if __name__ == '__main__':
    main()
