# train.py
import os
import argparse
import pyvww

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.transforms import RandAugment
from timm.data.transforms import RandomResizedCropAndInterpolation

from timm.data import Mixup
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.transforms import RandAugment
from torch.cuda.amp import autocast, GradScaler

from models.vit_base import get_vit_base
from models.vit import get_vit_tiny
from models.vit_parallel_repa import get_parallel_vit, LambdaScheduler, CosineSimilarityMonitor
from loader import TinyImageNetDataset


# ==========V1===========
def flatten_params(module: torch.nn.Module):
    """把一个分支的所有参数 flatten 成一个向量"""
    return torch.cat([p.detach().flatten() for p in module.parameters()])


def cosine_similarity_matrix(branches):
    n = len(branches)
    sims = torch.zeros((n, n))
    vecs = [flatten_params(b) for b in branches]
    for i in range(n):
        for j in range(n):
            sims[i, j] = F.cosine_similarity(vecs[i], vecs[j], dim=0)
    return sims


# ==========V2===========
def flatten_linear_weights(linear_layer: torch.nn.Linear):
    """提取并展平Linear层的权重（包括偏置如果存在）"""
    weights = [linear_layer.weight.detach().flatten()]
    if linear_layer.bias is not None:
        weights.append(linear_layer.bias.detach().flatten())
    return torch.cat(weights)


def get_linear_layers_from_branch(branch):
    """从一个branch中提取所有Linear层"""
    linear_layers = []

    if hasattr(branch, 'q_proj'):  # Attention branch
        linear_layers.extend([
            ('q_proj', branch.q_proj),
            ('k_proj', branch.k_proj),
            ('v_proj', branch.v_proj),
            ('out_proj', branch.out_proj)
        ])
    else:  # MLP branch (Sequential)
        for i, layer in enumerate(branch):
            if isinstance(layer, torch.nn.Linear):
                linear_layers.append((f'linear_{i}', layer))

    return linear_layers


def cosine_similarity_matrix_per_layer(branches, layer_type="attention"):
    """计算每个线性层在不同branch之间的相似度矩阵"""
    n_branches = len(branches)

    # 获取所有branch的linear layers
    all_branch_layers = [get_linear_layers_from_branch(branch) for branch in branches]

    # 确保所有branch有相同的layer结构
    layer_names = [name for name, _ in all_branch_layers[0]]

    similarity_results = {}

    for layer_idx, layer_name in enumerate(layer_names):
        # 提取所有branch中对应layer的权重
        layer_weights = []
        for branch_layers in all_branch_layers:
            _, linear_layer = branch_layers[layer_idx]
            layer_weights.append(flatten_linear_weights(linear_layer))

        # 计算该layer在不同branch间的相似度矩阵
        sims = torch.zeros((n_branches, n_branches))
        for i in range(n_branches):
            for j in range(n_branches):
                sims[i, j] = F.cosine_similarity(layer_weights[i], layer_weights[j], dim=0)

        similarity_results[layer_name] = sims

    return similarity_results


class WarmupCosine:
    """
    Linear warm-up for `warmup_epochs`, then cosine decay until `max_epochs`.
    Works with PyTorch's LambdaLR.
    """
    def __init__(self, warmup_epochs: int, max_epochs: int):
        self.warmup = max(0, int(warmup_epochs))
        self.max_epochs = max_epochs

    def __call__(self, epoch: int):
        if epoch < self.warmup and self.warmup > 0:
            return float(epoch + 1) / float(self.warmup)
        # cosine from 1 -> 0
        t = (epoch - self.warmup) / max(1, (self.max_epochs - self.warmup))
        import math
        return 0.5 * (1.0 + math.cos(math.pi * t))


def get_layerwise_lr_decay_params(model, lr_base=1e-4, decay_rate=0.8, weight_decay=0.01):
    """
    返回按层衰减学习率的 param_groups
    适配自定义 ViT (model.blocks)
    """
    param_groups = []
    num_layers = len(model.blocks.blocks)
    print(num_layers)

    # Transformer blocks
    for i, layer in enumerate(model.blocks.blocks):
        lr = lr_base * (decay_rate ** (num_layers - 1 - i))
        param_groups.append({
            "params": layer.parameters(),
            "lr": lr,
            "weight_decay": weight_decay
        })

    # patch embedding 层（最小 lr）
    param_groups.append({
        "params": model.patch_embed.parameters(),
        "lr": lr_base * (decay_rate ** num_layers),
        "weight_decay": weight_decay
    })

    # 最后的 norm + head（通常最大 lr）
    param_groups.append({
        "params": list(model.norm.parameters()) + list(model.head.parameters()),
        "lr": lr_base,
        "weight_decay": weight_decay
    })

    return param_groups


def to_rgb(img):
    return img.convert("RGB")


def get_dataloaders(dataset_name, data_dir, batch_size, rank, world_size, img_size):
    normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                     std=(0.229, 0.224, 0.225))
    transform_train = transforms.Compose([
        RandomResizedCropAndInterpolation(img_size),
        to_rgb,
        transforms.RandomHorizontalFlip(),
        RandAugment(num_ops=10, magnitude=5),  # <<< torchvision 版本
        transforms.ColorJitter(0.3, 0.3, 0.3, 0.1),  # <<< color jitter=0.3
        transforms.ToTensor(),
        normalize,
    ])
    # 256 -> center crop 224 这类套路（跟随论文“resize到224×224”精神；更稳的是验证时 256/224）
    transform_val = transforms.Compose([
        transforms.Resize(int(img_size * 1.14)),
        to_rgb,
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        normalize,
    ])

    if dataset_name.lower() == 'tiny-imagenet':
        # trainset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform_train)
        trainset = TinyImageNetDataset(data_dir, "train", transform_train)
        valset = TinyImageNetDataset(data_dir, 'val', transform_val)
        num_classes = 200
    elif dataset_name.lower() == 'cifar10':
        trainset = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform_train)
        valset = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform_val)
        num_classes = 10
    elif dataset_name.lower() == 'cifar100':
        trainset = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform_train)
        valset = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform_val)
        num_classes = 100
    elif dataset_name == 'VWW':
        # /usr/homes/cxz760/data os.path.join('/fs/scratch/PDS0359/data/visualwakewords', 'annotations/instances_train.json'
        trainset = pyvww.pytorch.VisualWakeWordsClassification(root=os.path.join(data_dir, 'coco2014/all2014'),
                                                                    annFile=os.path.join(data_dir, 'visualwakewords/annotations/instances_train.json'),
                                                                    transform=transform_train)
        valset = pyvww.pytorch.VisualWakeWordsClassification(root=os.path.join(data_dir, 'coco2014/all2014'),
                                                               annFile=os.path.join(data_dir, 'visualwakewords/annotations/instances_val.json'),
                                                               transform=transform_val)
        num_classes = 2
    elif dataset_name.lower() == 'imagenet':
        # Assuming the ImageNet data is organized in standard ImageFolder structure
        train_dir = os.path.join(data_dir, 'train')
        val_dir = os.path.join(data_dir, 'val')

        trainset = datasets.ImageFolder(train_dir, transform=transform_train)
        valset = datasets.ImageFolder(val_dir, transform=transform_val)
        num_classes = 1000  # Standard ImageNet has 1000 classes
    elif dataset_name == 'inat18':
        # 注意 iNaturalist2018 类别特别多（8,142 类），文件很大
        trainset = datasets.INaturalist(data_dir, version='2018', transform=transform_train)
        valset = datasets.INaturalist(data_dir, version='2018', transform=transform_val)
        num_classes = 8142

    else:
        raise ValueError("Unsupported dataset. Choose from 'tiny-imagenet' or 'cifar10' or 'VWW'.")

    train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(valset, num_replicas=world_size, rank=rank, shuffle=False)

    trainloader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=10, pin_memory=True, drop_last=True)
    valloader = DataLoader(valset, batch_size=batch_size, sampler=val_sampler, num_workers=10, pin_memory=True, drop_last=True)
    return trainloader, valloader, train_sampler, num_classes


def train(rank, world_size, args):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    trainloader, valloader, train_sampler, num_classes = get_dataloaders(args.dataset, args.data_dir, args.batch_size, rank, world_size, args.img_size)

    # Load model
    model = get_parallel_vit(num_classes=num_classes, dropout=0,
                             attn_branches=args.attn_branches, mlp_branches=args.mlp_branches, depth=args.depth).to(device)

    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True)

    param_groups = get_layerwise_lr_decay_params(
        model.module,  # 注意要传进去 unwrap 后的模型
        lr_base=args.lr,
        decay_rate=0.85,
        weight_decay=0.05
    )

    optimizer = optim.AdamW(
        param_groups,
        betas=(0.9, 0.999),  # <<< 表 A1
        weight_decay=0.05
    )
    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=WarmupCosine(args.warmup_epochs, args.epochs)
    )

    criterion = SoftTargetCrossEntropy().to(device)  # <<< label smoothing
    mixup_fn = Mixup(mixup_alpha=0.2, cutmix_alpha=1.0, prob=1.0, label_smoothing=0.0, num_classes=num_classes)
    scaler = GradScaler()
    start_sft_epoch=0

    if args.lambda_policy:
        lambda_scheduler = LambdaScheduler(warmup_steps=args.warmup_step, mode=args.lambda_policy)  # 10k steps ramp
    # print(model)

    if args.resume and os.path.isfile(args.resume):
        if not args.resume_from_sft:
            checkpoint = None
            if dist.get_rank() == 0:
                checkpoint = torch.load(args.resume, map_location="cpu")
                # print("===", checkpoint["model"].keys())

            # 所有 rank 上都要有这个 list，不然 broadcast_object_list 会找不到变量
            obj_list = [checkpoint]
            dist.broadcast_object_list(obj_list, src=0)
            checkpoint = obj_list[0]

            # map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
            # checkpoint = torch.load(args.resume, map_location=map_location)
            state_dict = checkpoint["model"]
            # 只加载 encoder 权重
            encoder_state_dict = {}
            for k, v in state_dict.items():
                if k.startswith("module.model.patch_embed"):
                    new_k = k.replace("module.model.patch_embed", "module.patch_embed")
                    encoder_state_dict[new_k] = v
                elif k.startswith("module.model.blocks"):
                    # ParallelViT 是 module.blocks.blocks
                    new_k = k.replace("module.model.blocks", "module.blocks.blocks")
                    encoder_state_dict[new_k] = v
                elif k.startswith("module.model.cls_token"):
                    new_k = k.replace("module.model.cls_token", "module.cls_token")
                    encoder_state_dict[new_k] = v
                elif k.startswith("module.model.pos_embed"):
                    new_k = k.replace("module.model.pos_embed", "module.pos_embed")
                    encoder_state_dict[new_k] = v
                elif k.startswith("module.model.norm"):
                    new_k = k.replace("module.model.norm", "module.norm")
                    encoder_state_dict[new_k] = v
                # 过滤掉 decoder / tch_model / distill_criterion / adapter
                else:
                    continue

            missing, unexpected = model.load_state_dict(encoder_state_dict, strict=False)
            if rank == 0:
                start_epoch = checkpoint.get("start_epoch", 0)
                print(f"=> Loaded checkpoint '{args.resume}' (only encoder) at epoch '{start_epoch}'")
                if missing: print("Missing keys:", missing)
                if unexpected: print("Unexpected keys:", unexpected)
        else:
            ckpt = torch.load(args.resume, map_location="cpu")

            model.load_state_dict(ckpt["model"])
            optimizer.load_state_dict(ckpt["optimizer"])
            # scheduler.load_state_dict(ckpt["scheduler"])
            scaler.load_state_dict(ckpt["scaler"])
            # start_sft_epoch = ckpt["start_epoch"]
            print("Resumed and start with epoch:", start_sft_epoch)

    # monitor = CosineSimilarityMonitor()
    monitor = None

    # # ======v1===========
    # for layer_idx, block in enumerate(model.module.blocks.blocks):
    #     if hasattr(block, "attn"):
    #         sims_attn = cosine_similarity_matrix(block.attn.branches)
    #         if rank == 0:
    #             print(f"Layer {layer_idx} Attention similarity:\n", sims_attn)
    #
    #     if hasattr(block, "mlp"):
    #         sims_mlp = cosine_similarity_matrix(block.mlp.branches)
    #         if rank == 0:
    #             print(f"Layer {layer_idx} MLP similarity:\n", sims_mlp)

    # ======v2===========
    for layer_idx, block in enumerate(model.module.blocks.blocks):
        if hasattr(block, "attn"):
            sims_attn_per_layer = cosine_similarity_matrix_per_layer(block.attn.branches, "attention")
            if rank == 0:
                print(f"\nLayer {layer_idx} Attention - Per Linear Layer Similarity:")
                for layer_name, sim_matrix in sims_attn_per_layer.items():
                    print(f"  {layer_name}:")
                    print(f"    {sim_matrix}")

        if hasattr(block, "mlp"):
            sims_mlp_per_layer = cosine_similarity_matrix_per_layer(block.mlp.branches, "mlp")
            if rank == 0:
                print(f"\nLayer {layer_idx} MLP - Per Linear Layer Similarity:")
                for layer_name, sim_matrix in sims_mlp_per_layer.items():
                    print(f"  {layer_name}:")
                    print(f"    {sim_matrix}")


    for epoch in range(start_sft_epoch, args.epochs):
        train_sampler.set_epoch(epoch)
        model.train()
        running_loss, total, correct = 0.0, 0, 0

        for step, (imgs, labels) in enumerate(trainloader):
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            # <<< mixup
            imgs, labels = mixup_fn(imgs, labels)

            optimizer.zero_grad()
            if args.lambda_policy:
                lambda_off = lambda_scheduler.get_lambda(epoch + step/len(trainloader))
            else:
                lambda_off = args.lambda_off
                # forward with lambda_off
            # <<< 混合精度
            with autocast():
                if args.get_diversity_loss:
                    outputs, diversity_loss = model.forward(imgs, lambda_off, get_diversity_loss=True, monitor=monitor)
                else:
                    diversity_loss = torch.tensor(0.0, device=device)
                    outputs = model.forward(imgs, lambda_off, monitor=monitor)
                loss = criterion(outputs, labels)

            loss = loss + args.gamma * diversity_loss

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * imgs.size(0)
            total += labels.size(0)
            # _, predicted = outputs.max(1)
            # correct += predicted.eq(labels).sum().item()
            _, predicted = outputs.max(1)
            acc_targets = labels.argmax(dim=1) if labels.ndim == 2 else labels
            correct += predicted.eq(acc_targets).sum().item()

            # if rank == 0:
            #     print(f"记录的注意力层相似度数量: {len(monitor.attn_similarities)}")
            #     print(f"记录的MLP层相似度数量: {len(monitor.mlp_similarities)}")
            #
            #     # 打印详细的相似度信息
            #     print("\n=== 详细相似度信息 ===")
            #     for i, sim_info in enumerate(monitor.attn_similarities):
            #         print(f"Layer {sim_info['layer_idx']} Attention - "
            #               f"Mean: {sim_info['mean_similarity']:.4f}, "
            #               f"Min: {sim_info['min_similarity']:.4f}, "
            #               f"Max: {sim_info['max_similarity']:.4f}, "
            #               f"Std: {sim_info['std_similarity']:.4f}")
            #         print(f"  Similarity matrix:\n{sim_info['similarity_matrix']}")
            #
            #     for i, sim_info in enumerate(monitor.mlp_similarities):
            #         print(f"Layer {sim_info['layer_idx']} MLP - "
            #               f"Mean: {sim_info['mean_similarity']:.4f}, "
            #               f"Min: {sim_info['min_similarity']:.4f}, "
            #               f"Max: {sim_info['max_similarity']:.4f}, "
            #               f"Std: {sim_info['std_similarity']:.4f}")
            #         print(f"  Similarity matrix:\n{sim_info['similarity_matrix']}")
            #
            #     # 打印汇总统计
            #     monitor.print_summary()


            if step % 20 == 0 and rank == 0:
                print(f"[Step {step}/{len(trainloader)}] Train Loss: {loss.item()}, Div Loss:{diversity_loss.item()}, Lambda: {lambda_off}")

        if rank == 0:
            train_acc = correct / total
            print(f"[Epoch {epoch+1}] Train Loss: {running_loss/total:.4f}, Acc: {train_acc:.4f}")
        scheduler.step()
        if epoch % 20 == 0:
            evaluate(model, valloader, device, rank, args.multi_branch)

        if epoch % 10 == 0 and rank == 0:
            ckpt = {
                "start_epoch": epoch + 1,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "scaler": scaler.state_dict()
            }
            if not os.path.exists(args.save):
                os.makedirs(args.save)
            filename = os.path.join(args.save, "latest_ckpt.pth.tar")
            torch.save(ckpt, filename)

    evaluate(model, valloader, device, rank, args.multi_branch)
    dist.destroy_process_group()


def evaluate(model, dataloader, device, rank, multi_branch=False):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            if multi_branch:
                outputs = model(imgs, 0)
            else:
                outputs = model(imgs, 1)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    acc = correct / total
    if rank == 0:
        print(f"Validation Acc: {acc:.4f}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='tiny-imagenet')
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--img_size', type=int, default=224)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--lr', type=float, default=2e-3)
    parser.add_argument('--warmup_epochs', type=int, default=5, help="linear warmup epochs")

    parser.add_argument('--parallel', action='store_true', help="Use parallel ViT model")
    parser.add_argument('--multi-branch', action='store_true', help="Fine tuning on baseline multi-branch model")
    parser.add_argument('--lambda_off', type=float, default=0)
    parser.add_argument('--lambda_policy', type=str, default=None, help="Lambda policy, including linear/cosine/exponential/sqrt/sine/smoothstep")
    parser.add_argument('--world_size', type=int, default=torch.cuda.device_count())
    parser.add_argument('--warmup_step', type=int, default=30)
    parser.add_argument('--attn_branches', type=int, default=2)
    parser.add_argument('--mlp_branches', type=int, default=2)
    parser.add_argument('--depth', type=int, default=6)
    parser.add_argument('--get_diversity_loss', action='store_true')
    parser.add_argument('--gamma', type=float, default=0.05)

    parser.add_argument('--resume', type=str, default='', help="Path to checkpoint to resume")
    parser.add_argument('--resume_from_sft', action='store_true')
    parser.add_argument('--save', type=str, default='', help="Path to checkpoint to save")
    parser.add_argument('--master_addr', type=str, default='localhost', help="Master node address")
    parser.add_argument('--master_port', type=str, default='12355', help="Master node port")

    args = parser.parse_args()

    os.environ['MASTER_ADDR'] = args.master_addr
    os.environ['MASTER_PORT'] = args.master_port

    mp.spawn(train, args=(args.world_size, args), nprocs=args.world_size, join=True)


if __name__ == '__main__':
    main()
