import torch
# from PureLearner_CUB import PSVMAPUTrainer
from PureLearner_AWA2 import PSVMAPUTrainer
# from PureLearner_SUN import PSVMAPUTrainer
from models.modeling import build_gzsl_pipeline
from data import build_dataloader
from models.config import cfg

# 加载配置和模型
cfg.merge_from_file("config/awa2.yaml")
model = build_gzsl_pipeline(cfg)
model_dict = model.state_dict()
# saved_dict = torch.load('H:/PSVMA/checkpoints/best_model_cub.pth')
# saved_dict = torch.load('H:/PSVMA/checkpoints/best_model_sun.pth')
saved_dict = torch.load('H:/PSVMA/checkpoints/best_model_awa2.pth')
saved_dict = {k: v for k, v in saved_dict['model'].items() if k in model_dict}
model_dict.update(saved_dict)
model.load_state_dict(model_dict)

# 加载数据
tr_dataloader, tu_loader, ts_loader, res = build_dataloader(cfg, is_distributed=False)

# 优化参数
device = torch.device(cfg.MODEL.DEVICE)
model = model.to(device)


if __name__ == '__main__':
    # 执行PU学习优化
    print("开始PSVMA模型的PU学习优化")

    trainer = PSVMAPUTrainer(
        model=model,
        dataloader=(tr_dataloader, tu_loader, ts_loader, res),
        device=device
    )
    results = trainer.train_and_evaluate(datasets="AWA2",epochs=300)