import torch
from PureLearner_SUN import PUTrainer
from core.MSDN import MSDN
from core.SUNDataLoader import SUNDataLoader
from global_setting import NFS_path
import numpy as np

idx_GPU = 0
device = torch.device("cuda:{}".format(idx_GPU) if torch.cuda.is_available() else "cpu")
dataloader = SUNDataLoader(NFS_path, device, is_scale=False, is_balance=True)
torch.backends.cudnn.benchmark = True

seed = 2339  #214
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

print('Randomize seed {}'.format(seed))
#%%
batch_size = 50
nepoches = 70
niters = dataloader.ntrain * nepoches // batch_size
dim_f = 2048
dim_v = 300
init_w2v_att = dataloader.w2v_att
att = dataloader.att
normalize_att = dataloader.normalize_att

trainable_w2v = True
lambda_ = 0.0001
bias = 0.
prob_prune = 0
uniform_att_1 = False
uniform_att_2 = True

seenclass = dataloader.seenclasses
unseenclass = dataloader.unseenclasses
desired_mass = 1
report_interval = niters // nepoches

model = MSDN(dim_f, dim_v, init_w2v_att, att, normalize_att,
             seenclass, unseenclass,
             lambda_,
             trainable_w2v, normalize_V=False, normalize_F=True, is_conservative=True,
             uniform_att_1=uniform_att_1, uniform_att_2=uniform_att_2,
             prob_prune=prob_prune, desired_mass=desired_mass, is_conv=False,
             is_bias=True, non_linear_act=False).to(device)

model.load_state_dict(torch.load('SUN_MSDN_GZSL.pth'), strict=False)
model.eval()

# 初始化PU学习训练器
trainer = PUTrainer(
    model=model,
    dataloader=dataloader,
    device=device,
)

# 训练和评估
print("Starting PU learning optimization...")
trainer.train_and_evaluate(epochs=800)
