import copy
import os
import random

import numpy as np
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR

import ramps
from hmix_args import process_args
from hmix_image_dataset import load_image_dataset
from nnPU_loss2 import PNCEloss, PUCEloss, ABSPUloss, DistPUloss, FOPULoss, PULBloss, BalancedPUCEloss, BalancedPULBloss, \
                       PULBloss2, BalancedPUCEloss2, BalancedPULBloss2, \
                       AdaptivePULoss, CurvatureAdjustedPULoss, AdaptiveSmoothPULoss, AdaptiveRiskPULoss, CompositePULoss, \
                       AdaptiveWeightPULoss, InstanceWeightedPULoss
from statistic import test
from hmix_model import MultiLayerPerceptron
from model.alexnet import alexnet
from model.cnn7 import cnn_cifar, cnn_stl
from model.lenet5 import lenet_fmnist, LeNet, LeNet_FMNIST
from model.loss import loss_entropy, loss_ft
from model.ResNet_Zoo import ResNet18, ResNet50
from utils.misc import multi_class_accuracy

args = process_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def seed_torch(seed=args.seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # torch.backends.cudnn.enabled = True
    # torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True


seed_torch()


def select_loss(loss_name, gamma=1, tau=0.5, eta=1.0, pi=0.5, a0=2.0, b0=1.0, b1=0.8, k1=8.0, k2=8.0, t=0.5):
    losses = {
        # ===== 常见分类损失 =====
        "zero-one": lambda x: -0.5 * torch.sign(x) + 0.5,
        "hinge": lambda x: torch.clamp(1 - x, min=0),
        "exponential": lambda x: torch.exp(-x),
        "perceptron": lambda x: torch.clamp(-x, min=0),  # Logistic 的前身
        "logistic": lambda x: F.softplus(-x),
        "sigmoid": lambda x: torch.sigmoid(-x),

        # ===== 变体 =====
        "squared-hinge": lambda x: torch.clamp(1 - x, min=0) ** 2,  # Squared Hinge（Hinge 的平方版）
        "smooth-hinge": lambda x: torch.where(x < 0, 0.5 - x,
                                              torch.where(x < 1, 0.5 * (1 - x) ** 2, torch.zeros_like(x))), # Hinge 的平滑版
        "modified-huber": lambda x: torch.where(x >= -1, 0.25 * torch.clamp(-x, min=0) ** 2, -x),   # Hinge 与 Squared 结合的分段形式
        "squared": lambda x: (1 - x) ** 2,

        # ===== 鲁棒损失函数 =====
        "unhinged": lambda x: 1 - x,
        "tangent": lambda x: (2 * torch.atan(x) - 1) ** 2,
        "savage": lambda x: 1 / (1 + torch.exp(x)) ** 2,
        "focal": lambda x, gamma_focal=2.0: (1 - torch.sigmoid(x)) ** gamma_focal * F.softplus(-x), # Logistic 的变体，提升小概率样本的关注
        "gsigmoid": lambda x: torch.sigmoid(-gamma * x),  # Sigmoid 的 gamma 控制扩展
        "ramp": lambda x: torch.clamp(torch.min(torch.tensor(1.0, device=x.device), (1 - x) / 2), min=0),

        # ===== 特殊设计 =====
        "pinball": lambda x: torch.max(1 - x, -tau * (1 - x)),
        "rescaled-hinge": lambda x: eta * (1 - torch.exp(-eta * torch.clamp(1 - x, min=0))) / (1 - torch.exp(-eta)),
        "double-hinge": lambda x: torch.max(torch.stack([
            torch.zeros_like(x), (1 - x) / 2, -x
        ]), dim=0).values,

        # ===== mine =====
        "mine": lambda x: 1 - torch.exp(
            -calc_alpha(pi, a0, k1) *
            torch.clamp(1 - x, min=0) ** calc_beta(pi, b0, b1, k2, t)
        ),
    }

    return losses[loss_name]


def calc_alpha(pi, a0=2.0, k1=8.0):
    """计算 alpha(π) 参数"""
    # 将浮点数转换为标量张量
    pi_tensor = torch.tensor(pi, dtype=torch.float)
    return a0 / (1 + torch.exp(k1 * (pi_tensor - 0.5)))

def calc_beta(pi, b0=1.0, b1=0.8, k2=8.0, t=0.5):
    """计算 beta(π) 参数"""
    # 将浮点数转换为标量张量
    pi_tensor = torch.tensor(pi, dtype=torch.float)
    return b0 + b1 * torch.sigmoid(k2 * (pi_tensor - t))


def select_loss_function(method_name, loss_type):
    loss_funcs = {
        "uPU": PUCEloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            nnpu=False,
            objective=False,  # objective=False&True 都不影响
        ),
        "nnPU-objective": PUCEloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            nnpu=True,
            objective=True,
        ),
        "nnPU-out": PUCEloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            nnpu=True,
            objective=False,
        ),
        "absPU": ABSPUloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            objective=False,  # objective=False&True 都不影响
        ),
        "DistPU-brief": DistPUloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
        ),
        "FOPU": FOPULoss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            objective=False,
            lam_f=args.lam_f,
        ),
        "PULB": PULBloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            objective=False,
            momentum=args.momentum,
        ),
        "BalancePU": BalancedPUCEloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            objective=False,
            balance_weight=args.balance_weight,
        ),
        "PULB+BalancePU": BalancedPULBloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            objective=False,
            momentum=args.momentum,
            balance_weight=args.balance_weight,
        ),
        "PULB2": PULBloss2(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            objective=True,
            momentum=args.momentum,
        ),
        "BalancePU2": BalancedPUCEloss2(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            objective=True,
            balance_weight=args.balance_weight,
        ),
        "PULB2+BalancePU2": BalancedPULBloss2(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            objective=True,
            momentum=args.momentum,
            balance_weight=args.balance_weight,
        ),
        "AdaptivePU": AdaptivePULoss(
            args.prior,
            loss=loss_type,
            a0=args.a0,
            b0=args.b0,
            b1=args.b1,
            k1=args.k1,
            k2=args.k2,
        ),
        "CurvatureAdjustedPU": CurvatureAdjustedPULoss(
            args.prior,
            loss=loss_type,
        ),
        "AdaptiveSmoothPU": AdaptiveSmoothPULoss(
            args.prior,
            loss=loss_type,
            a0=args.a0,
            b0=args.b0,
            b1=args.b1,
            k1=args.k1,
            k2=args.k2,
            objective=False,
        ),
        "AdaptiveRiskPU": AdaptiveRiskPULoss(
            args.prior,
            loss=loss_type,
            a0=args.a0,
            b0=args.b0,
            b1=args.b1,
            k1=args.k1,
            k2=args.k2,
            objective=False,
        ),
        "CompositePU": CompositePULoss(
            args.prior,
            loss=loss_type,
            a0=args.a0,
            b0=args.b0,
            b1=args.b1,
            k1=args.k1,
            k2=args.k2,
            objective=False,
        ),
        'AdaptiveWeightPU': AdaptiveWeightPULoss(
            args.prior,
            loss=loss_type,
            objective=False,
        ),
        'InstanceWeightedPU': InstanceWeightedPULoss(
            args.prior,
            loss=loss_type,
            objective=False,
        ),
    }
    return loss_funcs[method_name]


def select_model(model_name):
    models = {"cnn_cifar": cnn_cifar, "cnn_stl": cnn_stl,
              "lenet5": LeNet, "resnet50": ResNet50,
              'mlp': MultiLayerPerceptron}
    return models[model_name]

def make_optimizer(model, optim, stepsize):
    if optim == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=stepsize,
                                     betas=(0.5, 0.99),
                                     weight_decay=args.weight_decay)
    elif optim == "adagrad":
        optimizer = torch.optim.Adagrad(model.parameters(),
                                        lr=stepsize,
                                        weight_decay=args.weight_decay)
    elif optim == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=stepsize,
                                      weight_decay=args.weight_decay)
    elif optim == "sgd":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=stepsize,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif optim == "adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         lr=stepsize,
                                         rho=0.95)
    return optimizer


def make_scheduler(optim):
    # scheduler = MultiStepLR(optim, milestones=[50, 100], gamma=0.1)
    scheduler = MultiStepLR(optim,
                            milestones=args.milestones,
                            gamma=args.scheduler_gamma)

    return scheduler


class trainer():
    def __init__(self, models, loss_funcs, optimizers, schedulers,
                 XYPtrainLoader, XYUtrainLoader, XYvalidLoader, XYtestLoader,
                 prior, loss_func_pn):
        self.models = models
        self.loss_funcs = loss_funcs.values()
        self.optimizers = optimizers.values()
        self.schedulers = schedulers.values()

        self.XYPtrainLoader = XYPtrainLoader
        self.XYUtrainLoader = XYUtrainLoader
        self.XYvalidLoader = XYvalidLoader
        self.XYtestLoader = XYtestLoader
        self.prior = prior
        self.loss_func_pn = loss_func_pn

        # only for PULBLoss and BalancedPULBloss
        self.bound = 0

    def train(self, epoch):
        results_overall_val = {}
        results_class_val = {}
        results_overall_test = {}
        results_class_test = {}
        loss_all = {}

        for model_name in self.models.keys():
            results_overall_val[model_name] = []
            results_class_val[model_name] = []
            results_overall_test[model_name] = []
            results_class_test[model_name] = []
            loss_all[model_name] = []

        for net, opt, scheduler, loss_func, model_name in zip(
                self.models.values(), self.optimizers, self.schedulers,
                self.loss_funcs, self.models.keys()):
            XYUtrainLoader_iter = iter(self.XYUtrainLoader)
            XYPtrainLoader_iter = iter(self.XYPtrainLoader)
            net.train()
            total_steps = args.val_iterations

            if net.num_classifier == 1:
                outputs_p_all = np.array([]).reshape(0, 1)
                outputs_u_all = np.array([]).reshape(0, 1)
            else:
                outputs_p_all = np.array([]).reshape(0, 2)
                outputs_u_all = np.array([]).reshape(0, 2)
            targets_p_all = np.array([])
            targets_u_all = np.array([])

            targets_all = np.array([])
            predicts_all = np.array([])
            probs_all = np.array([])

            for i in range(args.val_iterations):
                # load positive data
                try:
                    data_p, t_p = next(XYPtrainLoader_iter)
                except:
                    XYPtrainLoader_iter = iter(self.XYPtrainLoader)
                    data_p, t_p = next(XYPtrainLoader_iter)

                # load unlabeled data
                try:
                    data_u, t_u = next(XYUtrainLoader_iter)
                except:
                    XYUtrainLoader_iter = iter(self.XYUtrainLoader)
                    data_u, t_u = next(XYUtrainLoader_iter)

                target_p = torch.ones(data_p.shape[0]).to(device,
                                                          non_blocking=True)
                target_u = - torch.ones(data_u.shape[0]).to(device,
                                                           non_blocking=True)

                t = torch.cat((t_p, t_u), dim=0)
                targets_all = np.hstack((targets_all, t.detach().cpu().numpy()))

                data_p = data_p.to(device, non_blocking=True)
                data_u = data_u.to(device, non_blocking=True)

                data = torch.cat((data_p, data_u), dim=0)
                idx_p = slice(0, len(data_p))
                idx_u = slice(len(data_p), len(data))

                output = net(data)  # get output for every net

                size = len(t)
                p = np.reshape(torch.sigmoid(output).detach().cpu().numpy(), size)
                probs_all = np.hstack((probs_all, p))
                # o = np.reshape(torch.sign(outputs).detach().cpu().numpy(), size)
                o = np.where(p > 0.5, 1, -1)
                predicts_all = np.hstack((predicts_all, o))

                if args.method == "PULB" or args.method == "PULB+BalancePU" \
                        or args.method == "PULB2" or args.method == "PULB2+BalancePU2":
                    # only for PULBLoss and BalancedPULBloss
                    loss, risk, risk_p, risk_u, risk_p_pos, risk_u_neg, risk_p_neg, current_bound = loss_func(
                        output[idx_p], output[idx_u], target_p, target_u, self.bound)
                    self.bound = current_bound
                elif args.method == "AdaptiveRiskPU" or args.method == "CompositePU" or args.method == "InstanceWeightedPU":
                    loss, risk, risk_p, risk_u, risk_p_pos, risk_u_neg, risk_p_neg = loss_func(
                        output[idx_p], output[idx_u], target_p, target_u, epoch=epoch, total_epochs=args.epoch)
                else:
                    loss, risk, risk_p, risk_u, risk_p_pos, risk_u_neg, risk_p_neg = loss_func(
                        output[idx_p], output[idx_u], target_p, target_u)
                    # loss, risk, risk_p, risk_u, risk_p_pos, risk_u_neg, risk_p_neg = loss_func(
                    #     output[idx_p], output[idx_u], target_p, target_u, epoch=epoch, total_epochs=args.epoch)


                # ############### add

                outputs_p_all = np.vstack((outputs_p_all, output[idx_p].detach().cpu().numpy()))
                outputs_u_all = np.vstack((outputs_u_all, output[idx_u].detach().cpu().numpy()))
                targets_p_all = np.hstack((targets_p_all, target_p.detach().cpu().numpy()))
                targets_u_all = np.hstack((targets_u_all, target_u.detach().cpu().numpy()))

                # ############### end add

                opt.zero_grad()  # clear gradients for next train
                loss.backward()  # backpropagation, compute gradients
                opt.step()  # apply gradients

            outputs_p_all = torch.from_numpy(outputs_p_all).to(device)
            outputs_u_all = torch.from_numpy(outputs_u_all).to(device)
            targets_p_all = torch.from_numpy(targets_p_all).to(device)
            targets_u_all = torch.from_numpy(targets_u_all).to(device)

            if args.method == "PULB" or args.method == "PULB+BalancePU" \
                    or args.method == "PULB2" or args.method == "PULB2+BalancePU2":
                # only for PULBLoss and BalancedPULBloss
                _, train_loss, train_loss_p, train_loss_u, train_loss_p_pos, train_loss_u_neg, train_loss_p_neg, _ \
                    = loss_func(outputs_p_all, outputs_u_all, targets_p_all, targets_u_all, self.bound)
            else:
                _, train_loss, train_loss_p, train_loss_u, train_loss_p_pos, train_loss_u_neg, train_loss_p_neg \
                    = loss_func(outputs_p_all, outputs_u_all, targets_p_all, targets_u_all)

            val_overall_metrics, val_class_metrics = multi_class_accuracy(probs_all, predicts_all, targets_all)

            test_overall_metrics, test_class_metrics, test_loss, test_loss_p, test_loss_n = test(args, self.XYtestLoader, net, self.loss_func_pn, device)
            results_overall_test[model_name].extend(test_overall_metrics)
            results_class_test[model_name].extend(test_class_metrics)

            # val_overall_metrics, val_class_metrics, val_loss, val_loss_p, val_loss_n = test(args, self.XYvalidLoader, net, self.loss_func_pn, device)
            results_overall_val[model_name].extend(val_overall_metrics)
            results_class_val[model_name].extend(val_class_metrics)

            loss_all[model_name].extend([test_loss, test_loss_p, test_loss_n, train_loss, train_loss_p, train_loss_u,
                                         train_loss_p_pos, train_loss_u_neg, train_loss_p_neg])

        return results_overall_val, results_class_val, results_overall_test, results_class_test, loss_all

    def run(self, Epochs):
        results_val = {}
        results_test = {}
        loss_all = {}
        # [precision,recall,error_rate, num_predicted_positive]
        metrics_overall = ['ACC', 'AUC', 'Macro_F1', 'Micro_F1', 'Precision', 'Recall', 'ErrorRate']
        metrics_class = ['Class_F1', 'Class_Precision', 'Class_Recall', 'Class_NPP']

        for model_name in self.models.keys():
            results_val[model_name] = [[] for _ in range(len(metrics_overall) + len(metrics_class) * args.num_classifier)]
            results_test[model_name] = [[] for _ in range(len(metrics_overall) + len(metrics_class) * args.num_classifier)]
            loss_all[model_name] = [[], [], [], [], [], [], [], [], []]

        # print("Epoch\t" + "".join([
        #     "{}/ACC\t{}/AUC\t{}/F1\t{}/Precision\t{}/Recall\t{}/ErrorRate\t{}/NPP\t".format(
        #         model_name, model_name, model_name, model_name, model_name, model_name, model_name)
        #     for model_name in sorted(self.models.keys())
        # ]))

        # print("Epoch\t" + "".join([
        #     "{}/test_loss\t{}/test_loss_p\t{}/test_loss_n\t{}/train_loss\t{}/train_loss_u\t".format(
        #         model_name, model_name, model_name, model_name, model_name)
        #     for model_name in sorted(self.models.keys())
        # ]))

        print("Epoch\t" + "".join([
            "".join([
                # 首先输出整体指标
                "{}/{}\t".format(model_name, metric)
                for metric in metrics_overall
            ]) +
            "".join([
                # 然后对每个分类指标，依次输出所有类别
                "".join([
                    "{}/{}_{}\t".format(model_name, metric, j + 1)
                    for j in range(args.num_classifier)
                ])
                for metric in metrics_class
            ])
            for model_name in sorted(results_test.keys())
        ]) + "\n")


        for epoch in range(Epochs):
            _results_overall_val, _results_class_val, _results_overall_test, _results_class_test, _loss_all = self.train(epoch)
            # print("{}\t".format(epoch) + "".join([
            #     "{:-8}\t{:-8}\t{:-8}\t{:-8}\t{:-8}\t{:-8}\t{:-8}\t".format(
            #         round(_results_test[model_name][0], 4),
            #         round(_results_test[model_name][1], 4),
            #         round(_results_test[model_name][2], 4),
            #         round(_results_test[model_name][3], 4),
            #         round(_results_test[model_name][4], 4),
            #         round(_results_test[model_name][5], 4),
            #         round(_results_test[model_name][6], 4),
            #     ) for model_name in sorted(_results_test.keys())
            # ]))

            # print("{}\t".format(epoch) + "".join([
            #     "{:-8}\t{:-8}\t{:-8}\t{:-8}\t{:-8}\t".format(
            #         round(_loss_all[model_name][0].item(), 4),
            #         round(_loss_all[model_name][1].item(), 4),
            #         round(_loss_all[model_name][2].item(), 4),
            #         round(_loss_all[model_name][3].item(), 4),
            #         round(_loss_all[model_name][4].item(), 4),
            #     ) for model_name in sorted(_loss_all.keys())
            # ]))

            print("{}\t".format(epoch) + "".join([
                # 打印总体指标
                "".join(["{:-8}\t".format(round(_results_overall_test[model_name][i], 4))
                         for i in range(len(metrics_overall))]) +
                # 打印类别指标，每个类别指标都是2个值
                "".join([
                    "".join(["{:-8}\t".format(round(value, 4))
                             for value in _results_class_test[model_name][j]])
                    for j in range(len(_results_class_test[model_name]))])
                for model_name in sorted(results_test.keys())
            ]) + "\n")

            for model_name in self.models.keys():
                for i in range(len(metrics_overall)):
                    results_test[model_name][i].append(_results_overall_test[model_name][i])
                    results_val[model_name][i].append(_results_overall_val[model_name][i])
                for i in range(len(metrics_class)):
                    for j in range(args.num_classifier):
                        results_test[model_name][args.num_classifier * i + len(metrics_overall) + j].\
                            append(_results_class_test[model_name][i][j])
                        results_val[model_name][args.num_classifier * i + len(metrics_overall) + j].\
                            append(_results_class_val[model_name][i][j])

                loss_all[model_name][0].append(_loss_all[model_name][0].item())
                loss_all[model_name][1].append(_loss_all[model_name][1].item())
                loss_all[model_name][2].append(_loss_all[model_name][2].item())
                loss_all[model_name][3].append(_loss_all[model_name][3].item())
                loss_all[model_name][4].append(_loss_all[model_name][4].item())
                loss_all[model_name][5].append(_loss_all[model_name][5].item())
                loss_all[model_name][6].append(_loss_all[model_name][6].item())
                loss_all[model_name][7].append(_loss_all[model_name][7].item())
                loss_all[model_name][8].append(_loss_all[model_name][8].item())

        return results_val, results_test, loss_all


def load_pretrained_vectors(vocab, fname):
    """Load pretrained vectors and create embedding layers.

    Args:
        word2idx (Dict): Vocabulary built from the corpus
        fname (str): Path to pretrained vector file

    Returns:
        embeddings (np.array): Embedding matrix with shape (N, d) where N is
            the size of word2idx and d is embedding dimension
    """

    print("Loading pretrained vectors...")
    fin = open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
    n, d = map(int, fin.readline().split())

    # Initilize random embeddings
    embeddings = np.random.uniform(-0.25, 0.25, (len(vocab), d))
    embeddings[vocab['<pad>']] = np.zeros((d, ))

    # Load pretrained vectors
    count = 0
    for line in fin:
        tokens = line.rstrip().split(' ')
        word = tokens[0]
        if word in vocab:
            count += 1
            embeddings[vocab[word]] = np.array(tokens[1:], dtype=np.float32)

    print(f"There are {count} / {len(vocab)} pretrained vectors found.")

    return np.asarray(embeddings, dtype=np.float32)


def main():
    # args = process_args()
    print("using:", device)

    image_datasets = ['mnist', 'fashionmnist', 'cifar10', 'stl10', 'alzheimer']
    text_datasets = [
        'imdb', 'yelp_full', '20ng', 'yelp', 'amazon', 'amazon_full'
    ]

    text_flag = False
    dim = None
    pretrained_embedding = None
    vocab_size = None

    # dataset setup
    if args.dataset in image_datasets:
        (XYPtrainLoader, XYUtrainLoader, XYtrainLoader, XYvalidLoader, XYtestLoader,
         XYPtrainSet, dim,
         prior) = load_image_dataset(args.dataset, args.labeled,
                                     args.batchsize, args.positive_label_list)
        prior = args.prior

    # model setup
    loss_type = select_loss(args.loss, args.loss_gamma, args.loss_tau, args.loss_eta,
                            pi=args.prior, a0=args.a0, b0=args.b0, b1=args.b1, k1=args.k1, k2=args.k2, t=0.5)
    selected_model = select_model(args.model)
    model = selected_model(int(dim))
    models = {
        args.method: copy.deepcopy(model).to(device),
    }
    models_pn = {
        "PN": copy.deepcopy(model).to(device),
    }

    loss_funcs = {
        args.method: select_loss_function(args.method, loss_type),
    }
    loss_func_pn = PNCEloss(args.prior, loss=loss_type)

    # trainer setup
    optimizers = {
        k: make_optimizer(v, args.optim, args.stepsize)
        for k, v in models.items()
    }
    schedulers = {k: make_scheduler(v) for k, v in optimizers.items()}

    optimizers_pn = {
        k: make_optimizer(v, args.optim, args.stepsize)
        for k, v in models_pn.items()
    }
    schedulers_pn = {k: make_scheduler(v) for k, v in optimizers_pn.items()}

    print("input dim: {}".format(dim))
    print("prior: {}".format(args.prior))
    print("loss: {}".format(args.loss))
    print("batchsize: {}".format(args.batchsize))
    print("model: {}".format(args.model))
    print("beta: {}".format(args.beta))
    print("gamma: {}".format(args.gamma))
    print("")

    # run training
    # training (nnPU, uPU)
    PUtrainer = trainer(models, loss_funcs, optimizers, schedulers,
                        XYPtrainLoader, XYUtrainLoader, XYvalidLoader,
                        XYtestLoader, args.prior, loss_func_pn)
    results_pu_val, results_pu_test, pu_loss = PUtrainer.run(args.epoch)

    # PNtrainer = trainer_pn(models_pn, loss_funcs_pn, optimizers_pn,
    #                        schedulers_pn, XYtrainLoader, XYvalidLoader,
    #                        XYtestLoader, prior)
    # results_pn_val, results_pn_test = PNtrainer.run(args.epoch)

    save_path = "./results/"
    save_result(save_path, results_pu_val, saved_mode="val")
    save_result(save_path, results_pu_test, saved_mode="test")
    save_loss(save_path, pu_loss)


def save_result(save_path, results_pu, saved_mode="test"):
    if args.method == "FOPU":
        filename_result = save_path + "".join(["{}".format(model_name) for model_name in sorted(results_pu.keys())]) \
                          + "_{}_{}_{}_{}_{}_{}_{}".format(args.preset, args.positive_label_list, args.model, args.stepsize, args.weight_decay,
                                                        args.labeled, args.lam_f)
    elif args.method == "PULB" or args.method == "PULB2":
        filename_result = save_path + "".join(["{}".format(model_name) for model_name in sorted(results_pu.keys())]) \
                          + "_{}_{}_{}_{}_{}_{}_{}".format(args.preset, args.positive_label_list, args.model, args.stepsize, args.weight_decay,
                                                        args.labeled, args.momentum)
    elif args.method == "BalancePU" or args.method == "BalancePU2":
        filename_result = save_path + "".join(["{}".format(model_name) for model_name in sorted(results_pu.keys())]) \
                          + "_{}_{}_{}_{}_{}_{}_{}".format(args.preset, args.positive_label_list, args.model, args.stepsize, args.weight_decay,
                                                        args.labeled, args.balance_weight)
    elif args.method == "PULB+BalancePU" or args.method == "PULB2+BalancePU2":
        filename_result = save_path + "".join(["{}".format(model_name) for model_name in sorted(results_pu.keys())]) \
                          + "_{}_{}_{}_{}_{}_{}_{}_{}".format(args.preset, args.positive_label_list, args.model, args.stepsize, args.weight_decay,
                                                           args.labeled, args.momentum, args.balance_weight)
    else:
        filename_result = save_path + "".join(["{}".format(model_name) for model_name in sorted(results_pu.keys())]) \
                          + "_{}_{}_{}_{}_{}_{}".format(args.preset, args.positive_label_list, args.model, args.stepsize, args.weight_decay, args.labeled)

    if args.loss == 'gsigmoid':
        filename_result += "_{}_{}{}_{}.txt".format(args.seed, args.loss, args.loss_gamma, saved_mode)
    elif args.loss == 'pinball':
        filename_result += "_{}_{}{}_{}.txt".format(args.seed, args.loss, args.loss_tau, saved_mode)
    elif args.loss == 'rescaled_hinge':
        filename_result += "_{}_{}{}_{}.txt".format(args.seed, args.loss, args.loss_eta, saved_mode)
    elif args.loss == 'mine':
        filename_result += "_{}_{}_{}_{}_{}_{}_{}_{}.txt".format(args.seed, args.loss, args.a0, args.b0, args.b1, args.k1, args.k2, saved_mode)
    else:
        filename_result += "_{}_{}_{}.txt".format(args.seed, args.loss, saved_mode)

    metrics_overall = ['ACC', 'AUC', 'Macro_F1', 'Micro_F1', 'Precision', 'Recall', 'ErrorRate']
    metrics_class = ['Class_F1', 'Class_Precision', 'Class_Recall', 'Class_NPP']

    with open(filename_result, 'w') as file_result:
        file_result.write("Epoch\t" + "".join([
            "".join([
                "{}/{}\t".format(model_name, metric) for metric in metrics_overall
            ]) +
            "".join([
                "".join([
                    "{}/{}_{}\t".format(model_name, metric, j + 1) for j in range(args.num_classifier)
                ])
                for metric in metrics_class
            ])
            for model_name in sorted(results_pu.keys())
        ]) + "\n")

        for epoch in range(args.epoch):
            file_result.write("{}\t".format(epoch) + "".join([
                    "".join(["{:-8}\t".format(round(results_pu[model_name][i][epoch], 4))
                             for i in range(len(results_pu[model_name]))])  # 循环处理每个指标的所有类别
            for model_name in sorted(results_pu.keys())
            ]) + "\n")

def save_loss(save_path, pu_loss):
    if args.method == "FOPU":
        filename_result = save_path + "".join(["{}".format(model_name) for model_name in sorted(pu_loss.keys())]) \
                          + "_{}_{}_{}_{}_{}_{}_{}".format(args.preset, args.positive_label_list, args.model, args.stepsize, args.weight_decay,
                                                        args.labeled, args.lam_f)
    elif args.method == "PULB" or args.method == "PULB2":
        filename_result = save_path + "".join(["{}".format(model_name) for model_name in sorted(pu_loss.keys())]) \
                          + "_{}_{}_{}_{}_{}_{}_{}".format(args.preset, args.positive_label_list, args.model, args.stepsize, args.weight_decay,
                                                        args.labeled, args.momentum)
    elif args.method == "BalancePU" or args.method == "BalancePU2":
        filename_result = save_path + "".join(["{}".format(model_name) for model_name in sorted(pu_loss.keys())]) \
                          + "_{}_{}_{}_{}_{}_{}_{}".format(args.preset, args.positive_label_list, args.model, args.stepsize, args.weight_decay,
                                                        args.labeled, args.balance_weight)
    elif args.method == "PULB+BalancePU" or args.method == "PULB2+BalancePU2":
        filename_result = save_path + "".join(["{}".format(model_name) for model_name in sorted(pu_loss.keys())]) \
                          + "_{}_{}_{}_{}_{}_{}_{}_{}".format(args.preset, args.positive_label_list, args.model, args.stepsize, args.weight_decay,
                                                        args.labeled, args.momentum, args.balance_weight)
    else:
        filename_result = save_path + "".join(["{}".format(model_name) for model_name in sorted(pu_loss.keys())]) \
                          + "_{}_{}_{}_{}_{}_{}".format(args.preset, args.positive_label_list, args.model, args.stepsize, args.weight_decay, args.labeled)

    if args.loss == 'gsigmoid':
        filename_result += "_{}_{}{}_loss.txt".format(args.seed, args.loss, args.loss_gamma)
    elif args.loss == 'pinball':
        filename_result += "_{}_{}{}_loss.txt".format(args.seed, args.loss, args.loss_tau)
    elif args.loss == 'rescaled_hinge':
        filename_result += "_{}_{}{}_loss.txt".format(args.seed, args.loss, args.loss_eta)
    elif args.loss == 'mine':
        filename_result += "_{}_{}_{}_{}_{}_{}_{}.txt".format(args.seed, args.loss, args.a0, args.b0, args.b1, args.k1, args.k2)
    else:
        filename_result += "_{}_{}_loss.txt".format(args.seed, args.loss)

    with open(filename_result, 'w') as file_result:
        file_result.write("Epoch\t" + "".join([
            "{}/test_loss\t{}/test_loss_p\t{}/test_loss_n\t{}/train_loss\t{}/train_loss_p\t{}/train_loss_u\t{}/R_p_pos\t{}/R_u_neg\t{}/R_p_neg\t".format(
                model_name, model_name, model_name, model_name, model_name, model_name, model_name, model_name, model_name)
            for model_name in sorted(pu_loss.keys())
        ]) + "\n")

        for epoch in range(args.epoch):
            file_result.write("{}\t".format(epoch) + "".join([
                "{:-8}\t{:-8}\t{:-8}\t{:-8}\t{:-8}\t{:-8}\t{:-8}\t{:-8}\t{:-8}\t".format(
                    round(pu_loss[model_name][0][epoch], 4),
                    round(pu_loss[model_name][1][epoch], 4),
                    round(pu_loss[model_name][2][epoch], 4),
                    round(pu_loss[model_name][3][epoch], 4),
                    round(pu_loss[model_name][4][epoch], 4),
                    round(pu_loss[model_name][5][epoch], 4),
                    round(pu_loss[model_name][6][epoch], 4),
                    round(pu_loss[model_name][7][epoch], 4),
                    round(pu_loss[model_name][8][epoch], 4),
                ) for model_name in sorted(pu_loss.keys())
            ]) + "\n")


if __name__ == '__main__':
    import os
    import sys
    os.chdir(sys.path[0])
    print("working dir: {}".format(os.getcwd()))
    main()
