import numpy as np
import torch
import torch.nn as nn
from utils.utils_loss import logistic_loss
import torch.nn.functional as F
from utils.utils_algo import accuracy_check, update_ema
from torch import linalg as LA
from utils.utils_mixup import mixup_two_pairs
from utils.utils_ramps import sigmoid_rampdown
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR, LinearLR, StepLR

def pretrainLR(model, given_train_loader, test_loader, train_eval_loader, args, loss_fn, device, if_write=False, save_path=""):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.pretrain_lr, weight_decay=args.wd)
    for epoch in range(args.pretrain_ep):
        model.train()
        for (X, y) in given_train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(X)[:,0]
            pos_index, neg_index = (y == 1), (y == -1)
            total_num = pos_index.sum() + neg_index.sum()
            pos_train_loss, neg_train_loss = 0.0, 0.0
            if pos_index.sum() > 0:
                pos_train_loss = (loss_fn(outputs[pos_index])).sum()
            if neg_index.sum() > 0:
                neg_train_loss = (loss_fn(-outputs[neg_index])).sum()
            train_loss = (pos_train_loss + neg_train_loss) / total_num
            train_loss.backward()
            optimizer.step()
        model.eval()
        train_eval_acc = accuracy_check(loader=train_eval_loader, model=model, device=device)
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)
        print('#epoch', epoch+1, ': train_loss', train_loss.data.item(), ' train_accuracy', train_eval_acc, ' test_accuracy', test_acc)
        if if_write:
            with open(save_path, "a") as f:
                f.writelines("{},{:.6f},{:.6f},{:.6f}\n".format(epoch + 1, train_loss.data.item(), train_eval_acc, test_acc))
        if epoch >= (args.pretrain_ep-10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list), model


def PcompUnbiased(model, given_train_loader, test_loader, args, loss_fn, device):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    prior = args.prior

    for epoch in range(args.ep):
        model.train()
        for (X, y) in given_train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(X)[:, 0]
            pos_index, neg_index = (y == 1), (y == -1)
            pos_train_loss, neg_train_loss = 0.0, 0.0
            if pos_index.sum() > 0:
                pos_train_loss = (loss_fn(outputs[pos_index]) - prior * loss_fn(-outputs[pos_index])).mean()
            if neg_index.sum() > 0:
                neg_train_loss = (loss_fn(-outputs[neg_index]) - (1 - prior) * loss_fn(outputs[neg_index])).mean()
            train_loss = pos_train_loss + neg_train_loss
            train_loss.backward()
            optimizer.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)
        print('#epoch', epoch + 1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if epoch >= (args.ep - 10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)


def PcompTeacher(model, ema_model, given_train_loader, test_loader, args, loss_fn, device):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    prior = args.prior
    global_step = 0
    cons_weight = args.ema_weight
    ema_decay = args.ema_alpha
    inverse_pos_noise_rate = ((1 - prior) * (1 - prior)) / (
                1 - prior * (1 - prior))  # p(y=0|\tilde{y}=1) = ((1-prior)*(1-prior)) / (1-prior*(1-prior))
    inverse_neg_noise_rate = (prior * prior) / (
                1 - prior * (1 - prior))  # p(y=1|\tilde{y}=0) = (prior*prior) / (1-prior*(1-prior))
    pos_noise_rate = prior / (1 + prior)  # transition matrix, p(\tilde{y}=0|y=1) = prior/(1+prior))
    neg_noise_rate = (1 - prior) / (1 + (1 - prior))  # p(\tilde{y}=1|y=0) = (1-prior)/(1+(1-prior))

    for epoch in range(args.ep):
        model.train()
        for (X, y) in given_train_loader:
            X, y = X.to(device), y.to(device)
            global_step += 1
            optimizer.zero_grad()
            outputs = model(X)[:, 0]
            update_ema(model, ema_model, ema_decay, global_step)
            with torch.no_grad():
                ema_outputs = ema_model(X)[:, 0]
            pos_index, neg_index = (y == 1), (y == -1)
            pos_train_loss, neg_train_loss = 0.0, 0.0
            if pos_index.sum() > 0:
                pos_train_loss = loss_fn(outputs[pos_index])
                pos_sorted_index = np.argsort(pos_train_loss.cpu().data)  # from small to large
                num_pred_pos = int((1 - inverse_pos_noise_rate) * len(pos_sorted_index))
                updated_pos_index = pos_sorted_index[:num_pred_pos]
                pos_train_loss = 1 / (1 - pos_noise_rate) * pos_train_loss[updated_pos_index].mean()
            if neg_index.sum() > 0:
                neg_train_loss = loss_fn(-outputs[neg_index])
                neg_sorted_index = np.argsort(neg_train_loss.cpu().data)  # from small to large
                num_pred_neg = int((1 - inverse_neg_noise_rate) * len(neg_sorted_index))
                updated_neg_index = neg_sorted_index[:num_pred_neg]
                neg_train_loss = 1 / (1 - neg_noise_rate) * neg_train_loss[updated_neg_index].mean()
            label_loss = pos_train_loss + neg_train_loss
            cons_loss = cons_weight * F.mse_loss(outputs, ema_outputs)
            train_loss = label_loss + cons_loss
            train_loss.backward()
            optimizer.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)
        print('#epoch', epoch + 1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if epoch >= (args.ep - 10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)


def PcompReLU(model, given_train_loader, test_loader, args, loss_fn, device):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    prior = args.prior
    lda = torch.tensor([0.0]).to(device)
    for epoch in range(args.ep):
        model.train()
        for (X, y) in given_train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(X)[:, 0]
            pos_index, neg_index = (y == 1), (y == -1)
            pos_train_loss, neg_train_loss = 0.0, 0.0
            if pos_index.sum() > 0 and neg_index.sum() > 0:
                pos_train_loss = torch.max(
                    (loss_fn(outputs[pos_index]).mean() - (1 - prior) * loss_fn(outputs[neg_index]).mean()), lda)
                neg_train_loss = torch.max(
                    (loss_fn(-outputs[neg_index]).mean() - prior * loss_fn(-outputs[pos_index]).mean()), lda)
            elif pos_index.sum() > 0 and neg_index.sum() == 0:
                pos_train_loss = torch.max(loss_fn(outputs[pos_index]).mean(), lda)
                neg_train_loss = torch.max(-prior * loss_fn(-outputs[pos_index]).mean(), lda)
            elif pos_index.sum() == 0 and neg_index.sum() > 0:
                pos_train_loss = torch.max(- (1 - prior) * loss_fn(outputs[neg_index]).mean(), lda)
                neg_train_loss = torch.max(loss_fn(-outputs[neg_index]).mean(), lda)
            train_loss = pos_train_loss + neg_train_loss
            train_loss.backward()
            optimizer.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)

        print('#epoch', epoch + 1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if epoch >= (args.ep - 10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)


def PcompABS(model, given_train_loader, test_loader, args, loss_fn, device):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    prior = args.prior

    for epoch in range(args.ep):
        model.train()
        for (X, y) in given_train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(X)[:, 0]
            pos_index, neg_index = (y == 1), (y == -1)
            pos_train_loss, neg_train_loss = 0.0, 0.0
            if pos_index.sum() > 0 and neg_index.sum() > 0:
                pos_train_loss = torch.abs(
                    (loss_fn(outputs[pos_index]).mean() - (1 - prior) * loss_fn(outputs[neg_index]).mean()))
                neg_train_loss = torch.abs(
                    (loss_fn(-outputs[neg_index]).mean() - prior * loss_fn(-outputs[pos_index]).mean()))
            elif pos_index.sum() > 0 and neg_index.sum() == 0:
                pos_train_loss = torch.abs(loss_fn(outputs[pos_index]).mean())
                neg_train_loss = torch.abs(- prior * loss_fn(-outputs[pos_index]).mean())
            elif pos_index.sum() == 0 and neg_index.sum() > 0:
                pos_train_loss = torch.abs(-(1 - prior) * loss_fn(outputs[neg_index]).mean())
                neg_train_loss = torch.abs(loss_fn(-outputs[neg_index]).mean())
            train_loss = pos_train_loss + neg_train_loss
            train_loss.backward()
            optimizer.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)

        print('#epoch', epoch + 1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if epoch >= (args.ep - 10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)


def ConfDiffUnbiased(model, given_train_loader, test_loader, args, loss_fn, device, if_write=False, save_path=""):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    # if(args.uci == 1):
    #     scheduler = StepLR(optimizer, step_size=100, gamma=1)
    # else:
    #     if(args.prior == 0.5):
    #         scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
    #     else:
    #         scheduler = StepLR(optimizer, step_size=10, gamma=1)
    prior = args.prior

    for epoch in range(args.ep):
        model.train()
        for (X1, X2, conf, y1, y2) in given_train_loader:
            X1, X2, conf = X1.to(device), X2.to(device), conf.to(device)
            optimizer.zero_grad()
            outputs1 = model(X1)[:,0]
            outputs2 = model(X2)[:,0]
            # loss1 = ((prior - conf) * loss_fn(outputs1)).mean()
            # loss2 = ((1 - prior + conf) * loss_fn(-outputs1)).mean()
            # loss3 = ((prior + conf) * loss_fn(outputs2)).mean()
            # loss4 = ((1 - prior - conf) * loss_fn(-outputs2)).mean()
            train_loss = (prior - conf) * loss_fn(outputs1) + (1 - prior + conf) * loss_fn(-outputs1) + (prior + conf) * loss_fn(outputs2) + (1 - prior - conf) * loss_fn(-outputs2)
            # train_loss = loss1 + loss2 + loss3 + loss4
            train_loss = 0.5 * train_loss.mean()
            train_loss.backward()
            optimizer.step()
        # scheduler.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)
        # print('#loss1:', loss1.data.item(), '#loss2:', loss2.data.item(), '#loss3:', loss3.data.item(), '#loss4:', loss4.data.item())
        print('#epoch', epoch+1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if if_write:
            with open(save_path, "a") as f:
                f.writelines("{},{:.6f},{:.6f}\n".format(epoch + 1, train_loss.data.item(), test_acc))
        if epoch >= (args.ep-10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)

def ConfDiffReLU(model, given_train_loader, test_loader, args, loss_fn, device, if_write=False, save_path=""):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    # if (args.uci == 1):
    #     scheduler = StepLR(optimizer, step_size=100, gamma=1)
    # else:
    #     if (args.prior == 0.5):
    #         scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
    #     else:
    #         scheduler = StepLR(optimizer, step_size=10, gamma=1)
    prior = args.prior
    lda = torch.tensor([0.0]).to(device)

    for epoch in range(args.ep):
        model.train()
        for (X1, X2, conf, y1, y2) in given_train_loader:
            X1, X2, conf = X1.to(device), X2.to(device), conf.to(device)
            optimizer.zero_grad()
            outputs1 = model(X1)[:,0]
            outputs2 = model(X2)[:,0]
            train_loss1 = torch.max(((prior - conf) * loss_fn(outputs1)).mean(), lda)
            train_loss2 = torch.max(((1 - prior + conf) * loss_fn(-outputs1)).mean(), lda)
            train_loss3 = torch.max(((prior + conf) * loss_fn(outputs2)).mean(), lda)
            train_loss4 = torch.max(((1 - prior - conf) * loss_fn(-outputs2)).mean(), lda)
            train_loss = 0.5 * (train_loss1 + train_loss2 + train_loss3 + train_loss4)
            train_loss.backward()
            optimizer.step()
        # scheduler.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)
        print('#epoch', epoch+1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if if_write:
            with open(save_path, "a") as f:
                f.writelines("{},{:.6f},{:.6f}\n".format(epoch + 1, train_loss.data.item(), test_acc))

        if epoch >= (args.ep-10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)

def ConfDiffABS(model, given_train_loader, test_loader, args, loss_fn, device, if_write=False, save_path=""):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    # if (args.uci == 1):
    #     scheduler = StepLR(optimizer, step_size=100, gamma=1)
    # else:
    #     if (args.prior == 0.5):
    #         scheduler = StepLR(optimizer, step_size=10, gamma=1)
    #     else:
    #         scheduler = StepLR(optimizer, step_size=10, gamma=1)
    prior = args.prior
    lda = torch.tensor([0.0]).to(device)

    for epoch in range(args.ep):
        model.train()
        for (X1, X2, conf, y1, y2) in given_train_loader:
            X1, X2, conf = X1.to(device), X2.to(device), conf.to(device)
            optimizer.zero_grad()
            outputs1 = model(X1)[:,0]
            outputs2 = model(X2)[:,0]
            train_loss1 = torch.abs(((prior - conf) * loss_fn(outputs1)).mean())
            train_loss2 = torch.abs(((1 - prior + conf) * loss_fn(-outputs1)).mean())
            train_loss3 = torch.abs(((prior + conf) * loss_fn(outputs2)).mean())
            train_loss4 = torch.abs(((1 - prior - conf) * loss_fn(-outputs2)).mean())
            train_loss = 0.5 * (train_loss1 + train_loss2 + train_loss3 + train_loss4)
            train_loss.backward()
            optimizer.step()
        # scheduler.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)

        print('#epoch', epoch+1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if if_write:
            with open(save_path, "a") as f:
                f.writelines("{},{:.6f},{:.6f}\n".format(epoch + 1, train_loss.data.item(), test_acc))
        if epoch >= (args.ep-10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)


def Normal(model, given_train_loader, test_loader, args, loss_fn, device, if_write=False, save_path=""):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    prior = args.prior
    lda = torch.tensor([0.0]).to(device)

    for epoch in range(args.ep):
        model.train()
        for (X1, X2, conf, _, _) in given_train_loader:
            X1, X2, conf = X1.to(device), X2.to(device), conf.to(device)
            conf[conf!=0] = 0
            optimizer.zero_grad()
            outputs1 = model(X1)[:, 0]
            outputs2 = model(X2)[:, 0]
            train_loss1 = torch.abs(((prior - conf) * loss_fn(outputs1)).mean())
            train_loss2 = torch.abs(((1 - prior + conf) * loss_fn(-outputs1)).mean())
            train_loss3 = torch.abs(((prior + conf) * loss_fn(outputs2)).mean())
            train_loss4 = torch.abs(((1 - prior - conf) * loss_fn(-outputs2)).mean())
            train_loss = 0.5 * (train_loss1 + train_loss2 + train_loss3 + train_loss4)
            train_loss.backward()
            optimizer.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)
        print('#epoch', epoch + 1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if if_write:
            with open(save_path, "a") as f:
                f.writelines("{},{:.6f},{:.6f}\n".format(epoch + 1, train_loss.data.item(), test_acc))
        if epoch >= (args.ep - 10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)



# def nadam(grad, x, m, v, i, step_size=0.001, b1=0.9, b2=0.999, eps=10**-8):
#     """Nadam ad described in `Incorporating Nesterov Momentum into Adam`.
#     It combines Adam and NAG.
#     """
#     m = (1 - b1) * grad + b1 * m
#     v = (1 - b2) * (grad**2) + b2 * v
#     mhat = m / (1 - b1**(i + 1))
#     vhat = v / (1 - b2**(i + 1))
#     x = x - step_size * (b1 * mhat + (1 - b1) * grad /
#                          (1 - b1**(i + 1))) / (np.sqrt(vhat) + eps)
#     return x, m, v


def ConfDiffABS_updateC(model, given_train_loader, test_loader, args, loss_fn, device, if_write=False, save_path=""):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    prior = args.prior
    lda = torch.tensor([0.0]).to(device)

    # given_train_dataset = given_train_loader.dataset
    # X1, X2, conf, y1, y2 = given_train_dataset.data1, given_train_dataset.data2, given_train_dataset.confidence, \
    #                        given_train_dataset.true_label1, given_train_dataset.true_label2
    # X1, X2, conf, y1, y2 = X1.to(device), X2.to(device), conf.to(device), y1.to(device), y2.to(device)
    #
    # C = torch.ones_like(conf)
    # E = torch.ones_like(conf)
    #
    # Y1 = conf / LA.norm(conf)
    # mu1_reg = conf.shape[0] / 4 / LA.norm(conf, 1)
    # rho = 6

    for epoch in range(args.ep):
        model.train()

        for (X1, X2, conf, _, _) in given_train_loader:
            X1, X2, conf = X1.to(device), X2.to(device), conf.to(device)
            optimizer.zero_grad()
            outputs1 = model(X1)[:, 0]
            outputs2 = model(X2)[:, 0]

            C = conf.clone()
            E = torch.zeros_like(conf)
            Q = C.clone()
            Y1 = conf / LA.norm(conf)
            Y2 = C / LA.norm(C)
            mu1_reg = conf.shape[0] / 2 / LA.norm(conf, 1)
            mu2_reg = mu1_reg
            rho = 100   # 100 for method-1

            alpha_reg = args.alpha
            beta_reg = args.beta

            for iter in range(50):  # 25 for method-1
                # ===== method-1 =====
                # # E fixed, update C
                # C = conf - E + Y1/mu1_reg - 0.5/mu1_reg * (-loss_fn(outputs1) - loss_fn(-outputs1) + loss_fn(outputs2) + loss_fn(-outputs2))
                # # C fixed, update E
                # E = conf - C + Y1/mu1_reg
                #
                # # Lag
                # Y1 = Y1 + mu1_reg * (conf - C - E)
                # mu1_reg = min(rho * mu1_reg, 1e6)

                # ===== method-2 =====
                def shrinkage(data, value):
                    S1 = data - value
                    S2 = data + value
                    S1[S1 < 0] = 0
                    S2[S2 > 0] = 0
                    B = S1 + S2
                    return B

                # E, Q fixed, update C
                C = 1/(mu1_reg+mu2_reg) * ( -0.5 * (-loss_fn(outputs1) - loss_fn(-outputs1) + loss_fn(outputs2) + loss_fn(-outputs2))
                                             + mu1_reg * (conf - E + Y1 / mu1_reg) - mu2_reg * (-Q + Y2 / mu2_reg))
                # C, E fixed, update Q
                Q = shrinkage(C + Y2 / mu2_reg, -alpha_reg/mu2_reg)

                # C, Q fixed, update E
                E = shrinkage(conf - C + Y1 / mu1_reg, beta_reg / mu1_reg)

                # Lag
                Y1 = Y1 + mu1_reg * (conf - C - E)
                Y2 = Y2 + mu2_reg * (C - Q)
                mu1_reg = min(rho * mu1_reg, 1e6)
                mu2_reg = min(rho * mu2_reg, 1e6)

            train_loss1 = torch.abs(((prior - C) * loss_fn(outputs1)).mean())
            train_loss2 = torch.abs(((1 - prior + C) * loss_fn(-outputs1)).mean())
            train_loss3 = torch.abs(((prior + C) * loss_fn(outputs2)).mean())
            train_loss4 = torch.abs(((1 - prior - C) * loss_fn(-outputs2)).mean())


            # mixup
            if epoch >= 5:
                # one_index = torch.where((C <= -0.5) | (C >= 0.5))
                # zero_index = torch.where((C >= -0.5) & (C <= 0.5))
                one_index = torch.where((C <= -0.75) | (C >= 0.75))  # all for 1/-1
                zero_index = torch.where((C > -0.25) & (C < 0.25))  # all for 0
                train_one_data1 = X1[one_index]
                train_one_data2 = X2[one_index]
                train_one_conf = C[one_index]
                train_zero_data1 = X1[zero_index]
                train_zero_data2 = X2[zero_index]
                train_zero_conf = C[zero_index]

                mixed_inputs1, mixed_inputs2, mixed_conf, lam = mixup_two_pairs((train_one_data1, train_one_data2), (train_zero_data1, train_zero_data2),
                                                                                train_one_conf, train_zero_conf, 6)
                mixed_outputs1 = model(mixed_inputs1)[:, 0]
                mixed_outputs2 = model(mixed_inputs2)[:, 0]
                loss_mixup = torch.abs(((prior - mixed_conf) * loss_fn(mixed_outputs1)).mean()) + \
                             torch.abs(((1 - prior + mixed_conf) * loss_fn(-mixed_outputs1)).mean()) + \
                             torch.abs(((prior + mixed_conf) * loss_fn(mixed_outputs2)).mean()) + \
                             torch.abs(((1 - prior - mixed_conf) * loss_fn(-mixed_outputs2)).mean())
            else:
                loss_mixup = 0


            train_loss = 0.5 * (train_loss1 + train_loss2 + train_loss3 + train_loss4)
            train_loss += 0.5 * loss_mixup * 0.01

            train_loss.backward()
            optimizer.step()

        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)

        print('#epoch', epoch + 1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if if_write:
            with open(save_path, "a") as f:
                f.writelines("{},{:.6f},{:.6f}\n".format(epoch + 1, train_loss.data.item(), test_acc))
        if epoch >= (args.ep - 10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)


def ConfDiffUnbiased_new(model, given_train_loader, test_loader, args, loss_fn, device, if_write=False, save_path=""):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    # if (args.uci == 1):
    #     scheduler = StepLR(optimizer, step_size=100, gamma=1)
    # else:
    #     if (args.prior == 0.5):
    #         scheduler = StepLR(optimizer, step_size=40, gamma=0.1)
    #     else:
    #         scheduler = StepLR(optimizer, step_size=20, gamma=0.1)
    prior = args.prior

    for epoch in range(args.ep):
        model.train()
        for (X1, X2, conf, y1, y2) in given_train_loader:
            X1, X2, conf = X1.to(device), X2.to(device), conf.to(device)
            optimizer.zero_grad()
            outputs1 = model(X1)[:,0]
            outputs2 = model(X2)[:,0]
            one_index = torch.where((conf <= -args.bound) | (conf >= args.bound))  # all for 1/-1
            zero_index = torch.where((conf > -args.bound) & (conf < args.bound))  # all for 0
            train_loss = (prior - conf[one_index]) * loss_fn(outputs1[one_index]) + (1 - prior + conf[one_index]) * loss_fn(-outputs1[one_index]) + \
                         (prior + conf[one_index]) * loss_fn(outputs2[one_index]) + (1 - prior - conf[one_index]) * loss_fn(-outputs2[one_index])
            train_loss5 = ((1 - torch.abs(conf[zero_index])) * (outputs2[zero_index] - outputs1[zero_index])).mean()
            # train_loss5 = ((1 - torch.abs(conf[zero_index])) * (
            #     torch.pow(torch.sigmoid(outputs2[zero_index]) - torch.sigmoid(outputs1[zero_index]), 2))).mean()

            # alpha_reg = sigmoid_rampdown(epoch, args.ep) * args.beta
            alpha_reg = args.beta
            train_loss = 0.5 * train_loss.mean() + train_loss5 * alpha_reg
            train_loss.backward()
            optimizer.step()
        # scheduler.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)
        print('#epoch', epoch+1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if if_write:
            with open(save_path, "a") as f:
                f.writelines("{},{:.6f},{:.6f}\n".format(epoch + 1, train_loss.data.item(), test_acc))
        if epoch >= (args.ep-10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)


def ConfDiffReLU_new(model, given_train_loader, test_loader, args, loss_fn, device, if_write=False, save_path=""):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    # if (args.uci == 1):
    #     scheduler = StepLR(optimizer, step_size=100, gamma=1)
    # else:
    #     if (args.prior == 0.5):
    #         scheduler = StepLR(optimizer, step_size=40, gamma=0.1)
    #     else:
    #         scheduler = StepLR(optimizer, step_size=20, gamma=0.1)
    prior = args.prior
    lda = torch.tensor([0.0]).to(device)

    for epoch in range(args.ep):
        model.train()
        for (X1, X2, conf, y1, y2) in given_train_loader:
            X1, X2, conf = X1.to(device), X2.to(device), conf.to(device)
            optimizer.zero_grad()
            outputs1 = model(X1)[:,0]
            outputs2 = model(X2)[:,0]
            one_index = torch.where((conf <= -args.bound) | (conf >= args.bound))  # all for 1/-1
            zero_index = torch.where((conf > -args.bound) & (conf < args.bound))  # all for 0

            train_loss1 = torch.max(((prior - conf[one_index]) * loss_fn(outputs1[one_index])).mean(), lda)
            train_loss2 = torch.max(((1 - prior + conf[one_index]) * loss_fn(-outputs1[one_index])).mean(), lda)
            train_loss3 = torch.max(((prior + conf[one_index]) * loss_fn(outputs2[one_index])).mean(), lda)
            train_loss4 = torch.max(((1 - prior - conf[one_index]) * loss_fn(-outputs2[one_index])).mean(), lda)
            train_loss5 = torch.max(((1 - torch.abs(conf[zero_index])) * (outputs2[zero_index] - outputs1[zero_index])).mean(), lda)
            # train_loss5 = ((1 - torch.abs(conf[zero_index])) * (
            #     torch.pow(torch.sigmoid(outputs2[zero_index]) - torch.sigmoid(outputs1[zero_index]), 2))).mean()

            # alpha_reg = sigmoid_rampdown(epoch, args.ep) * args.beta
            alpha_reg = args.beta
            train_loss = 0.5 * (train_loss1 + train_loss2 + train_loss3 + train_loss4) + train_loss5 * alpha_reg
            train_loss.backward()
            optimizer.step()
        # scheduler.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)
        print('#epoch', epoch+1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if if_write:
            with open(save_path, "a") as f:
                f.writelines("{},{:.6f},{:.6f}\n".format(epoch + 1, train_loss.data.item(), test_acc))

        if epoch >= (args.ep-10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)



def ConfDiffABS_new(model, given_train_loader, test_loader, args, loss_fn, device, if_write=False, save_path=""):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    # if (args.uci == 1):
    #     scheduler = StepLR(optimizer, step_size=100, gamma=1)
    # else:
    #     if (args.prior == 0.5):
    #         scheduler = StepLR(optimizer, step_size=40, gamma=0.1)
    #     else:
    #         scheduler = StepLR(optimizer, step_size=20, gamma=0.1)  #30 for stl10, 40 for others
    prior = args.prior
    lda = torch.tensor([0.0]).to(device)

    for epoch in range(args.ep):
        model.train()
        for (X1, X2, conf, y1, y2) in given_train_loader:
            optimizer.zero_grad()
            X1, X2, conf = X1.to(device), X2.to(device), conf.to(device)
            outputs1 = model(X1)[:, 0]
            outputs2 = model(X2)[:, 0]

            one_index = torch.where((conf <= -args.bound) | (conf >= args.bound))  # all for 1/-1
            zero_index = torch.where((conf > -args.bound) & (conf < args.bound))  # all for 0

            # train_loss1 = torch.abs(((prior - conf[one_index]) * loss_fn(outputs1[one_index])).mean())
            # train_loss2 = torch.abs(((1 - prior + conf[one_index]) * loss_fn(-outputs1[one_index])).mean())
            # train_loss3 = torch.abs(((prior + conf[one_index]) * loss_fn(outputs2[one_index])).mean())
            # train_loss4 = torch.abs(((1 - prior - conf[one_index]) * loss_fn(-outputs2[one_index])).mean())
            # train_loss5 = ((1 - torch.abs(conf[zero_index])) * (torch.pow(torch.sigmoid(outputs2[zero_index]) - torch.sigmoid(outputs1[zero_index]), 2))).mean()

            train_loss1 = torch.abs(((prior - conf[one_index]) * loss_fn(outputs1[one_index])).mean())
            train_loss2 = torch.abs(((1 - prior + conf[one_index]) * loss_fn(-outputs1[one_index])).mean())
            train_loss3 = torch.abs(((prior + conf[one_index]) * loss_fn(outputs2[one_index])).mean())
            train_loss4 = torch.abs(((1 - prior - conf[one_index]) * loss_fn(-outputs2[one_index])).mean())
            train_loss5 = ((1 - torch.abs(conf[zero_index])) * torch.abs(outputs2[zero_index] - outputs1[zero_index])).mean()

            # alpha_reg = sigmoid_rampdown(epoch, args.ep) * args.beta
            alpha_reg = args.beta
            train_loss = 0.5 * (train_loss1 + train_loss2 + train_loss3 + train_loss4) + train_loss5*alpha_reg
            # train_loss.backward()

            train_loss.backward(retain_graph=True)
            train_loss1.backward(retain_graph=True)
            train_loss2.backward(retain_graph=True)
            train_loss3.backward(retain_graph=True)
            train_loss4.backward(retain_graph=True)
            train_loss5.backward(retain_graph=True)

            optimizer.step()
        # scheduler.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)

        print('train_loss', train_loss.data.item(), 'train_loss1', train_loss1.data.item(), 'train_loss2', train_loss2.data.item(),
              'train_loss3', train_loss3.data.item(), 'train_loss4', train_loss4.data.item(), 'train_loss5', train_loss5.data.item())
        print('#epoch', epoch + 1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if if_write:
            with open(save_path, "a") as f:
                f.writelines("{},{:.6f},{:.6f}\n".format(epoch + 1, train_loss.data.item(), test_acc))
        if epoch >= (args.ep - 10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)


def CRCR_Unbiased(model, given_train_loader, test_loader, args, loss_fn, device, if_write=False, save_path=""):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    # if (args.uci == 1):
    #     scheduler = StepLR(optimizer, step_size=100, gamma=1)
    # else:
    #     if (args.prior == 0.5):
    #         scheduler = StepLR(optimizer, step_size=40, gamma=0.1)
    #     else:
    #         scheduler = StepLR(optimizer, step_size=20, gamma=0.1)  #30 for stl10, 40 for others
    prior = args.prior
    lda = torch.tensor([0.0]).to(device)

    for epoch in range(args.ep):
        model.train()
        for (X1, X2, conf, y1, y2) in given_train_loader:
            optimizer.zero_grad()
            X1, X2, conf = X1.to(device), X2.to(device), conf.to(device)
            outputs1 = model(X1)[:, 0]
            outputs2 = model(X2)[:, 0]

            one_index = torch.where((conf <= -args.bound) | (conf >= args.bound))  # all for 1/-1
            zero_index = torch.where((conf > -args.bound) & (conf < args.bound))  # all for 0

            train_loss1 = ((prior - conf[one_index]) * loss_fn(outputs1[one_index])).mean()
            train_loss2 = ((1 - prior + conf[one_index]) * loss_fn(-outputs1[one_index])).mean()
            train_loss3 = ((prior + conf[one_index]) * loss_fn(outputs2[one_index])).mean()
            train_loss4 = ((1 - prior - conf[one_index]) * loss_fn(-outputs2[one_index])).mean()

            # # 1/log(c_i+1)
            conf2 = 1 / torch.log(1.1 + torch.abs(conf[zero_index]))
            conf2_temp = (conf2 - torch.min(conf2)) / (torch.max(conf2) - torch.min(conf2))
            temp_outputs = torch.abs(outputs2[zero_index] - outputs1[zero_index])
            temp_outputs = (temp_outputs - torch.min(temp_outputs)) / (
                    torch.max(temp_outputs) - torch.min(temp_outputs))
            loss5_1 = (conf2_temp * temp_outputs).mean()
            loss5_2 = torch.abs(
                torch.sign(conf[zero_index]) - torch.sign(outputs2[zero_index] - outputs1[zero_index])).mean()
            train_loss5 = loss5_1 + args.lam * loss5_2

            alpha_reg = args.beta
            train_loss = 0.5 * (train_loss1 + train_loss2 + train_loss3 + train_loss4) + train_loss5 * alpha_reg
            # train_loss.backward()

            train_loss.backward(retain_graph=True)
            train_loss1.backward(retain_graph=True)
            train_loss2.backward(retain_graph=True)
            train_loss3.backward(retain_graph=True)
            train_loss4.backward(retain_graph=True)
            train_loss5.backward(retain_graph=True)

            optimizer.step()
        # scheduler.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)

        print('train_loss', train_loss.data.item(), 'train_loss1', train_loss1.data.item(), 'train_loss2',
              train_loss2.data.item(),
              'train_loss3', train_loss3.data.item(), 'train_loss4', train_loss4.data.item(), 'train_loss5',
              train_loss5.data.item())
        print('#epoch', epoch + 1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if if_write:
            with open(save_path, "a") as f:
                f.writelines("{},{:.6f},{:.6f}\n".format(epoch + 1, train_loss.data.item(), test_acc))
        if epoch >= (args.ep - 10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)



def CRCR_ReLU(model, given_train_loader, test_loader, args, loss_fn, device, if_write=False, save_path=""):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    # if (args.uci == 1):
    #     scheduler = StepLR(optimizer, step_size=100, gamma=1)
    # else:
    #     if (args.prior == 0.5):
    #         scheduler = StepLR(optimizer, step_size=40, gamma=0.1)
    #     else:
    #         scheduler = StepLR(optimizer, step_size=20, gamma=0.1)  #30 for stl10, 40 for others
    prior = args.prior
    lda = torch.tensor([0.0]).to(device)

    for epoch in range(args.ep):
        model.train()
        for (X1, X2, conf, y1, y2) in given_train_loader:
            optimizer.zero_grad()
            X1, X2, conf = X1.to(device), X2.to(device), conf.to(device)
            outputs1 = model(X1)[:, 0]
            outputs2 = model(X2)[:, 0]

            one_index = torch.where((conf <= -args.bound) | (conf >= args.bound))  # all for 1/-1
            zero_index = torch.where((conf > -args.bound) & (conf < args.bound))  # all for 0

            train_loss1 = torch.max(((prior - conf[one_index]) * loss_fn(outputs1[one_index])).mean(), lda)
            train_loss2 = torch.max(((1 - prior + conf[one_index]) * loss_fn(-outputs1[one_index])).mean(), lda)
            train_loss3 = torch.max(((prior + conf[one_index]) * loss_fn(outputs2[one_index])).mean(), lda)
            train_loss4 = torch.max(((1 - prior - conf[one_index]) * loss_fn(-outputs2[one_index])).mean(), lda)

            # # 1/log(c_i+1)
            conf2 = 1 / torch.log(1.1 + torch.abs(conf[zero_index]))
            conf2_temp = (conf2 - torch.min(conf2)) / (torch.max(conf2) - torch.min(conf2))
            temp_outputs = torch.abs(outputs2[zero_index] - outputs1[zero_index])
            temp_outputs = (temp_outputs - torch.min(temp_outputs)) / (
                        torch.max(temp_outputs) - torch.min(temp_outputs))
            loss5_1 = (conf2_temp * temp_outputs).mean()
            loss5_2 = torch.abs(torch.sign(conf[zero_index]) - torch.sign(outputs2[zero_index] - outputs1[zero_index])).mean()
            train_loss5 = loss5_1 + args.lam*loss5_2

            alpha_reg = args.beta
            train_loss = 0.5 * (train_loss1 + train_loss2 + train_loss3 + train_loss4) + train_loss5*alpha_reg
            # train_loss.backward()

            train_loss.backward(retain_graph=True)
            train_loss1.backward(retain_graph=True)
            train_loss2.backward(retain_graph=True)
            train_loss3.backward(retain_graph=True)
            train_loss4.backward(retain_graph=True)
            train_loss5.backward(retain_graph=True)

            optimizer.step()
        # scheduler.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)

        print('train_loss', train_loss.data.item(), 'train_loss1', train_loss1.data.item(), 'train_loss2', train_loss2.data.item(),
              'train_loss3', train_loss3.data.item(), 'train_loss4', train_loss4.data.item(), 'train_loss5', train_loss5.data.item())
        print('#epoch', epoch + 1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if if_write:
            with open(save_path, "a") as f:
                f.writelines("{},{:.6f},{:.6f}\n".format(epoch + 1, train_loss.data.item(), test_acc))
        if epoch >= (args.ep - 10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)



def CRCR_ABS(model, given_train_loader, test_loader, args, loss_fn, device, if_write=False, save_path=""):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    # if (args.uci == 1):
    #     scheduler = StepLR(optimizer, step_size=100, gamma=1)
    # else:
    #     if (args.prior == 0.5):
    #         scheduler = StepLR(optimizer, step_size=40, gamma=0.1)
    #     else:
    #         scheduler = StepLR(optimizer, step_size=20, gamma=0.1)  #30 for stl10, 40 for others
    prior = args.prior
    lda = torch.tensor([0.0]).to(device)

    for epoch in range(args.ep):
        model.train()
        for (X1, X2, conf, y1, y2) in given_train_loader:
            optimizer.zero_grad()
            X1, X2, conf = X1.to(device), X2.to(device), conf.to(device)
            outputs1 = model(X1)[:, 0]
            outputs2 = model(X2)[:, 0]

            one_index = torch.where((conf <= -args.bound) | (conf >= args.bound))  # all for 1/-1
            zero_index = torch.where((conf > -args.bound) & (conf < args.bound))  # all for 0

            train_loss1 = torch.abs(((prior - conf[one_index]) * loss_fn(outputs1[one_index])).mean())
            train_loss2 = torch.abs(((1 - prior + conf[one_index]) * loss_fn(-outputs1[one_index])).mean())
            train_loss3 = torch.abs(((prior + conf[one_index]) * loss_fn(outputs2[one_index])).mean())
            train_loss4 = torch.abs(((1 - prior - conf[one_index]) * loss_fn(-outputs2[one_index])).mean())

            # # 1/log(c_i+1)
            conf2 = 1 / torch.log(1.1 + torch.abs(conf[zero_index]))
            conf2_temp = (conf2 - torch.min(conf2)) / (torch.max(conf2) - torch.min(conf2))
            temp_outputs = torch.abs(outputs2[zero_index] - outputs1[zero_index])
            temp_outputs = (temp_outputs - torch.min(temp_outputs)) / (
                        torch.max(temp_outputs) - torch.min(temp_outputs))
            loss5_1 = (conf2_temp * temp_outputs).mean()
            loss5_2 = torch.abs(torch.sign(conf[zero_index]) - torch.sign(outputs2[zero_index] - outputs1[zero_index])).mean()
            train_loss5 = loss5_1 + args.lam*loss5_2



            # conf2 = conf[zero_index].clone()
            # conf2[conf2 < 0] = -1 / torch.log(1.1 - conf2[conf2 < 0])
            # # conf2 = 1 / (1 + torch.exp(torch.abs(conf2)))
            # conf2_temp = torch.abs(conf2)
            # conf2_temp = (conf2_temp - torch.min(conf2_temp)) / (torch.max(conf2_temp) - torch.min(conf2_temp))
            # temp_outputs = torch.abs(outputs2[zero_index] - outputs1[zero_index])
            # temp_outputs = (temp_outputs - torch.min(temp_outputs)) / (torch.max(temp_outputs) - torch.min(temp_outputs))
            # loss5_1 = (conf2_temp * temp_outputs).mean()
            # loss5_2 = torch.abs(torch.sign(conf2) - torch.sign(outputs2[zero_index] - outputs1[zero_index])).mean()
            # train_loss5 = loss5_1 + args.lam*loss5_2
            # train_loss5 = (conf2 * (outputs2[zero_index] - outputs1[zero_index])).mean()

            # alpha_reg = sigmoid_rampdown(epoch, args.ep) * args.beta
            alpha_reg = args.beta
            train_loss = 0.5 * (train_loss1 + train_loss2 + train_loss3 + train_loss4) + train_loss5*alpha_reg
            # train_loss.backward()

            train_loss.backward(retain_graph=True)
            train_loss1.backward(retain_graph=True)
            train_loss2.backward(retain_graph=True)
            train_loss3.backward(retain_graph=True)
            train_loss4.backward(retain_graph=True)
            train_loss5.backward(retain_graph=True)

            optimizer.step()
        # scheduler.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)

        print('train_loss', train_loss.data.item(), 'train_loss1', train_loss1.data.item(), 'train_loss2', train_loss2.data.item(),
              'train_loss3', train_loss3.data.item(), 'train_loss4', train_loss4.data.item(), 'train_loss5', train_loss5.data.item())
        print('#epoch', epoch + 1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if if_write:
            with open(save_path, "a") as f:
                f.writelines("{},{:.6f},{:.6f}\n".format(epoch + 1, train_loss.data.item(), test_acc))
        if epoch >= (args.ep - 10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)
