import copy
import os
import random
from hmix_args import process_args

args = process_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

from hmix_image_dataset import load_image_dataset
from nnPU_loss import PNCEloss, PULLPLoss
from statistic import test
from hmix_model import MultiLayerPerceptron
from model.alexnet import alexnet
from model.cnn7 import cnn_cifar, cnn_stl
from model.lenet5 import lenet_fmnist, LeNet, LeNet_FMNIST
from model.ResNet_Zoo import ResNet18, ResNet50
from utils.misc import multi_class_accuracy
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def seed_torch(seed=args.seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # torch.backends.cudnn.enabled = True
    # torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True


seed_torch()


def select_loss(loss_name, gamma=1, tau=0.5, eta=1.0, pi=0.5, a0=2.0, b0=1.0, b1=0.8, k1=8.0, k2=8.0, t=0.5):
    losses = {
        # ===== 常见分类损失 =====
        "zero-one": lambda x: -0.5 * torch.sign(x) + 0.5,
        "hinge": lambda x: torch.clamp(1 - x, min=0),
        "exponential": lambda x: torch.exp(-x),
        "perceptron": lambda x: torch.clamp(-x, min=0),  # Logistic 的前身
        "logistic": lambda x: F.softplus(-x),
        "sigmoid": lambda x: torch.sigmoid(-x),

        # ===== 变体 =====
        "squared-hinge": lambda x: torch.clamp(1 - x, min=0) ** 2,  # Squared Hinge（Hinge 的平方版）
        "smooth-hinge": lambda x: torch.where(x < 0, 0.5 - x,
                                              torch.where(x < 1, 0.5 * (1 - x) ** 2, torch.zeros_like(x))), # Hinge 的平滑版
        "modified-huber": lambda x: torch.where(x >= -1, 0.25 * torch.clamp(-x, min=0) ** 2, -x),   # Hinge 与 Squared 结合的分段形式
        "squared": lambda x: (1 - x) ** 2,

        # ===== 鲁棒损失函数 =====
        "unhinged": lambda x: 1 - x,
        "tangent": lambda x: (2 * torch.atan(x) - 1) ** 2,
        "savage": lambda x: 1 / (1 + torch.exp(x)) ** 2,
        "focal": lambda x, gamma_focal=2.0: (1 - torch.sigmoid(x)) ** gamma_focal * F.softplus(-x), # Logistic 的变体，提升小概率样本的关注
        "gsigmoid": lambda x: torch.sigmoid(-gamma * x),  # Sigmoid 的 gamma 控制扩展
        "ramp": lambda x: torch.clamp(torch.min(torch.tensor(1.0, device=x.device), (1 - x) / 2), min=0),

        # ===== 特殊设计 =====
        "pinball": lambda x: torch.max(1 - x, -tau * (1 - x)),
        "rescaled-hinge": lambda x: eta * (1 - torch.exp(-eta * torch.clamp(1 - x, min=0))) / (1 - torch.exp(-eta)),
        "double-hinge": lambda x: torch.max(torch.stack([
            torch.zeros_like(x), (1 - x) / 2, -x
        ]), dim=0).values,

        # ===== mine =====
        "mine": lambda x: 1 - torch.exp(
            -calc_alpha(pi, a0, k1) *
            torch.clamp(1 - x, min=0) ** calc_beta(pi, b0, b1, k2, t)
        ),
    }

    return losses[loss_name]


def calc_alpha(pi, a0=2.0, k1=8.0):
    """计算 alpha(π) 参数"""
    # 将浮点数转换为标量张量
    pi_tensor = torch.tensor(pi, dtype=torch.float)
    return a0 / (1 + torch.exp(k1 * (pi_tensor - 0.5)))

def calc_beta(pi, b0=1.0, b1=0.8, k2=8.0, t=0.5):
    """计算 beta(π) 参数"""
    # 将浮点数转换为标量张量
    pi_tensor = torch.tensor(pi, dtype=torch.float)
    return b0 + b1 * torch.sigmoid(k2 * (pi_tensor - t))


def select_loss_function(method_name, loss_type):
    loss_funcs = {
        'PULLP': PULLPLoss(
            args.prior,
            loss=loss_type,
        ),
    }
    return loss_funcs[method_name]


def select_model(model_name):
    models = {"cnn_cifar": cnn_cifar, "cnn_stl": cnn_stl,
              "lenet5": LeNet, "resnet50": ResNet50,
              'mlp': MultiLayerPerceptron}
    return models[model_name]

def make_optimizer(model, optim, stepsize):
    if optim == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=stepsize,
                                     betas=(0.5, 0.99),
                                     weight_decay=args.weight_decay)
    elif optim == "adagrad":
        optimizer = torch.optim.Adagrad(model.parameters(),
                                        lr=stepsize,
                                        weight_decay=args.weight_decay)
    elif optim == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=stepsize,
                                      weight_decay=args.weight_decay)
    elif optim == "sgd":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=stepsize,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif optim == "adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         lr=stepsize,
                                         rho=0.95)
    return optimizer


def make_scheduler(optim):
    # scheduler = MultiStepLR(optim, milestones=[50, 100], gamma=0.1)
    scheduler = MultiStepLR(optim,
                            milestones=args.milestones,
                            gamma=args.scheduler_gamma)

    return scheduler


class trainer():
    def __init__(self, models, loss_funcs, optimizers, schedulers,
                 XYPtrainLoader, XYUtrainLoader, XYvalidLoader, XYtestLoader,
                 prior, loss_func_pn):
        self.models = models
        self.loss_funcs = loss_funcs.values()
        self.optimizers = optimizers.values()
        self.schedulers = schedulers.values()

        self.XYPtrainLoader = XYPtrainLoader
        self.XYUtrainLoader = XYUtrainLoader
        self.XYvalidLoader = XYvalidLoader
        self.XYtestLoader = XYtestLoader
        self.prior = prior
        self.loss_func_pn = loss_func_pn

        # only for PULBLoss and BalancedPULBloss
        self.bound = 0

    def train(self, epoch):
        results_overall_val = {}
        results_class_val = {}
        results_overall_test = {}
        results_class_test = {}

        for model_name in self.models.keys():
            results_overall_val[model_name] = []
            results_class_val[model_name] = []
            results_overall_test[model_name] = []
            results_class_test[model_name] = []

        for net, opt, scheduler, loss_func, model_name in zip(
                self.models.values(), self.optimizers, self.schedulers,
                self.loss_funcs, self.models.keys()):
            XYUtrainLoader_iter = iter(self.XYUtrainLoader)
            XYPtrainLoader_iter = iter(self.XYPtrainLoader)
            net.train()

            targets_all = np.array([])
            predicts_all = np.array([])
            probs_all = np.array([])

            for i in range(args.val_iterations):
                # load positive data
                try:
                    data_p, t_p = next(XYPtrainLoader_iter)
                except:
                    XYPtrainLoader_iter = iter(self.XYPtrainLoader)
                    data_p, t_p = next(XYPtrainLoader_iter)

                # load unlabeled data
                try:
                    data_u, t_u = next(XYUtrainLoader_iter)
                except:
                    XYUtrainLoader_iter = iter(self.XYUtrainLoader)
                    data_u, t_u = next(XYUtrainLoader_iter)

                target_p = torch.ones(data_p.shape[0]).to(device,
                                                          non_blocking=True)
                target_u = - torch.ones(data_u.shape[0]).to(device,
                                                           non_blocking=True)

                t = torch.cat((t_p, t_u), dim=0)
                targets_all = np.hstack((targets_all, t.detach().cpu().numpy()))

                data_p = data_p.to(device, non_blocking=True)
                data_u = data_u.to(device, non_blocking=True)

                data = torch.cat((data_p, data_u), dim=0)
                idx_p = slice(0, len(data_p))
                idx_u = slice(len(data_p), len(data))

                output = net(data)  # get output for every net

                size = len(t)
                p = np.reshape(torch.sigmoid(output).detach().cpu().numpy(), size)
                probs_all = np.hstack((probs_all, p))
                # o = np.reshape(torch.sign(outputs).detach().cpu().numpy(), size)
                o = np.where(p > 0.5, 1, -1)
                predicts_all = np.hstack((predicts_all, o))

                logits_u = output[idx_u].squeeze(1)  # shape: [N_u]
                num_bags = args.n_bag
                bag_size = logits_u.size(0) // num_bags
                logits_u = logits_u[:num_bags * bag_size]
                logits_u_bags = torch.chunk(logits_u, num_bags)  # List[Tensor[bag_size]]

                loss = loss_func(output[idx_p], target_p, logits_u_bags)

                # ############### end add

                opt.zero_grad()  # clear gradients for next train
                loss.backward()  # backpropagation, compute gradients
                opt.step()  # apply gradients

            val_overall_metrics, val_class_metrics = multi_class_accuracy(probs_all, predicts_all, targets_all)

            test_overall_metrics, test_class_metrics, test_loss, test_loss_p, test_loss_n = test(args, self.XYtestLoader, net, self.loss_func_pn, device)
            results_overall_test[model_name].extend(test_overall_metrics)
            results_class_test[model_name].extend(test_class_metrics)

            results_overall_val[model_name].extend(val_overall_metrics)
            results_class_val[model_name].extend(val_class_metrics)

        return results_overall_val, results_class_val, results_overall_test, results_class_test

    def run(self, Epochs):
        results_val = {}
        results_test = {}
        # [precision,recall,error_rate, num_predicted_positive]
        metrics_overall = ['ACC', 'AUC', 'Macro_F1', 'Micro_F1', 'Precision', 'Recall', 'ErrorRate']
        metrics_class = ['Class_F1', 'Class_Precision', 'Class_Recall', 'Class_NPP']

        for model_name in self.models.keys():
            results_val[model_name] = [[] for _ in range(len(metrics_overall) + len(metrics_class) * args.num_classifier)]
            results_test[model_name] = [[] for _ in range(len(metrics_overall) + len(metrics_class) * args.num_classifier)]

        print("Epoch\t" + "".join([
            "".join([
                # 首先输出整体指标
                "{}/{}\t".format(model_name, metric)
                for metric in metrics_overall
            ]) +
            "".join([
                # 然后对每个分类指标，依次输出所有类别
                "".join([
                    "{}/{}_{}\t".format(model_name, metric, j + 1)
                    for j in range(args.num_classifier)
                ])
                for metric in metrics_class
            ])
            for model_name in sorted(results_test.keys())
        ]) + "\n")


        for epoch in range(Epochs):
            _results_overall_val, _results_class_val, _results_overall_test, _results_class_test = self.train(epoch)

            print("{}\t".format(epoch) + "".join([
                # 打印总体指标
                "".join(["{:-8}\t".format(round(_results_overall_test[model_name][i], 4))
                         for i in range(len(metrics_overall))]) +
                # 打印类别指标，每个类别指标都是2个值
                "".join([
                    "".join(["{:-8}\t".format(round(value, 4))
                             for value in _results_class_test[model_name][j]])
                    for j in range(len(_results_class_test[model_name]))])
                for model_name in sorted(results_test.keys())
            ]) + "\n")

            for model_name in self.models.keys():
                for i in range(len(metrics_overall)):
                    results_test[model_name][i].append(_results_overall_test[model_name][i])
                    results_val[model_name][i].append(_results_overall_val[model_name][i])
                for i in range(len(metrics_class)):
                    for j in range(args.num_classifier):
                        results_test[model_name][args.num_classifier * i + len(metrics_overall) + j].\
                            append(_results_class_test[model_name][i][j])
                        results_val[model_name][args.num_classifier * i + len(metrics_overall) + j].\
                            append(_results_class_val[model_name][i][j])

        return results_val, results_test



def main():
    # args = process_args()
    print("using:", device)

    image_datasets = ['mnist', 'fashionmnist', 'cifar10', 'stl10', 'alzheimer']
    text_datasets = [
        'imdb', 'yelp_full', '20ng', 'yelp', 'amazon', 'amazon_full'
    ]

    text_flag = False
    dim = None
    pretrained_embedding = None
    vocab_size = None

    # dataset setup
    if args.dataset in image_datasets:
        (XYPtrainLoader, XYUtrainLoader, XYtrainLoader, XYvalidLoader, XYtestLoader,
         XYPtrainSet, dim,
         prior) = load_image_dataset(args.dataset, args.labeled,
                                     args.batchsize, args.positive_label_list)
        prior = args.prior

    # model setup
    loss_type = select_loss(args.loss, args.loss_gamma, args.loss_tau, args.loss_eta,
                            pi=args.prior, a0=args.a0, b0=args.b0, b1=args.b1, k1=args.k1, k2=args.k2, t=0.5)
    selected_model = select_model(args.model)
    model = selected_model(int(dim))
    models = {
        args.method: copy.deepcopy(model).to(device),
    }
    models_pn = {
        "PN": copy.deepcopy(model).to(device),
    }

    loss_funcs = {
        args.method: select_loss_function(args.method, loss_type),
    }
    loss_func_pn = PNCEloss(args.prior, loss=loss_type)

    # trainer setup
    optimizers = {
        k: make_optimizer(v, args.optim, args.stepsize)
        for k, v in models.items()
    }
    schedulers = {k: make_scheduler(v) for k, v in optimizers.items()}

    optimizers_pn = {
        k: make_optimizer(v, args.optim, args.stepsize)
        for k, v in models_pn.items()
    }
    schedulers_pn = {k: make_scheduler(v) for k, v in optimizers_pn.items()}

    print("input dim: {}".format(dim))
    print("prior: {}".format(args.prior))
    print("loss: {}".format(args.loss))
    print("batchsize: {}".format(args.batchsize))
    print("model: {}".format(args.model))
    print("beta: {}".format(args.beta))
    print("gamma: {}".format(args.gamma))
    print("")

    # run training
    # training (nnPU, uPU)
    PUtrainer = trainer(models, loss_funcs, optimizers, schedulers,
                        XYPtrainLoader, XYUtrainLoader, XYvalidLoader,
                        XYtestLoader, args.prior, loss_func_pn)
    results_pu_val, results_pu_test = PUtrainer.run(args.epoch)

    # PNtrainer = trainer_pn(models_pn, loss_funcs_pn, optimizers_pn,
    #                        schedulers_pn, XYtrainLoader, XYvalidLoader,
    #                        XYtestLoader, prior)
    # results_pn_val, results_pn_test = PNtrainer.run(args.epoch)

    save_path = "./results_bag/"
    save_result(save_path, results_pu_val, saved_mode="val")
    save_result(save_path, results_pu_test, saved_mode="test")


def save_result(save_path, results_pu, saved_mode="test"):

    filename_result = save_path + "".join(["{}".format(model_name) for model_name in sorted(results_pu.keys())]) \
                          + "_{}_{}_{}_{}_{}_{}".format(args.preset, args.positive_label_list, args.model, args.stepsize, args.weight_decay, args.labeled)

    if args.loss == 'gsigmoid':
        filename_result += "_{}_{}{}_{}.txt".format(args.seed, args.loss, args.loss_gamma, saved_mode)
    elif args.loss == 'pinball':
        filename_result += "_{}_{}{}_{}.txt".format(args.seed, args.loss, args.loss_tau, saved_mode)
    elif args.loss == 'rescaled_hinge':
        filename_result += "_{}_{}{}_{}.txt".format(args.seed, args.loss, args.loss_eta, saved_mode)
    elif args.loss == 'mine':
        filename_result += "_{}_{}_{}_{}_{}_{}_{}_{}.txt".format(args.seed, args.loss, args.a0, args.b0, args.b1, args.k1, args.k2, saved_mode)
    else:
        filename_result += "_{}_{}_{}_{}_{}.txt".format(args.seed, args.loss, args.batchsize, args.n_bag, saved_mode)

    metrics_overall = ['ACC', 'AUC', 'Macro_F1', 'Micro_F1', 'Precision', 'Recall', 'ErrorRate']
    metrics_class = ['Class_F1', 'Class_Precision', 'Class_Recall', 'Class_NPP']

    with open(filename_result, 'w') as file_result:
        file_result.write("Epoch\t" + "".join([
            "".join([
                "{}/{}\t".format(model_name, metric) for metric in metrics_overall
            ]) +
            "".join([
                "".join([
                    "{}/{}_{}\t".format(model_name, metric, j + 1) for j in range(args.num_classifier)
                ])
                for metric in metrics_class
            ])
            for model_name in sorted(results_pu.keys())
        ]) + "\n")

        for epoch in range(args.epoch):
            file_result.write("{}\t".format(epoch) + "".join([
                    "".join(["{:-8}\t".format(round(results_pu[model_name][i][epoch], 4))
                             for i in range(len(results_pu[model_name]))])  # 循环处理每个指标的所有类别
            for model_name in sorted(results_pu.keys())
            ]) + "\n")



if __name__ == '__main__':
    import os
    import sys
    os.chdir(sys.path[0])
    print("working dir: {}".format(os.getcwd()))
    main()
