import numpy as np
import torch
from PureLearner_CUB import PUTrainer
from core.MSDN import MSDN
from core.CUBDataLoader import CUBDataLoader,CUBDataLoaderCached

NFS_path = ''

idx_GPU = 0
device = torch.device("cuda:{}".format(idx_GPU) if torch.cuda.is_available() else "cpu")
# dataloader = CUBDataLoader(NFS_path, device, is_unsupervised_attr=False, is_balance=False)
dataloader = CUBDataLoaderCached("./Cached/", device, is_unsupervised_attr=False, is_balance=False, use_cache=True)

torch.backends.cudnn.benchmark = True


def get_lr(optimizer):
    lr = []
    for param_group in optimizer.param_groups:
        lr.append(param_group['lr'])
    return lr


seed = 214
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

batch_size = 50
nepoches = 30  #
niters = dataloader.ntrain * nepoches // batch_size
dim_f = 2048
dim_v = 300
init_w2v_att = dataloader.w2v_att
att = dataloader.att
normalize_att = dataloader.normalize_att

trainable_w2v = True
lambda_ = 0.1
bias = 0
prob_prune = 0
uniform_att_1 = False
uniform_att_2 = False

seenclass = dataloader.seenclasses
unseenclass = dataloader.unseenclasses
desired_mass = 1
report_interval = niters // nepoches

model = MSDN(dim_f, dim_v, init_w2v_att, att, normalize_att,
             seenclass, unseenclass,
             lambda_,
             trainable_w2v, normalize_V=False, normalize_F=True, is_conservative=True,
             uniform_att_1=uniform_att_1, uniform_att_2=uniform_att_2,
             prob_prune=prob_prune, desired_mass=desired_mass, is_conv=False,
             is_bias=True).to(device)

model.load_state_dict(torch.load('CUB_MSDN_GZSL.pth'), strict=False)
model.eval()

# 初始化PU学习训练器
trainer = PUTrainer(
    model=model,
    dataloader=dataloader,
    device=device,
)

# 训练和评估
print("Starting PU learning optimization...")
trainer.train_and_evaluate(epochs=1000)
