import torch
import torch.nn as nn
import scipy.io as sio
import wandb
from timm import create_model
from torch.utils.data import DataLoader
from Dataset import DataLoader, trainTransform, testTransform
from PureLearner_AWA2 import ZSLViTPUTrainer
from zslvit import ZSLViT
import numpy as np


def main():
    # 加载配置
    DATASET = 'AWA2'
    gamma = 0.95
    wandb.init(project='VIT-ZSL', config='wandb_config/awa2_zslvit.yaml', mode='disabled')
    # 设置数据路径
    if DATASET == 'AWA2':
        ROOT='H:/Models/AWA2/Animals_with_Attributes2/JPEGImages/'
        DATA_DIR = f'H:/Models/xlsa17/data/{DATASET}'
        attr_length = 85

    # 加载数据集
    data = sio.loadmat(f'{DATA_DIR}/res101.mat')
    attrs_mat = sio.loadmat(f'{DATA_DIR}/att_splits.mat')
    image_files = data['image_files']
    args = wandb.config
    # 处理图像路径
    if DATASET == 'AWA2':
        image_files = np.array([im_f[0][0].split('JPEGImages/')[-1] for im_f in image_files])
    else:
        image_files = np.array([im_f[0][0].split('images/')[-1] for im_f in image_files])

    # 加载标签和索引
    labels = data['labels'].squeeze().astype(np.int64) - 1
    trainval_idx = attrs_mat['trainval_loc'].squeeze() - 1
    test_seen_idx = attrs_mat['test_seen_loc'].squeeze() - 1
    test_unseen_idx = attrs_mat['test_unseen_loc'].squeeze() - 1

    # 属性矩阵
    attrs_mat_numpy = attrs_mat["att"].astype(np.float32).T

    # 测试集数据
    trainval_files = image_files[trainval_idx]
    trainval_labels = labels[trainval_idx]
    uniq_trainval_labels, trainval_labels_based0, counts_trainval_labels = np.unique(
        trainval_labels, return_inverse=True, return_counts=True)

    # 已见类和未见类
    test_seen_files = image_files[test_seen_idx]
    test_seen_labels = labels[test_seen_idx]
    test_unseen_files = image_files[test_unseen_idx]
    test_unseen_labels = labels[test_unseen_idx]

    # 获取独特类别
    uniq_test_unseen_labels = np.unique(test_unseen_labels)

    # 准备数据加载器
    num_workers = 8
    batch_size = 32

    # 训练数据加载器
    trainval_data = DataLoader(ROOT, trainval_files, trainval_labels_based0, transform=trainTransform)
    weights_ = 1. / counts_trainval_labels
    weights = weights_[trainval_labels_based0]
    trainval_sampler = torch.utils.data.WeightedRandomSampler(
        weights, num_samples=trainval_labels_based0.shape[0], replacement=True)
    train_data_loader = torch.utils.data.DataLoader(
        trainval_data, batch_size=batch_size, sampler=trainval_sampler, num_workers=num_workers)

    # 测试数据加载器
    test_seen_data = DataLoader(ROOT, test_seen_files, test_seen_labels, transform=testTransform)
    test_seen_data_loader = torch.utils.data.DataLoader(
        test_seen_data, batch_size=256, shuffle=False, num_workers=num_workers)

    test_unseen_data = DataLoader(ROOT, test_unseen_files, test_unseen_labels, transform=testTransform)
    test_unseen_data_loader = torch.utils.data.DataLoader(
        test_unseen_data, batch_size=256, shuffle=False, num_workers=num_workers)

    # 加载模型
    checkpoint = torch.load('H:/Models/ZSLViT/saved_model/ZSLViT_AWA2_GZSL.pth', map_location=torch.device('cpu'))

    # 构建模型
    zslvit = create_model(
        args.model,
        base_keep_rate=args.base_keep_rate,
        drop_loc=eval(args.drop_loc),
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=None,
        fuse_token=args.fuse_token,
        img_size=(args.input_size, args.input_size),
        dataset=args.dataset
    )

    mlp_g = nn.Linear(768, attr_length, bias=True)
    model = nn.ModuleDict({
        "vit": zslvit,
        "mlp_g": mlp_g
    })
    model.load_state_dict(checkpoint)

    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # 获取特定类别信息
    seenclasses = uniq_trainval_labels  # 已见类
    unseenclasses = uniq_test_unseen_labels  # 未见类
    test_seen_labels = torch.tensor(test_seen_labels)
    test_unseen_labels = torch.tensor(test_unseen_labels)

    # 数据集打包
    dataset_loaders = {
        'train': train_data_loader,
        'test_seen': test_seen_data_loader,
        'test_unseen': test_unseen_data_loader,
        'test_seen_labels': test_seen_labels,
        'test_unseen_labels': test_unseen_labels,
        'seenclasses': seenclasses,
        'unseenclasses': unseenclasses
    }


    # 初始化PU训练器
    trainer = ZSLViTPUTrainer(
        model=model,
        dataset_loaders=dataset_loaders,
        attrs_mat=attrs_mat_numpy,
        device=device
    )

    # 训练与评估
    trainer.train_and_evaluate(dataset="AWA2",epochs=300, batch_size=64, lr=5e-5, gamma=gamma)

if __name__ == '__main__':
    main()