
import copy
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR

from hmix_args import process_args
from hmix_image_dataset_new import load_image_dataset
from nnPU_loss import (PNCEloss, PUCEloss, ABSPUloss, DistPUloss, PULBloss, BalancedPUCEloss, BalancedPULBloss,
                       ScalePULoss, ScalePULossVarianceLambda, FocalScalePULoss, GeometricScalePULoss)
from statistic import test
from hmix_model import MultiLayerPerceptron
from model.alexnet import alexnet
from model.cnn7 import cnn_cifar, cnn_stl
from model.lenet5 import lenet_fmnist, LeNet, LeNet_FMNIST
from model.loss import loss_entropy, loss_ft
from model.ResNet_Zoo import ResNet18, ResNet50
from utils.misc import multi_class_accuracy

args = process_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def seed_torch(seed=args.seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # torch.backends.cudnn.enabled = True
    # torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True


seed_torch()


def select_loss(loss_name, gamma=1, tau=0.5, eta=1.0, pi=0.5, a0=2.0, b0=1.0, b1=0.8, k1=8.0, k2=8.0, t=0.5):
    losses = {
        # ===== 常见分类损失 =====
        "zero-one": lambda x: -0.5 * torch.sign(x) + 0.5,
        "hinge": lambda x: torch.clamp(1 - x, min=0),
        "exponential": lambda x: torch.exp(-x),
        "perceptron": lambda x: torch.clamp(-x, min=0),  # Logistic 的前身
        "logistic": lambda x: F.softplus(-x),
        "sigmoid": lambda x: torch.sigmoid(-x),
    }

    return losses[loss_name]


def calc_alpha(pi, a0=2.0, k1=8.0):
    """计算 alpha(π) 参数"""
    # 将浮点数转换为标量张量
    pi_tensor = torch.tensor(pi, dtype=torch.float)
    return a0 / (1 + torch.exp(k1 * (pi_tensor - 0.5)))


def calc_beta(pi, b0=1.0, b1=0.8, k2=8.0, t=0.5):
    """计算 beta(π) 参数"""
    # 将浮点数转换为标量张量
    pi_tensor = torch.tensor(pi, dtype=torch.float)
    return b0 + b1 * torch.sigmoid(k2 * (pi_tensor - t))


def select_loss_function(method_name, loss_type):
    loss_funcs = {
        "uPU": PUCEloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            nnpu=False,
            objective=False,
        ),
        "nnPU-objective": PUCEloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            nnpu=True,
            objective=True,
        ),
        "nnPU-out": PUCEloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            nnpu=True,
            objective=False,
        ),
        "absPU": ABSPUloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            objective=False,
        ),
        "DistPU-brief": DistPUloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
        ),
        "PULB": PULBloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            objective=False,
            momentum=args.momentum,
        ),
        "BalancePU": BalancedPUCEloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            objective=False,
            balance_weight=args.balance_weight,
        ),
        "PULB+BalancePU": BalancedPULBloss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            objective=False,
            momentum=args.momentum,
            balance_weight=args.balance_weight,
        ),
        "ScalePU-brief": ScalePULoss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            lambda_reg=args.lambda_reg
        ),
        "ScalePU-VarianceLambda": ScalePULossVarianceLambda(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            lambda_reg=args.lambda_reg,
            var_threshold=args.var_threshold,
            smooth_lambda=True
        ),
        "ScalePU-Focal": FocalScalePULoss(
            args.prior,
            loss=loss_type,
            gamma=args.gamma,
            beta=args.beta,
            lambda_reg=args.lambda_reg,
            var_threshold=args.var_threshold,
            smooth_lambda=True,
            focal_mode=args.focal_mode,
            focal_ratio=args.focal_ratio
        ),
        "ScalePU-Geometric": GeometricScalePULoss(
            args.prior,
            loss=loss_type,
            lambda_reg=args.lambda_reg,
            gamma_geo=args.gamma_geo if hasattr(args, 'gamma_geo') else 0.01,
            beta_sep=args.beta_sep if hasattr(args, 'beta_sep') else 1.0,
            margin=args.margin if hasattr(args, 'margin') else 1.0,
            var_threshold=args.var_threshold,
            smooth_lambda=True,
            focal_mode=args.focal_mode,
            focal_ratio=args.focal_ratio,
            beta_nnpu=args.beta,
            gamma_nnpu=args.gamma,
        ),
    }
    return loss_funcs[method_name]



class ModelWithFeatures(nn.Module):

    def __init__(self, base_model):
        super(ModelWithFeatures, self).__init__()
        self.base_model = base_model
        self.num_classifier = base_model.num_classifier if hasattr(base_model, 'num_classifier') else 1

        if hasattr(base_model, 'fc') and isinstance(base_model.fc, nn.Linear):
            self.fc = base_model.fc
            base_model.fc = nn.Identity()
            self.feature_extractor = base_model
        elif hasattr(base_model, 'classifier') and isinstance(base_model.classifier, nn.Linear):
            self.fc = base_model.classifier
            base_model.classifier = nn.Identity()
            self.feature_extractor = base_model
        else:
            modules = list(base_model.children())
            if len(modules) > 0:
                last_module = modules[-1]
                if isinstance(last_module, nn.Linear):
                    self.fc = last_module
                    self.feature_extractor = nn.Sequential(*modules[:-1])
                else:
                    self.fc = nn.Identity()
                    self.feature_extractor = base_model
            else:
                self.fc = nn.Identity()
                self.feature_extractor = base_model

    def forward(self, x, return_features=False):
        features = self.feature_extractor(x)

        if features.dim() > 2:
            features = features.view(features.size(0), -1)

        logits = self.fc(features)

        if return_features:
            return logits, features
        return logits


def select_model(model_name):
    models = {"cnn_cifar": cnn_cifar, "cnn_stl": cnn_stl,
              "lenet5": lenet_fmnist, "resnet50": ResNet50,
              'mlp': MultiLayerPerceptron}
    return models[model_name]


def make_optimizer(model, optim, stepsize):
    if optim == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=stepsize,
                                     betas=(0.5, 0.99),
                                     weight_decay=args.weight_decay)
    elif optim == "adagrad":
        optimizer = torch.optim.Adagrad(model.parameters(),
                                        lr=stepsize,
                                        weight_decay=args.weight_decay)
    elif optim == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=stepsize,
                                      weight_decay=args.weight_decay)
    elif optim == "sgd":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=stepsize,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif optim == "adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         lr=stepsize,
                                         rho=0.95)
    return optimizer


def make_scheduler(optim):
    scheduler = MultiStepLR(optim,
                            milestones=args.milestones,
                            gamma=args.scheduler_gamma)
    return scheduler


class trainer():
    def __init__(self, models, loss_funcs, optimizers, schedulers,
                 XYPtrainLoader, XYUtrainLoader, XYvalidLoader, XYtestLoader,
                 prior, loss_func_pn):
        self.models = models
        self.loss_funcs = loss_funcs.values()
        self.optimizers = optimizers.values()
        self.schedulers = schedulers.values()

        self.XYPtrainLoader = XYPtrainLoader
        self.XYUtrainLoader = XYUtrainLoader
        self.XYvalidLoader = XYvalidLoader
        self.XYtestLoader = XYtestLoader
        self.prior = prior
        self.loss_func_pn = loss_func_pn

        self.bound = 0

        self.use_geometric = "Geometric" in args.method

    def train(self, epoch):
        results_overall_val = {}
        results_class_val = {}
        results_overall_test = {}
        results_class_test = {}
        loss_all = {}

        for model_name in self.models.keys():
            results_overall_val[model_name] = []
            results_class_val[model_name] = []
            results_overall_test[model_name] = []
            results_class_test[model_name] = []
            loss_all[model_name] = []

        for net, opt, scheduler, loss_func, model_name in zip(
                self.models.values(), self.optimizers, self.schedulers,
                self.loss_funcs, self.models.keys()):
            XYUtrainLoader_iter = iter(self.XYUtrainLoader)
            XYPtrainLoader_iter = iter(self.XYPtrainLoader)
            net.train()
            total_steps = args.val_iterations

            if net.num_classifier == 1:
                outputs_p_all = np.array([]).reshape(0, 1)
                outputs_u_all = np.array([]).reshape(0, 1)
            else:
                outputs_p_all = np.array([]).reshape(0, 2)
                outputs_u_all = np.array([]).reshape(0, 2)
            targets_p_all = np.array([])
            targets_u_all = np.array([])

            targets_all = np.array([])
            predicts_all = np.array([])
            probs_all = np.array([])

            for i in range(args.val_iterations):
                # load positive data
                try:
                    data_p, t_p = next(XYPtrainLoader_iter)
                except:
                    XYPtrainLoader_iter = iter(self.XYPtrainLoader)
                    data_p, t_p = next(XYPtrainLoader_iter)

                # load unlabeled data
                try:
                    data_u, t_u = next(XYUtrainLoader_iter)
                except:
                    XYUtrainLoader_iter = iter(self.XYUtrainLoader)
                    data_u, t_u = next(XYUtrainLoader_iter)

                target_p = torch.ones(data_p.shape[0]).to(device,
                                                          non_blocking=True)
                target_u = - torch.ones(data_u.shape[0]).to(device,
                                                            non_blocking=True)

                t = torch.cat((t_p, t_u), dim=0)
                targets_all = np.hstack((targets_all, t.detach().cpu().numpy()))

                data_p = data_p.to(device, non_blocking=True)
                data_u = data_u.to(device, non_blocking=True)

                if self.use_geometric:
                    data = torch.cat((data_p, data_u), dim=0)
                    output, features = net(data, return_features=True)

                    len_p = data_p.shape[0]
                    output_p = output[:len_p]
                    output_u = output[len_p:]
                    features_p = features[:len_p]
                    features_u = features[len_p:]
                else:
                    data = torch.cat((data_p, data_u), dim=0)
                    idx_p = slice(0, len(data_p))
                    idx_u = slice(len(data_p), len(data))
                    output = net(data)

                size = len(t)
                p = np.reshape(torch.sigmoid(output).detach().cpu().numpy(), size)
                probs_all = np.hstack((probs_all, p))
                o = np.where(p > 0.5, 1, -1)
                predicts_all = np.hstack((predicts_all, o))

                # Compute loss
                if args.method == "PULB" or args.method == "PULB+BalancePU" \
                        or args.method == "PULB2" or args.method == "PULB2+BalancePU2":
                    # only for PULBLoss and BalancedPULBloss
                    if self.use_geometric:
                        loss, risk, risk_p, risk_u, risk_p_pos, risk_u_neg, risk_p_neg, current_bound = loss_func(
                            output_p, output_u, target_p, target_u, self.bound)
                    else:
                        loss, risk, risk_p, risk_u, risk_p_pos, risk_u_neg, risk_p_neg, current_bound = loss_func(
                            output[idx_p], output[idx_u], target_p, target_u, self.bound)
                    self.bound = current_bound
                elif self.use_geometric:
                    # for Geometric regularization
                    loss, risk_p, risk_u, risk_var, risk_geo, loss_dict = loss_func(
                        output_p, output_u, target_p, target_u,
                        features_p, features_u
                    )

                    if i == 0 and epoch % 10 == 0:
                        print(f"Epoch {epoch}: Compactness={loss_dict['compactness']:.4f}, "
                              f"Separation={loss_dict['separation']:.4f}, "
                              f"FocalVar={loss_dict['focal_var']:.4f}")
                elif "ScalePU" in args.method or "Variance" in args.method:
                    # only for ScalePU-breif and ScalePU
                    loss, risk_p, risk_u, risk_var \
                        = loss_func(output[idx_p], output[idx_u], target_p, target_u)
                else:
                    loss, risk, risk_p, risk_u, risk_p_pos, risk_u_neg, risk_p_neg = loss_func(
                        output[idx_p], output[idx_u], target_p, target_u)

                if self.use_geometric:
                    outputs_p_all = np.vstack((outputs_p_all, output_p.detach().cpu().numpy()))
                    outputs_u_all = np.vstack((outputs_u_all, output_u.detach().cpu().numpy()))
                else:
                    outputs_p_all = np.vstack((outputs_p_all, output[idx_p].detach().cpu().numpy()))
                    outputs_u_all = np.vstack((outputs_u_all, output[idx_u].detach().cpu().numpy()))
                targets_p_all = np.hstack((targets_p_all, target_p.detach().cpu().numpy()))
                targets_u_all = np.hstack((targets_u_all, target_u.detach().cpu().numpy()))

                opt.zero_grad()
                loss.backward()
                opt.step()

            outputs_p_all = torch.from_numpy(outputs_p_all).to(device)
            outputs_u_all = torch.from_numpy(outputs_u_all).to(device)
            targets_p_all = torch.from_numpy(targets_p_all).to(device)
            targets_u_all = torch.from_numpy(targets_u_all).to(device)

            if args.method == "PULB" or args.method == "PULB+BalancePU":
                # only for PULBLoss and BalancedPULBloss
                _, train_loss, train_loss_p, train_loss_u, train_loss_p_pos, train_loss_u_neg, train_loss_p_neg, _ \
                    = loss_func(outputs_p_all, outputs_u_all, targets_p_all, targets_u_all, self.bound)
            elif self.use_geometric:
                # Geometric regularization
                train_loss, train_loss_p, train_loss_u, train_loss_var, _, _ \
                    = loss_func(outputs_p_all, outputs_u_all, targets_p_all, targets_u_all, None, None)
                train_loss_p_pos, train_loss_u_neg, train_loss_p_neg = train_loss, train_loss, train_loss
            elif "ScalePU" in args.method or "Variance" in args.method:
                # only for ScalePU-breif and ScalePU
                train_loss, train_loss_p, train_loss_u, train_loss_var \
                    = loss_func(outputs_p_all, outputs_u_all, targets_p_all, targets_u_all)
                train_loss_p_pos, train_loss_u_neg, train_loss_p_neg = train_loss, train_loss, train_loss
            else:
                _, train_loss, train_loss_p, train_loss_u, train_loss_p_pos, train_loss_u_neg, train_loss_p_neg \
                    = loss_func(outputs_p_all, outputs_u_all, targets_p_all, targets_u_all)

            val_overall_metrics, val_class_metrics = multi_class_accuracy(probs_all, predicts_all, targets_all)

            test_overall_metrics, test_class_metrics, test_loss, test_loss_p, test_loss_n = test(args,
                                                                                                 self.XYtestLoader, net,
                                                                                                 self.loss_func_pn,
                                                                                                 device)
            results_overall_test[model_name].extend(test_overall_metrics)
            results_class_test[model_name].extend(test_class_metrics)

            results_overall_val[model_name].extend(val_overall_metrics)
            results_class_val[model_name].extend(val_class_metrics)

            loss_all[model_name].extend([test_loss, test_loss_p, test_loss_n, train_loss, train_loss_p, train_loss_u,
                                         train_loss_p_pos, train_loss_u_neg, train_loss_p_neg])

        return results_overall_val, results_class_val, results_overall_test, results_class_test, loss_all

    def run(self, Epochs):
        results_val = {}
        results_test = {}
        loss_all = {}
        # [precision,recall,error_rate, num_predicted_positive]
        metrics_overall = ['ACC', 'AUC', 'Macro_F1', 'Micro_F1', 'Precision', 'Recall', 'ErrorRate']
        metrics_class = ['Class_F1', 'Class_Precision', 'Class_Recall', 'Class_NPP']

        for model_name in self.models.keys():
            results_val[model_name] = [[] for _ in
                                       range(len(metrics_overall) + len(metrics_class) * args.num_classifier)]
            results_test[model_name] = [[] for _ in
                                        range(len(metrics_overall) + len(metrics_class) * args.num_classifier)]
            loss_all[model_name] = [[], [], [], [], [], [], [], [], []]

        print("Epoch\t" + "".join([
            "".join([
                # 首先输出整体指标
                "{}/{}\t".format(model_name, metric)
                for metric in metrics_overall
            ]) +
            "".join([
                # 然后对每个分类指标，依次输出所有类别
                "".join([
                    "{}/{}_{}\t".format(model_name, metric, j + 1)
                    for j in range(args.num_classifier)
                ])
                for metric in metrics_class
            ])
            for model_name in sorted(results_test.keys())
        ]) + "\n")

        for epoch in range(Epochs):
            _results_overall_val, _results_class_val, _results_overall_test, _results_class_test, _loss_all = self.train(
                epoch)

            print("{}\t".format(epoch) + "".join([
                # 打印总体指标
                "".join(["{:-8}\t".format(round(_results_overall_test[model_name][i], 4))
                         for i in range(len(metrics_overall))]) +
                # 打印类别指标，每个类别指标都是2个值
                "".join([
                    "".join(["{:-8}\t".format(round(value, 4))
                             for value in _results_class_test[model_name][j]])
                    for j in range(len(_results_class_test[model_name]))])
                for model_name in sorted(results_test.keys())
            ]) + "\n")

            for model_name in self.models.keys():
                for i in range(len(metrics_overall)):
                    results_test[model_name][i].append(_results_overall_test[model_name][i])
                    results_val[model_name][i].append(_results_overall_val[model_name][i])
                for i in range(len(metrics_class)):
                    for j in range(args.num_classifier):
                        results_test[model_name][args.num_classifier * i + len(metrics_overall) + j]. \
                            append(_results_class_test[model_name][i][j])
                        results_val[model_name][args.num_classifier * i + len(metrics_overall) + j]. \
                            append(_results_class_val[model_name][i][j])

                loss_all[model_name][0].append(_loss_all[model_name][0].item())
                loss_all[model_name][1].append(_loss_all[model_name][1].item())
                loss_all[model_name][2].append(_loss_all[model_name][2].item())
                loss_all[model_name][3].append(_loss_all[model_name][3].item())
                loss_all[model_name][4].append(_loss_all[model_name][4].item())
                loss_all[model_name][5].append(_loss_all[model_name][5].item())
                loss_all[model_name][6].append(_loss_all[model_name][6].item())
                loss_all[model_name][7].append(_loss_all[model_name][7].item())
                loss_all[model_name][8].append(_loss_all[model_name][8].item())

        return results_val, results_test, loss_all


def main():
    print("using:", device)

    image_datasets = ['mnist', 'fashionmnist', 'cifar10', 'stl10', 'alzheimer', 'imagenet']
    text_datasets = [
        'imdb', 'yelp_full', '20ng', 'yelp', 'amazon', 'amazon_full'
    ]

    text_flag = False
    dim = None
    pretrained_embedding = None
    vocab_size = None

    # dataset setup
    if args.dataset in image_datasets:
        (XYPtrainLoader, XYUtrainLoader, XYtrainLoader, XYvalidLoader, XYtestLoader,
         XYPtrainSet, dim,
         prior) = load_image_dataset(args.dataset, args.labeled,
                                     args.batchsize, args.positive_label_list, prior=args.prior)
        prior = args.prior

    # model setup
    loss_type = select_loss(args.loss, args.loss_gamma, args.loss_tau, args.loss_eta,
                            pi=args.prior, a0=args.a0, b0=args.b0, b1=args.b1, k1=args.k1, k2=args.k2, t=0.5)
    selected_model = select_model(args.model)

    if "Geometric" in args.method:
        base_model = selected_model()
        model = ModelWithFeatures(base_model).to(device)
        print(f"Using ModelWithFeatures wrapper for geometric regularization")
        print(f"Geometric parameters: γ={args.gamma_geo if hasattr(args, 'gamma_geo') else 0.01}, "
              f"β={args.beta_sep if hasattr(args, 'beta_sep') else 1.0}")
    else:
        model = selected_model()

    models = {
        args.method: copy.deepcopy(model).to(device),
    }
    models_pn = {
        "PN": copy.deepcopy(selected_model()).to(device),
    }

    loss_funcs = {
        args.method: select_loss_function(args.method, loss_type),
    }
    loss_func_pn = PNCEloss(args.prior, loss=loss_type)

    # trainer setup
    optimizers = {
        k: make_optimizer(v, args.optim, args.stepsize)
        for k, v in models.items()
    }
    schedulers = {k: make_scheduler(v) for k, v in optimizers.items()}

    optimizers_pn = {
        k: make_optimizer(v, args.optim, args.stepsize)
        for k, v in models_pn.items()
    }
    schedulers_pn = {k: make_scheduler(v) for k, v in optimizers_pn.items()}

    print("input dim: {}".format(dim))
    print("prior: {}".format(args.prior))
    print("loss: {}".format(args.loss))
    print("batchsize: {}".format(args.batchsize))
    print("model: {}".format(args.model))
    print("beta: {}".format(args.beta))
    print("gamma: {}".format(args.gamma))
    print("")

    # run training
    PUtrainer = trainer(models, loss_funcs, optimizers, schedulers,
                        XYPtrainLoader, XYUtrainLoader, XYvalidLoader,
                        XYtestLoader, args.prior, loss_func_pn)
    results_pu_val, results_pu_test, pu_loss = PUtrainer.run(args.epoch)

    save_path = "./results_scalePU_geometric/"
    save_result(save_path, results_pu_test, saved_mode="test")


def save_result(save_path, results_pu, saved_mode="test"):
    filename_result = save_path + "".join(["{}".format(model_name) for model_name in sorted(results_pu.keys())]) \
                          + "_{}_{}_{}_{}_{}_{}_{}_{}_{}".format(args.preset, args.model, args.stepsize, args.weight_decay,
                                                        args.labeled, args.focal_ratio, args.lambda_reg, args.gamma_geo, args.beta_sep)
    filename_result += "_{}_{}_{}.txt".format(args.seed, args.loss, saved_mode)

    metrics_overall = ['ACC', 'AUC', 'Macro_F1', 'Micro_F1', 'Precision', 'Recall', 'ErrorRate']
    metrics_class = ['Class_F1', 'Class_Precision', 'Class_Recall', 'Class_NPP']

    with open(filename_result, 'w') as file_result:
        file_result.write("Epoch\t" + "".join([
            "".join([
                "{}/{}\t".format(model_name, metric) for metric in metrics_overall
            ]) +
            "".join([
                "".join([
                    "{}/{}_{}\t".format(model_name, metric, j + 1) for j in range(args.num_classifier)
                ])
                for metric in metrics_class
            ])
            for model_name in sorted(results_pu.keys())
        ]) + "\n")

        for epoch in range(args.epoch):
            file_result.write("{}\t".format(epoch) + "".join([
                "".join(["{:-8}\t".format(round(results_pu[model_name][i][epoch], 4))
                         for i in range(len(results_pu[model_name]))])
                for model_name in sorted(results_pu.keys())
            ]) + "\n")


if __name__ == '__main__':
    import os
    import sys

    os.chdir(sys.path[0])
    print("working dir: {}".format(os.getcwd()))
    main()