# train.py
import os
import argparse
import pyvww
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, DistributedSampler
from models.vit_base import get_vit_base
from models.vit import get_vit_tiny
from models.vit_parallel_repa import get_parallel_vit, LambdaScheduler
from loader import TinyImageNetDataset
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F


class WarmupCosine:
    """
    Linear warm-up for `warmup_epochs`, then cosine decay until `max_epochs`.
    Works with PyTorch's LambdaLR.
    """
    def __init__(self, warmup_epochs: int, max_epochs: int):
        self.warmup = max(0, int(warmup_epochs))
        self.max_epochs = max_epochs

    def __call__(self, epoch: int):
        if epoch < self.warmup and self.warmup > 0:
            return float(epoch + 1) / float(self.warmup)
        # cosine from 1 -> 0
        t = (epoch - self.warmup) / max(1, (self.max_epochs - self.warmup))
        import math
        return 0.5 * (1.0 + math.cos(math.pi * t))


def get_layerwise_lr_decay_params(model, lr_base=1e-4, decay_rate=0.8, weight_decay=0.01):
    """
    返回按层衰减学习率的 param_groups
    适配自定义 ViT (model.blocks)
    """
    param_groups = []
    num_layers = len(model.blocks.blocks)
    print(num_layers)

    # Transformer blocks
    for i, layer in enumerate(model.blocks.blocks):
        lr = lr_base * (decay_rate ** (num_layers - 1 - i))
        param_groups.append({
            "params": layer.parameters(),
            "lr": lr,
            "weight_decay": weight_decay
        })

    # patch embedding 层（最小 lr）
    param_groups.append({
        "params": model.patch_embed.parameters(),
        "lr": lr_base * (decay_rate ** num_layers),
        "weight_decay": weight_decay
    })

    # 最后的 norm + head（通常最大 lr）
    param_groups.append({
        "params": list(model.norm.parameters()) + list(model.head.parameters()),
        "lr": lr_base,
        "weight_decay": weight_decay
    })

    return param_groups



def get_dataloaders(dataset_name, data_dir, batch_size, rank, world_size, img_size):
    normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                     std=(0.229, 0.224, 0.225))
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    transform_val = transforms.Compose([
        transforms.Resize(int(img_size * 1.14)),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        normalize,
    ])

    dataset_name = dataset_name.lower()

    if dataset_name == 'tiny-imagenet':
        trainset = TinyImageNetDataset(data_dir, "train", transform_train)
        valset = TinyImageNetDataset(data_dir, "val", transform_val)
        num_classes = 200

    elif dataset_name == 'cifar10':
        trainset = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform_train)
        valset = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform_val)
        num_classes = 10

    elif dataset_name == 'cifar100':
        trainset = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform_train)
        valset = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform_val)
        num_classes = 100

    elif dataset_name == 'vww':
        trainset = pyvww.pytorch.VisualWakeWordsClassification(
            root=os.path.join(data_dir, 'coco2014/all2014'),
            annFile=os.path.join(data_dir, 'visualwakewords/annotations/instances_train.json'),
            transform=transform_train)
        valset = pyvww.pytorch.VisualWakeWordsClassification(
            root=os.path.join(data_dir, 'coco2014/all2014'),
            annFile=os.path.join(data_dir, 'visualwakewords/annotations/instances_val.json'),
            transform=transform_val)
        num_classes = 2

    elif dataset_name == 'imagenet':
        train_dir = os.path.join(data_dir, 'train')
        val_dir = os.path.join(data_dir, 'val')
        trainset = datasets.ImageFolder(train_dir, transform=transform_train)
        valset = datasets.ImageFolder(val_dir, transform=transform_val)
        num_classes = 1000

    elif dataset_name == 'flowers':
        trainset = datasets.Flowers102(data_dir, split='train', download=True, transform=transform_train)
        valset = datasets.Flowers102(data_dir, split='test', download=True, transform=transform_val)
        num_classes = 102

    elif dataset_name == 'pets':
        trainset = datasets.OxfordIIITPet(data_dir, split='trainval', download=True, transform=transform_train)
        valset = datasets.OxfordIIITPet(data_dir, split='test', download=True, transform=transform_val)
        num_classes = 37

    elif dataset_name == 'aircraft':
        trainset = datasets.FGVCAircraft(data_dir, split='train', download=True, transform=transform_train)
        valset = datasets.FGVCAircraft(data_dir, split='test', download=True, transform=transform_val)
        num_classes = 100

    elif dataset_name == 'cars':
        trainset = datasets.StanfordCars(data_dir, split='train', transform=transform_train)
        valset = datasets.StanfordCars(data_dir, split='test', transform=transform_val)
        num_classes = 196

    elif dataset_name == 'inat18':
        # 注意 iNaturalist2018 类别特别多（8,142 类），文件很大
        trainset = datasets.INaturalist(data_dir, version='2018', transform=transform_train)
        valset = datasets.INaturalist(data_dir, version='2018', transform=transform_val)
        num_classes = 8142

    else:
        raise ValueError(f"Unsupported dataset {dataset_name}")

    # 分布式采样器
    train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(valset, num_replicas=world_size, rank=rank, shuffle=False)

    trainloader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=8, pin_memory=True)
    valloader = DataLoader(valset, batch_size=batch_size, sampler=val_sampler, num_workers=8, pin_memory=True)

    return trainloader, valloader, train_sampler, num_classes


def train(rank, world_size, args):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    trainloader, valloader, train_sampler, num_classes = get_dataloaders(args.dataset, args.data_dir, args.batch_size, rank, world_size, args.img_size)

    # Load model
    model = get_parallel_vit(num_classes=num_classes, dropout=0,
                             attn_branches=args.attn_branches, mlp_branches=args.mlp_branches, depth=args.depth).to(device)

    # ====== Linear Probing for ImageNet ======
    if args.dataset.lower() == "imagenet":
        if rank == 0:
            print(">>> Using Linear Probing (freeze backbone, train classifier only) <<<")
        for name, param in model.named_parameters():
            if "head" not in name:   # 只让分类头训练
                param.requires_grad = False

    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True)

    criterion = nn.CrossEntropyLoss().to(device)
    # optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
    #                       lr=args.lr, momentum=0.9, weight_decay=0.0, nesterov=False)
    param_groups = get_layerwise_lr_decay_params(
        model.module,  # 注意要传进去 unwrap 后的模型
        lr_base=args.lr,
        decay_rate=args.decay_rate,
        weight_decay=0
    )

    optimizer = optim.SGD(
        param_groups,
        momentum=0.9,
        nesterov=False
    )
    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=WarmupCosine(args.warmup_epochs, args.epochs)
    )
    scaler = GradScaler()
    if args.lambda_policy:
        lambda_scheduler = LambdaScheduler(warmup_steps=args.warmup_step, mode=args.lambda_policy, start_steps=args.start_step)  # 10k steps ramp
    # print(model)

    if args.resume and os.path.isfile(args.resume):
        checkpoint = None
        if dist.get_rank() == 0:
            checkpoint = torch.load(args.resume, map_location="cpu")
            # print("===", checkpoint["model"].keys())

        # 所有 rank 上都要有这个 list，不然 broadcast_object_list 会找不到变量
        obj_list = [checkpoint]
        dist.broadcast_object_list(obj_list, src=0)
        checkpoint = obj_list[0]

        # map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
        # checkpoint = torch.load(args.resume, map_location=map_location)
        state_dict = checkpoint["model"]
        # 只加载 encoder 权重
        encoder_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith("module.model.patch_embed"):
                new_k = k.replace("module.model.patch_embed", "module.patch_embed")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.blocks"):
                # ParallelViT 是 module.blocks.blocks
                new_k = k.replace("module.model.blocks", "module.blocks.blocks")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.cls_token"):
                new_k = k.replace("module.model.cls_token", "module.cls_token")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.pos_embed"):
                new_k = k.replace("module.model.pos_embed", "module.pos_embed")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.norm"):
                new_k = k.replace("module.model.norm", "module.norm")
                encoder_state_dict[new_k] = v
            # 过滤掉 decoder / tch_model / distill_criterion / adapter
            else:
                continue

        missing, unexpected = model.load_state_dict(encoder_state_dict, strict=False)

        if rank == 0:
            start_epoch = checkpoint.get("start_epoch", 0)
            print(f"=> Loaded checkpoint '{args.resume}' (only encoder) at epoch '{start_epoch}'")
            if missing: print("Missing keys:", missing)
            if unexpected: print("Unexpected keys:", unexpected)

    for epoch in range(args.epochs):
        train_sampler.set_epoch(epoch)
        model.train()
        running_loss, total, correct = 0.0, 0, 0

        for step, (imgs, labels) in enumerate(trainloader):
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            optimizer.zero_grad()
            if args.lambda_policy:
                lambda_off = lambda_scheduler.get_lambda(epoch + step/len(trainloader))
            else:
                lambda_off = args.lambda_off
                # forward with lambda_off
            # <<< 混合精度
            with autocast():
                if args.get_diversity_loss:
                    outputs, diversity_loss = model.forward(imgs, lambda_off, get_diversity_loss=True)
                else:
                    diversity_loss = torch.tensor(0.0, device=device)
                    outputs = model.forward(imgs, lambda_off)
                loss = criterion(outputs, labels)

            loss = loss + args.gamma * diversity_loss
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * imgs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            if step % 20 == 0 and rank == 0:
                print(f"[Step {step}/{len(trainloader)}] Train Loss: {loss.item()}, Div Loss:{diversity_loss.item()}, Lambda: {lambda_off}")

        if rank == 0:
            train_acc = correct / total
            print(f"[Epoch {epoch+1}] Train Loss: {running_loss/total:.4f}, Acc: {train_acc:.4f}")
        scheduler.step()
        if epoch % 5 == 0:
            evaluate(model, valloader, device, rank, args.multi_branch)
    evaluate(model, valloader, device, rank, args.multi_branch)
    dist.destroy_process_group()


def evaluate(model, dataloader, device, rank, multi_branch=False):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            if multi_branch:
                outputs = model(imgs, 0)
            else:
                outputs = model(imgs, 1)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    acc = correct / total
    if rank == 0:
        print(f"Validation Acc: {acc:.4f}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='tiny-imagenet')
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--img_size', type=int, default=224)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--warmup_epochs', type=int, default=5, help="linear warmup epochs")
    parser.add_argument('--decay_rate', type=float, default=0.75)

    parser.add_argument('--parallel', action='store_true', help="Use parallel ViT model")
    parser.add_argument('--multi-branch', action='store_true', help="Fine tuning on baseline multi-branch model")
    parser.add_argument('--lambda_off', type=float, default=0)
    parser.add_argument('--lambda_policy', type=str, default=None, help="Lambda policy, including linear/cosine/exponential/sqrt/sine/smoothstep")
    parser.add_argument('--start_step', type=int, default=0)
    parser.add_argument('--warmup_step', type=int, default=30)
    parser.add_argument('--attn_branches', type=int, default=2)
    parser.add_argument('--mlp_branches', type=int, default=2)
    parser.add_argument('--depth', type=int, default=6)
    parser.add_argument('--get_diversity_loss', action='store_true')
    parser.add_argument('--gamma', type=float, default=0.05)

    parser.add_argument('--resume', type=str, default='', help="Path to checkpoint to resume")
    parser.add_argument('--master_addr', type=str, default='localhost', help="Master node address")
    parser.add_argument('--master_port', type=str, default='12355', help="Master node port")
    parser.add_argument('--world_size', type=int, default=torch.cuda.device_count())

    args = parser.parse_args()

    os.environ['MASTER_ADDR'] = args.master_addr
    os.environ['MASTER_PORT'] = args.master_port

    mp.spawn(train, args=(args.world_size, args), nprocs=args.world_size, join=True)


if __name__ == '__main__':
    main()
