import torch
from PureLearner_AWA2 import PUTrainer
from core.AWA2DataLoader import AWA2DataLoader
from core.MSDN import MSDN
from global_setting import NFS_path
import numpy as np

idx_GPU = 0
device = torch.device("cuda:{}".format(idx_GPU) if torch.cuda.is_available() else "cpu")
dataloader = AWA2DataLoader(NFS_path, device)
torch.backends.cudnn.benchmark = True


def get_lr(optimizer):
    lr = []
    for param_group in optimizer.param_groups:
        lr.append(param_group['lr'])
    return lr


seed = 87778
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

batch_size = 50
nepoches = 50
niters = dataloader.ntrain * nepoches // batch_size
dim_f = 2048
dim_v = 300
init_w2v_att = dataloader.w2v_att
att = dataloader.att
att[att < 0] = 0
normalize_att = dataloader.normalize_att

trainable_w2v = True
lambda_ = 0.12
bias = 0
prob_prune = 0
uniform_att_1 = False
uniform_att_2 = False

seenclass = dataloader.seenclasses
unseenclass = dataloader.unseenclasses
desired_mass = 1
report_interval = niters // nepoches

model = MSDN(dim_f, dim_v, init_w2v_att, att, normalize_att,
             seenclass, unseenclass,
             lambda_,
             trainable_w2v, normalize_V=True, normalize_F=True, is_conservative=True,
             uniform_att_1=uniform_att_1, uniform_att_2=uniform_att_2,
             prob_prune=prob_prune, desired_mass=desired_mass, is_conv=False,
             is_bias=True).to(device)

model.load_state_dict(torch.load('AWA2_MSDN_GZSL.pth'), strict=False)
model.eval()

# 初始化PU学习训练器
trainer = PUTrainer(
    model=model,
    dataloader=dataloader,
    device=device,
)

# 训练和评估
print("Starting PU learning optimization...")
w, z, results = trainer.train_and_evaluate()
