import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
from models.vit import get_vit_tiny
from models.vit_multi_branch import get_multibranch_vit
import numpy as np
from tqdm import tqdm


class FeatureExtractor(nn.Module):
    """将预训练ViT的head替换为identity，只输出特征"""

    def __init__(self, vit_model):
        super().__init__()
        self.model = vit_model
        # 将head替换为identity
        self.model.head = nn.Identity()

    def forward(self, x):
        return self.model(x)


class LinearClassifier(nn.Module):
    """单层线性分类器"""

    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.classifier = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.classifier(x)


def get_imagenet_dataloaders(data_dir, batch_size, rank, world_size, img_size=224):
    """获取ImageNet数据加载器"""
    normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

    transform_train = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomCrop(img_size, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    transform_val = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        normalize,
    ])

    train_dir = os.path.join(data_dir, 'train')
    val_dir = os.path.join(data_dir, 'val')

    trainset = datasets.ImageFolder(train_dir, transform=transform_train)
    valset = datasets.ImageFolder(val_dir, transform=transform_val)

    train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(valset, num_replicas=world_size, rank=rank, shuffle=False)

    trainloader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler,
                             num_workers=4, pin_memory=True)
    valloader = DataLoader(valset, batch_size=batch_size, sampler=val_sampler,
                           num_workers=4, pin_memory=True)

    return trainloader, valloader, train_sampler, 1000  # ImageNet has 1000 classes


def extract_features(rank, world_size, args):
    """第一阶段：提取特征并保存"""
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    if rank == 0:
        print("=== Stage 1: Feature Extraction ===")

    # 获取数据加载器
    trainloader, valloader, train_sampler, num_classes = get_imagenet_dataloaders(
        args.data_dir, args.batch_size, rank, world_size, args.img_size
    )

    # 加载预训练模型
    if args.parallel:
        model = get_multibranch_vit(num_classes=num_classes)
    else:
        model = get_vit_tiny(num_classes=num_classes, pretrained=False)

    # 加载预训练权重
    if args.resume and os.path.isfile(args.resume):
        checkpoint = None
        if rank == 0:
            checkpoint = torch.load(args.resume, map_location="cpu")

        obj_list = [checkpoint]
        dist.broadcast_object_list(obj_list, src=0)
        checkpoint = obj_list[0]

        if checkpoint is None:
            raise RuntimeError("Failed to load checkpoint")

        state_dict = checkpoint["model"]
        # 加载encoder权重（与原代码相同的逻辑）
        encoder_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith("module.model.patch_embed"):
                new_k = k.replace("module.model.patch_embed", "patch_embed")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.blocks"):
                new_k = k.replace("module.model.blocks", "blocks")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.cls_token"):
                new_k = k.replace("module.model.cls_token", "cls_token")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.pos_embed"):
                new_k = k.replace("module.model.pos_embed", "pos_embed")
                encoder_state_dict[new_k] = v
            elif k.startswith("module.model.norm"):
                new_k = k.replace("module.model.norm", "norm")
                encoder_state_dict[new_k] = v

        missing, unexpected = model.load_state_dict(encoder_state_dict, strict=False)
        if rank == 0:
            print(f"Loaded checkpoint '{args.resume}' (encoder only)")
            if missing: print("Missing keys:", missing)
            if unexpected: print("Unexpected keys:", unexpected)

    # 创建特征提取器
    feature_extractor = FeatureExtractor(model).to(device)
    feature_extractor = nn.parallel.DistributedDataParallel(feature_extractor, device_ids=[rank])
    feature_extractor.eval()

    # 提取训练集特征
    if rank == 0:
        print("Extracting training features...")

    train_features_list = []
    train_labels_list = []

    with torch.no_grad():
        for batch_idx, (imgs, labels) in enumerate(tqdm(trainloader, disable=(rank != 0))):
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            features = feature_extractor(imgs)

            # 收集所有GPU的特征和标签
            gathered_features = [torch.zeros_like(features) for _ in range(world_size)]
            gathered_labels = [torch.zeros_like(labels) for _ in range(world_size)]

            dist.all_gather(gathered_features, features)
            dist.all_gather(gathered_labels, labels)

            if rank == 0:  # 只在rank 0上保存
                for f, l in zip(gathered_features, gathered_labels):
                    train_features_list.append(f.cpu())
                    train_labels_list.append(l.cpu())

    # 提取验证集特征
    if rank == 0:
        print("Extracting validation features...")

    val_features_list = []
    val_labels_list = []

    with torch.no_grad():
        for batch_idx, (imgs, labels) in enumerate(tqdm(valloader, disable=(rank != 0))):
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            features = feature_extractor(imgs)

            # 收集所有GPU的特征和标签
            gathered_features = [torch.zeros_like(features) for _ in range(world_size)]
            gathered_labels = [torch.zeros_like(labels) for _ in range(world_size)]

            dist.all_gather(gathered_features, features)
            dist.all_gather(gathered_labels, labels)

            if rank == 0:  # 只在rank 0上保存
                for f, l in zip(gathered_features, gathered_labels):
                    val_features_list.append(f.cpu())
                    val_labels_list.append(l.cpu())

    # 保存特征
    if rank == 0:
        train_features = torch.cat(train_features_list, dim=0)
        train_labels = torch.cat(train_labels_list, dim=0)
        val_features = torch.cat(val_features_list, dim=0)
        val_labels = torch.cat(val_labels_list, dim=0)

        print(f"Train features shape: {train_features.shape}")
        print(f"Val features shape: {val_features.shape}")

        os.makedirs(args.feature_dir, exist_ok=True)
        torch.save({
            'train_features': train_features,
            'train_labels': train_labels,
            'val_features': val_features,
            'val_labels': val_labels,
            'feature_dim': train_features.shape[1]
        }, os.path.join(args.feature_dir, 'features.pth'))

        print("Features saved successfully!")

    dist.destroy_process_group()


def train_linear_classifier(rank, world_size, args):
    """第二阶段：训练线性分类器"""
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    if rank == 0:
        print("=== Stage 2: Linear Classifier Training ===")

    # 加载特征
    feature_path = os.path.join(args.feature_dir, 'features.pth')
    if not os.path.exists(feature_path):
        raise FileNotFoundError(f"Features not found at {feature_path}. Run feature extraction first.")

    # 只在rank 0加载数据，然后广播
    if rank == 0:
        data = torch.load(feature_path, map_location='cpu')
        train_features = data['train_features']
        train_labels = data['train_labels']
        val_features = data['val_features']
        val_labels = data['val_labels']
        feature_dim = data['feature_dim']
        print(f"Loaded features with dimension: {feature_dim}")
    else:
        train_features = train_labels = val_features = val_labels = feature_dim = None

    # 广播特征维度
    feature_dim_list = [feature_dim]
    dist.broadcast_object_list(feature_dim_list, src=0)
    feature_dim = feature_dim_list[0]

    # 创建数据集和数据加载器
    if rank == 0:
        train_dataset = TensorDataset(train_features, train_labels)
        val_dataset = TensorDataset(val_features, val_labels)
    else:
        # 其他rank创建空的数据集作为占位符
        train_dataset = TensorDataset(torch.empty(0, feature_dim), torch.empty(0, dtype=torch.long))
        val_dataset = TensorDataset(torch.empty(0, feature_dim), torch.empty(0, dtype=torch.long))

    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False)

    train_loader = DataLoader(train_dataset, batch_size=args.classifier_batch_size,
                              sampler=train_sampler, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.classifier_batch_size,
                            sampler=val_sampler, num_workers=2, pin_memory=True)

    # 创建线性分类器
    classifier = LinearClassifier(feature_dim, 1000).to(device)  # ImageNet has 1000 classes
    classifier = nn.parallel.DistributedDataParallel(classifier, device_ids=[rank])

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(classifier.parameters(), lr=args.classifier_lr,
                          momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.classifier_epochs)

    # 训练线性分类器
    best_acc = 0.0
    for epoch in range(args.classifier_epochs):
        train_sampler.set_epoch(epoch)
        classifier.train()

        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (features, labels) in enumerate(train_loader):
            if features.size(0) == 0:  # 跳过空batch（其他rank可能遇到）
                continue

            features = features.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad()
            outputs = classifier(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * features.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if batch_idx % 100 == 0 and rank == 0:
                print(f"Epoch [{epoch + 1}/{args.classifier_epochs}] "
                      f"Batch [{batch_idx}/{len(train_loader)}] Loss: {loss.item():.4f}")

        if rank == 0 and total > 0:
            train_acc = correct / total
            print(f"Epoch [{epoch + 1}/{args.classifier_epochs}] "
                  f"Train Loss: {running_loss / total:.4f}, Train Acc: {train_acc:.4f}")

        scheduler.step()

        # 验证
        if (epoch + 1) % 5 == 0 or epoch == args.classifier_epochs - 1:
            val_acc = evaluate_classifier(classifier, val_loader, device, rank)
            # if rank == 0 and val_acc > best_acc:
            #     best_acc = val_acc
            #     # 保存最佳模型
            #     torch.save({
            #         'classifier_state_dict': classifier.state_dict(),
            #         'epoch': epoch,
            #         'best_acc': best_acc,
            #         'feature_dim': feature_dim
            #     }, os.path.join(args.feature_dir, 'best_linear_classifier.pth'))
            #     print(f"New best validation accuracy: {best_acc:.4f}")

    if rank == 0:
        print(f"Training completed! Best validation accuracy: {best_acc:.4f}")

    dist.destroy_process_group()


def evaluate_classifier(model, dataloader, device, rank):
    """评估线性分类器"""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for features, labels in dataloader:
            if features.size(0) == 0:  # 跳过空batch
                continue

            features = features.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(features)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    # 收集所有GPU的结果
    total_tensor = torch.tensor(total, device=device)
    correct_tensor = torch.tensor(correct, device=device)

    dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
    dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)

    acc = correct_tensor.item() / total_tensor.item() if total_tensor.item() > 0 else 0.0

    if rank == 0:
        print(f"Validation Accuracy: {acc:.4f}")

    return acc


def main():
    parser = argparse.ArgumentParser()

    # 数据相关参数
    parser.add_argument('--data_dir', type=str, required=True, help="ImageNet dataset directory")
    parser.add_argument('--feature_dir', type=str, default='./features', help="Directory to save/load features")
    parser.add_argument('--img_size', type=int, default=224, help="Input image size")

    # 模型相关参数
    parser.add_argument('--parallel', action='store_true', help="Use parallel ViT model")
    parser.add_argument('--resume', type=str, default='', help="Path to checkpoint to resume")

    # 特征提取参数
    parser.add_argument('--batch_size', type=int, default=64, help="Batch size for feature extraction")

    # 分类器训练参数
    parser.add_argument('--classifier_epochs', type=int, default=50, help="Epochs for linear classifier training")
    parser.add_argument('--classifier_batch_size', type=int, default=1024, help="Batch size for classifier training")
    parser.add_argument('--classifier_lr', type=float, default=1e-2, help="Learning rate for classifier")

    # 分布式训练参数
    parser.add_argument('--world_size', type=int, default=torch.cuda.device_count())
    parser.add_argument('--master_addr', type=str, default='localhost', help="Master node address")
    parser.add_argument('--master_port', type=str, default='12355', help="Master node port")

    # 运行模式
    parser.add_argument('--mode', type=str, default='all', choices=['extract', 'train', 'all'],
                        help="Run mode: extract features, train classifier, or both")

    args = parser.parse_args()

    os.environ['MASTER_ADDR'] = args.master_addr
    os.environ['MASTER_PORT'] = args.master_port

    if args.mode in ['extract', 'all']:
        print("Starting feature extraction...")
        mp.spawn(extract_features, args=(args.world_size, args), nprocs=args.world_size, join=True)

    if args.mode in ['train', 'all']:
        print("Starting linear classifier training...")
        mp.spawn(train_linear_classifier, args=(args.world_size, args), nprocs=args.world_size, join=True)


if __name__ == '__main__':
    main()