# train.py
import os
import argparse
import pyvww

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.transforms import RandAugment
from timm.data.transforms import RandomResizedCropAndInterpolation

from timm.data import Mixup
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.transforms import RandAugment
from torch.cuda.amp import autocast, GradScaler

from models.vit_base import get_vit_base
from models.vit import get_vit_tiny
from models.vit_parallel_repa import get_parallel_vit, LambdaScheduler
from loader import TinyImageNetDataset


def load_checkpoint(path, map_location="cpu"):
    try:
        # 先尝试用默认方式（安全）
        checkpoint = torch.load(path, map_location=map_location)
    except Exception as e:
        print(f"[WARN] Safe load failed ({e}), retrying with weights_only=False")
        # 如果失败，再强制关闭 weights_only
        checkpoint = torch.load(path, map_location=map_location, weights_only=False)

    # 如果只是 state_dict（常见于 .pth）
    if isinstance(checkpoint, dict) and "model" not in checkpoint:
        checkpoint = {"model": checkpoint}

    return checkpoint

class WarmupCosine:
    """
    Linear warm-up for `warmup_epochs`, then cosine decay until `max_epochs`.
    Works with PyTorch's LambdaLR.
    """
    def __init__(self, warmup_epochs: int, max_epochs: int):
        self.warmup = max(0, int(warmup_epochs))
        self.max_epochs = max_epochs

    def __call__(self, epoch: int):
        if epoch < self.warmup and self.warmup > 0:
            return float(epoch + 1) / float(self.warmup)
        # cosine from 1 -> 0
        t = (epoch - self.warmup) / max(1, (self.max_epochs - self.warmup))
        import math
        return 0.5 * (1.0 + math.cos(math.pi * t))


def get_layerwise_lr_decay_params(model, lr_base=1e-4, decay_rate=0.8, weight_decay=0.01):
    """
    返回按层衰减学习率的 param_groups
    适配自定义 ViT (model.blocks)
    """
    param_groups = []
    num_layers = len(model.blocks.blocks)
    print(num_layers)

    # Transformer blocks
    for i, layer in enumerate(model.blocks.blocks):
        lr = lr_base * (decay_rate ** (num_layers - 1 - i))
        param_groups.append({
            "params": layer.parameters(),
            "lr": lr,
            "weight_decay": weight_decay
        })

    # patch embedding 层（最小 lr）
    param_groups.append({
        "params": model.patch_embed.parameters(),
        "lr": lr_base * (decay_rate ** num_layers),
        "weight_decay": weight_decay
    })

    # 最后的 norm + head（通常最大 lr）
    param_groups.append({
        "params": list(model.norm.parameters()) + list(model.head.parameters()),
        "lr": lr_base,
        "weight_decay": weight_decay
    })

    return param_groups


def to_rgb(img):
    return img.convert("RGB")


def get_dataloaders(dataset_name, data_dir, batch_size, rank, world_size, img_size):
    normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                     std=(0.229, 0.224, 0.225))

    transform_train = transforms.Compose([
        RandomResizedCropAndInterpolation(img_size),
        to_rgb,
        transforms.RandomHorizontalFlip(),
        RandAugment(num_ops=10, magnitude=5),  # <<< torchvision 版本
        transforms.ColorJitter(0.3, 0.3, 0.3, 0.1),  # <<< color jitter=0.3
        transforms.ToTensor(),
        normalize,
    ])
    # 256 -> center crop 224 这类套路（跟随论文“resize到224×224”精神；更稳的是验证时 256/224）
    transform_val = transforms.Compose([
        transforms.Resize(int(img_size * 1.14)),
        to_rgb,
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        normalize,
    ])

    if dataset_name.lower() == 'imagenet':
        # Assuming the ImageNet data is organized in standard ImageFolder structure
        train_dir = os.path.join(data_dir, 'train')
        val_dir = os.path.join(data_dir, 'val')

        trainset = datasets.ImageFolder(train_dir, transform=transform_train)
        valset = datasets.ImageFolder(val_dir, transform=transform_val)
        num_classes = 1000  # Standard ImageNet has 1000 classes
    elif dataset_name == 'inat18':
        # 注意 iNaturalist2018 类别特别多（8,142 类），文件很大
        trainset = datasets.INaturalist(data_dir, version='2018', transform=transform_train)
        valset = datasets.INaturalist(data_dir, version='2018', transform=transform_val)
        num_classes = 8142
    else:
        raise ValueError("Unsupported dataset. Choose from 'tiny-imagenet' or 'cifar10' or 'VWW'.")

    train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(valset, num_replicas=world_size, rank=rank, shuffle=False)

    trainloader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=10, pin_memory=True, drop_last=True)
    valloader = DataLoader(valset, batch_size=batch_size, sampler=val_sampler, num_workers=10, pin_memory=True, drop_last=True)
    return trainloader, valloader, train_sampler, num_classes


def train(rank, world_size, args):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    trainloader, valloader, train_sampler, num_classes = get_dataloaders(args.dataset, args.data_dir, args.batch_size, rank, world_size, args.img_size)

    # Load model
    model = get_parallel_vit(num_classes=num_classes, dropout=0,
                             attn_branches=args.attn_branches, mlp_branches=args.mlp_branches, depth=args.depth).to(device)

    # # ====== Linear Probing for ImageNet ======
    # if args.dataset.lower() == "imagenet":
    #     if rank == 0:
    #         print(">>> Using Linear Probing (freeze backbone, train classifier only) <<<")
    #     for name, param in model.named_parameters():
    #         if "head" not in name:   # 只让分类头训练
    #             param.requires_grad = False

    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True)

    param_groups = get_layerwise_lr_decay_params(
        model.module,  # 注意要传进去 unwrap 后的模型
        lr_base=args.lr,
        decay_rate=0.85,
        weight_decay=0.05
    )

    optimizer = optim.AdamW(
        param_groups,
        # model.parameters(),
        betas=(0.9, 0.999),  # <<< 表 A1
        weight_decay=0.05
    )
    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=WarmupCosine(args.warmup_epochs, args.epochs)
    )

    criterion = SoftTargetCrossEntropy().to(device)  # <<< label smoothing
    mixup_fn = Mixup(mixup_alpha=0.2, cutmix_alpha=1.0, prob=1.0, label_smoothing=0.0, num_classes=num_classes)
    scaler = GradScaler()
    start_sft_epoch=0

    if args.lambda_policy:
        lambda_scheduler = LambdaScheduler(warmup_steps=args.warmup_step, mode=args.lambda_policy)  # 10k steps ramp
    # print(model)

    if args.resume and os.path.isfile(args.resume):
        if not args.resume_from_sft:
            checkpoint = None
            if dist.get_rank() == 0:
                checkpoint = load_checkpoint(args.resume, map_location="cpu")
                # print("===", checkpoint["model"].keys())

            # 所有 rank 上都要有这个 list，不然 broadcast_object_list 会找不到变量
            obj_list = [checkpoint]
            dist.broadcast_object_list(obj_list, src=0)
            checkpoint = obj_list[0]

            # map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
            # checkpoint = torch.load(args.resume, map_location=map_location)
            state_dict = checkpoint["model"]
            # 只加载 encoder 权重
            encoder_state_dict = {}
            for k, v in state_dict.items():
                if k.startswith("head."):  # 跳过分类头
                    continue
                if not k.startswith("module."):
                    k = 'module.' + k
                encoder_state_dict[k] = v

            missing, unexpected = model.load_state_dict(encoder_state_dict, strict=False)

            if rank == 0:
                start_epoch = checkpoint.get("start_epoch", 0)
                print(f"=> Loaded checkpoint '{args.resume}' (only encoder) at epoch '{start_epoch}'")
                if missing: print("Missing keys:", missing)
                if unexpected: print("Unexpected keys:", unexpected)
        else:
            ckpt = torch.load(args.resume, map_location="cpu")

            model.load_state_dict(ckpt["model"])
            # optimizer.load_state_dict(ckpt["optimizer"])
            # scheduler.load_state_dict(ckpt["scheduler"])
            # scaler.load_state_dict(ckpt["scaler"])
            start_sft_epoch = ckpt["start_epoch"]
            print("Resumed and start with epoch:", start_sft_epoch)
            start_sft_epoch = 0

    for epoch in range(start_sft_epoch, args.epochs):
        train_sampler.set_epoch(epoch)
        model.train()
        running_loss, total, correct = 0.0, 0, 0

        for step, (imgs, labels) in enumerate(trainloader):
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            # <<< mixup
            imgs, labels = mixup_fn(imgs, labels)

            optimizer.zero_grad()
            if args.lambda_policy:
                lambda_off = lambda_scheduler.get_lambda(epoch + step/len(trainloader))
            else:
                lambda_off = args.lambda_off
                # forward with lambda_off
            # <<< 混合精度
            with autocast():
                if args.get_diversity_loss:
                    outputs, diversity_loss = model.forward(imgs, lambda_off, get_diversity_loss=True)
                else:
                    diversity_loss = torch.tensor(0.0, device=device)
                    outputs = model.forward(imgs, lambda_off)
                loss = criterion(outputs, labels)

            loss = loss + args.gamma * diversity_loss

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * imgs.size(0)
            total += labels.size(0)
            # _, predicted = outputs.max(1)
            # correct += predicted.eq(labels).sum().item()
            _, predicted = outputs.max(1)
            acc_targets = labels.argmax(dim=1) if labels.ndim == 2 else labels
            correct += predicted.eq(acc_targets).sum().item()
            if step % 20 == 0 and rank == 0:
                print(f"[Step {step}/{len(trainloader)}] Train Loss: {loss.item()}, Div Loss:{diversity_loss.item()}, Lambda: {lambda_off}")

        if rank == 0:
            train_acc = correct / total
            print(f"[Epoch {epoch+1}] Train Loss: {running_loss/total:.4f}, Acc: {train_acc:.4f}")
        scheduler.step()
        if epoch % 20 == 0:
            evaluate(model, valloader, device, rank, args.multi_branch)

        if epoch % 10 == 0 and rank == 0:
            ckpt = {
                "start_epoch": epoch + 1,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "scaler": scaler.state_dict()
            }
            if not os.path.exists(args.save):
                os.makedirs(args.save)
            filename = os.path.join(args.save, "latest_ckpt.pth.tar")
            torch.save(ckpt, filename)

    evaluate(model, valloader, device, rank, args.multi_branch)
    dist.destroy_process_group()


def evaluate(model, dataloader, device, rank, multi_branch=False):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            if multi_branch:
                outputs = model(imgs, 0)
            else:
                outputs = model(imgs, 1)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    acc = correct / total
    if rank == 0:
        print(f"Validation Acc: {acc:.4f}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='imagenet')
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--img_size', type=int, default=224)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--lr', type=float, default=2e-3)
    parser.add_argument('--warmup_epochs', type=int, default=5, help="linear warmup epochs")

    parser.add_argument('--parallel', action='store_true', help="Use parallel ViT model")
    parser.add_argument('--multi-branch', action='store_true', help="Fine tuning on baseline multi-branch model")
    parser.add_argument('--lambda_off', type=float, default=0)
    parser.add_argument('--lambda_policy', type=str, default=None, help="Lambda policy, including linear/cosine/exponential/sqrt/sine/smoothstep")
    parser.add_argument('--world_size', type=int, default=torch.cuda.device_count())
    parser.add_argument('--warmup_step', type=int, default=30)
    parser.add_argument('--attn_branches', type=int, default=2)
    parser.add_argument('--mlp_branches', type=int, default=2)
    parser.add_argument('--depth', type=int, default=6)
    parser.add_argument('--get_diversity_loss', action='store_true')
    parser.add_argument('--gamma', type=float, default=0.05)

    parser.add_argument('--resume', type=str, default='', help="Path to checkpoint to resume")
    parser.add_argument('--resume_from_sft', action='store_true')
    parser.add_argument('--save', type=str, default='', help="Path to checkpoint to save")
    parser.add_argument('--master_addr', type=str, default='localhost', help="Master node address")
    parser.add_argument('--master_port', type=str, default='12355', help="Master node port")

    args = parser.parse_args()

    os.environ['MASTER_ADDR'] = args.master_addr
    os.environ['MASTER_PORT'] = args.master_port

    mp.spawn(train, args=(args.world_size, args), nprocs=args.world_size, join=True)


if __name__ == '__main__':
    main()
