from PureLearner_CUB import PUTrainer
# from PureLearner_SUN import PUTrainer
# from PureLearner_AWA2 import PUTrainer
from model import TransZero
from dataset import *
import wandb

wandb.init(project='TransZero', config='wandb_config/cub_gzsl.yaml',mode="disabled")
# wandb.init(project='TransZero', config='wandb_config/sun_gzsl.yaml',mode="disabled")

config = wandb.config
print('Config file from wandb:', config)
dataloader = CUBDataLoaderOptimized('.', config.device, use_cache=True)
# dataloader = SUNDataLoader('.', config.device)
# dataloader = AWA2DataLoader('.', config.device)
model = TransZero(config, dataloader.att, dataloader.w2v_att,
                  dataloader.seenclasses, dataloader.unseenclasses).to(config.device)

model.load_state_dict(torch.load('TransZero_CUB_GZSL.pth'), strict=False)
# model.load_state_dict(torch.load('TransZero_SUN_GZSL.pth'), strict=False)
# model.load_state_dict(torch.load('TransZero_AWA2_GZSL.pth'), strict=False)
model.eval()

# 初始化PU学习训练器
trainer = PUTrainer(
    model=model,
    dataloader=dataloader,
    device=config.device,
)

# 训练和评估
print("Starting PU learning optimization...")
trainer.train_and_evaluate(epochs=300)