import numpy as np
import os
import torch
from utils.utils_data import generate_pretrain_loaders, generate_binary_pretrain_data, gen_index_dataset, train_test_data_gen, gen_confdiff_train_loader, generate_pcomp_loaders
import argparse
from utils.utils_models import linear_model, mlp_model
from utils.utils_loss import logistic_loss
from utils.utils_algo import get_model, accuracy_check, train_data_confidence_gen
from cifar_models import resnet
from algorithms import pretrainLR, ConfDiffUnbiased, ConfDiffReLU, ConfDiffABS, ConfDiffABS_new, ConfDiffUnbiased_new, ConfDiffReLU_new, PcompABS, PcompReLU, PcompTeacher, PcompUnbiased, CRCR_ABS, CRCR_ReLU, CRCR_Unbiased
import scipy.io as sio
import matplotlib.pyplot as plt
import math


parser = argparse.ArgumentParser()

parser.add_argument('-pretrain_lr', help='pretrain optimizer\'s learning rate', default=1e-3, type=float)
parser.add_argument('-lr', help='optimizer\'s learning rate', default=1e-3, type=float)
parser.add_argument('-pretrain_bs', help='batch_size of ordinary labels.', default=256, type=int)
parser.add_argument('-bs', help='batch_size of ordinary labels.', default=256, type=int)
parser.add_argument('-ds', help='specify a dataset', default='mnist', type=str, required=False)
parser.add_argument('-mo', help='model name', default='mlp', choices=['linear', 'mlp', 'resnet'], type=str, required=False)
parser.add_argument('-me', help='specify a method', default='ConfDiffABS', type=str, choices=['ConfDiffUnbiased','ConfDiffReLU', 'ConfDiffABS','N_Unbiased','N_ReLU', 'N_ABS', 'PcompUnbiased','PcompReLU', 'PcompABS',  'PcompTeacher', 'CRCR_ABS', 'CRCR_ReLU', 'CRCR_Unbiased'], required=False)
parser.add_argument('-pretrain_ep', help='number of pretrain epochs', type=int, default=10)
parser.add_argument('-ep', help='number of ConfDiff epochs', type=int, default=200)
parser.add_argument('-n', help = 'number of unlabeled data pairs', default=15000, type=int, required=False)
parser.add_argument('-prior', help='the class prior of the data set', type=float, default=0.5)
parser.add_argument('-wd', help='weight decay', default=1e-5, type=float)
parser.add_argument('-lo', help='specify a loss function', default='logistic', type=str, choices=['logistic'], required=False)
parser.add_argument('-uci', help = 'Is UCI datasets?', default=0, type=int, choices=[0,1], required=False)
parser.add_argument('-gpu', help = 'used gpu id', default='0', type=str, required=False)
parser.add_argument('-seed', help = 'Random seed', default=1, type=int, required=False)
parser.add_argument('-run_times', help='random run times', default=5, type=int, required=False)
parser.add_argument('-alpha', help='para', default=1, type=float)
parser.add_argument('-beta', help='para', default=1, type=float)
parser.add_argument('-lam', help='para', default=0.1, type=float)
parser.add_argument('-bound', help='Divide the boundaries between subsets close to 1 and close to 0', default=0.4, type=float)
parser.add_argument('-ema_weight', help = 'consistency weight', default=10, type=float, required=False)
parser.add_argument('-ema_alpha', help = 'ema variable decay rate', default=0.97, type=float, required=False)


args = parser.parse_args()
device = torch.device("cuda:"+args.gpu if torch.cuda.is_available() else "cpu")

if args.lo == 'logistic':
    loss_fn = logistic_loss

acc_run_list = torch.zeros(args.run_times)

save_pretrain_dir = "./result/pretrain"
save_total_dir = "./result/total"
save_detail_dir = "./result/detail"
save_log_dir = "./result/log"
if not os.path.exists(save_pretrain_dir):
    os.makedirs(save_pretrain_dir)
if not os.path.exists(save_total_dir):
    os.makedirs(save_total_dir)
if not os.path.exists(save_detail_dir):
    os.makedirs(save_detail_dir)

save_pretrain_name = "Res_pretrain_ds_{}_prior_{}_me_{}_mo_{}_lr_{}_wd_{}_pretrain_bs_{}_pretrain_ep_{}_seed_{}_n_{}.csv".format(args.ds, args.prior, args.me, args.mo, args.lr, args.wd, args.pretrain_bs, args.pretrain_ep, args.seed, args.n)
save_total_name = "Res_total_ds_{}_prior_{}_me_{}_mo_{}_lr_{}_wd_{}_bs_{}_ep_{}_pretrain_bs_{}_pretrain_ep_{}_seed_{}_n_{}.csv".format(args.ds, args.prior, args.me, args.mo, args.lr, args.wd, args.bs, args.ep, args.pretrain_bs, args.pretrain_ep, args.seed, args.n)
save_detail_name = "Res_detail_ds_{}_prior_{}_me_{}_mo_{}_lr_{}_wd_{}_bs_{}_ep_{}_pretrain_bs_{}_pretrain_ep_{}_seed_{}_n_{}.csv".format(args.ds, args.prior, args.me, args.mo, args.lr, args.wd, args.bs, args.ep, args.pretrain_bs, args.pretrain_ep, args.seed, args.n)
save_log_name = "Res_log_ds_{}_prior_{}_me_{}_mo_{}_lr_{}_wd_{}_seed_{}_n_{}.csv".format(args.ds, args.prior, args.me, args.mo, args.lr, args.wd, args.seed, args.n)

save_pretrain_path = os.path.join(save_pretrain_dir, save_pretrain_name)
save_total_path = os.path.join(save_total_dir, save_total_name)
save_detail_path = os.path.join(save_detail_dir, save_detail_name)
save_log_path = os.path.join(save_log_dir, save_log_name)

if os.path.exists(save_pretrain_path):
    os.remove(save_pretrain_path)
if os.path.exists(save_total_path):
    os.remove(save_total_path)
if os.path.exists(save_detail_path):
    os.remove(save_detail_path)

if_write = True

if if_write:
    with open(save_pretrain_path, 'a') as f:
        f.writelines("epoch,train_loss,train_accuracy,test_accuracy\n")
    with open(save_total_path, 'a') as f:
        f.writelines("run_idx,acc,std\n")
    with open(save_detail_path, 'a') as f:
        f.writelines("epoch,train_loss,test_accuracy\n")

for run_idx in range(args.run_times):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed);
    torch.cuda.manual_seed_all(args.seed);
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    args.seed = args.seed + 1
    print('the {}-th random round'.format(run_idx))

    positive_pretrain_data, negative_pretrain_data, positive_pretrain_label, negative_pretrain_label, positive_pretrain_test_data, negative_pretrain_test_data, positive_pretrain_test_label, negative_pretrain_test_label, dim = generate_binary_pretrain_data(args.uci, args.ds)
    pretrain_loader, pretrain_test_loader, pretrain_eval_loader, pretrain_data, pretrain_label = generate_pretrain_loaders(positive_pretrain_data, negative_pretrain_data, positive_pretrain_label, negative_pretrain_label, positive_pretrain_test_data, negative_pretrain_test_data, positive_pretrain_test_label, negative_pretrain_test_label, args.pretrain_bs)

    # construct a weak classifier
    pretrain_model = get_model(args.ds, args.mo, dim, device)

    avg_pretrain_test_acc, pretrain_model = pretrainLR(pretrain_model, pretrain_loader, pretrain_test_loader, pretrain_eval_loader, args, loss_fn, device, if_write=if_write, save_path=save_pretrain_path)
    print("Average test accuracy for pretrain (weak): ", avg_pretrain_test_acc)

    train_data1, train_data2, train_label1, train_label2, data1_loader, data2_loader, test_loader = train_test_data_gen(positive_pretrain_data, negative_pretrain_data, positive_pretrain_test_data, negative_pretrain_test_data, args.n, args.prior, args.pretrain_bs)

    data1_confidence = torch.zeros(train_data1.shape[0])
    data1_confidence = data1_confidence.to(device)
    data1_confidence, start_idx1 = train_data_confidence_gen(data1_loader, pretrain_model, device, data1_confidence)


    data2_confidence = torch.zeros(train_data2.shape[0])
    data2_confidence = data2_confidence.to(device)
    data2_confidence, start_idx2 = train_data_confidence_gen(data2_loader, pretrain_model, device, data2_confidence)

    # ================================================================================================================

    # delete error classified instance

    n_weak = len(data1_confidence.cpu().numpy())
    plt.rcParams['font.size'] = 16
    eps = 1e-10

    train_label1_temp = train_label1.clone()
    train_label1_temp[train_label1_temp == -1] = 0
    index_train1 = torch.abs(data1_confidence - train_label1_temp.to(device)) <= 0.5
    index_train1 = index_train1.cpu()   # can delete in 3090
    data1_confidence = data1_confidence[index_train1]
    train_data1 = train_data1[index_train1]
    train_label1 = train_label1[index_train1]

    train_label2_temp = train_label2.clone()
    train_label2_temp[train_label2_temp == -1] = 0
    index_train2 = torch.abs(data2_confidence - train_label2_temp.to(device)) <= 0.5
    index_train2 = index_train2.cpu()  # can delete in 3090
    data2_confidence = data2_confidence[index_train2]
    train_data2 = train_data2[index_train2]
    train_label2 = train_label2[index_train2]

    len_data1 = data1_confidence.shape[0]
    len_data2 = data2_confidence.shape[0]
    len_final = min(len_data1, len_data2)

    def gaussian_kernel(x, xi, bandwidth):
        """基于高斯核的概率密度估计"""
        return torch.exp(-0.5 * ((x - xi) / bandwidth) ** 2) / (bandwidth * (2 * math.pi) ** 0.5)

    def kde(data, x_grid, bandwidth=0.5):
        """核密度估计的实现"""
        n = data.shape[0]
        pdf = torch.zeros_like(x_grid)

        for i in range(n):  # 对每个点进行高斯核密度估计
            pdf += gaussian_kernel(x_grid, data[i], bandwidth)

        pdf /= n    # 概率密度值取均值
        return pdf

    def Outliers(data):
        x_grid = data  # 使用数据本身作为计算点
        std = torch.std(data)  # 计算数据的标准差
        bandwidth = 0.01 * std  # 核宽度，控制平滑程度
        pdf_values = kde(data, x_grid, bandwidth)

        threshold = torch.quantile(pdf_values, 0.02).item()  # 设置阈值，概率密度低于阈值的点为离群点
        outliers = pdf_values <= threshold

        return outliers


    # close 0
    mask = torch.logical_and(Outliers(data1_confidence[data1_confidence <= 0.5]) == 0,
                             data1_confidence[data1_confidence <= 0.5] != 0)
    indices = torch.nonzero(data1_confidence <= 0.5).squeeze()
    final_indices = indices[mask]
    data1_confidence_dense = data1_confidence[final_indices]
    data1_confidence_dense = torch.log(data1_confidence_dense + eps)
    data1_confidence_dense = (data1_confidence_dense - torch.min(data1_confidence_dense)) / (torch.max(data1_confidence_dense) - torch.min(data1_confidence_dense)) * 0.5
    data1_confidence[final_indices] = data1_confidence_dense

    mask = torch.logical_and(Outliers(data2_confidence[data2_confidence <= 0.5]) == 0,
                             data2_confidence[data2_confidence <= 0.5] != 0)
    indices = torch.nonzero(data2_confidence <= 0.5).squeeze()
    final_indices = indices[mask]
    data2_confidence_dense = data2_confidence[final_indices]
    data2_confidence_dense = torch.log(data2_confidence_dense + eps)
    data2_confidence_dense = (data2_confidence_dense - torch.min(data2_confidence_dense)) / (torch.max(data2_confidence_dense) - torch.min(data2_confidence_dense)) * 0.5
    data2_confidence[final_indices] = data2_confidence_dense

    # close 1
    mask = torch.logical_and(Outliers(data1_confidence[data1_confidence > 0.5]) == 0,
                             data1_confidence[data1_confidence > 0.5] != 1)
    indices = torch.nonzero(data1_confidence > 0.5).squeeze()
    final_indices = indices[mask]
    data1_confidence_dense2 = data1_confidence[indices]
    data1_confidence_dense2 = torch.log(1 - data1_confidence_dense2 + eps)
    data1_confidence_dense2 = 1 - (data1_confidence_dense2 - torch.min(data1_confidence_dense2)) / (torch.max(data1_confidence_dense2) - torch.min(data1_confidence_dense2)) * 0.5
    data1_confidence[indices] = data1_confidence_dense2

    mask = torch.logical_and(Outliers(data2_confidence[data2_confidence > 0.5]) == 0,
                             data2_confidence[data2_confidence > 0.5] != 1)
    indices = torch.nonzero(data2_confidence > 0.5).squeeze()
    final_indices = indices[mask]
    data2_confidence_dense2 = data2_confidence[data2_confidence > 0.5][torch.logical_and(Outliers(data2_confidence[data2_confidence > 0.5]) == 0, data2_confidence[data2_confidence > 0.5]!=1)]
    data2_confidence_dense2 = torch.log(1 - data2_confidence_dense2 + eps)
    data2_confidence_dense2 = 1 - (data2_confidence_dense2 - torch.min(data2_confidence_dense2)) / (torch.max(data2_confidence_dense2) - torch.min(data2_confidence_dense2)) * 0.5
    data2_confidence[final_indices] = data2_confidence_dense2


    # cat
    index_train1 = torch.randperm(len_data1)
    data1_confidence = data1_confidence[index_train1][:len_final]
    train_data1 = train_data1[index_train1][:len_final, :]
    train_label1 = train_label1[index_train1][:len_final]

    index_train2 = torch.randperm(len_data2)
    data2_confidence = data2_confidence[index_train2][:len_final]
    train_data2 = train_data2[index_train2][:len_final, :]
    train_label2 = train_label2[index_train2][:len_final]

    print(len_final)

    pcomp_confidence = data2_confidence - data1_confidence

    # weights = np.ones_like(data1_confidence.cpu().numpy()) / len_final
    # plt.hist(data1_confidence.cpu().numpy(), bins=100, weights=weights)
    # plt.title('data1 confidence Distribution')
    # plt.show()
    #
    # weights = np.ones_like(pcomp_confidence.cpu().numpy()) / len_final
    # plt.hist(pcomp_confidence.cpu().numpy(), bins=100, weights=weights)
    # plt.title('pcomp_confidence Distribution by weak classifier')
    # plt.show()

    # ================================================================================================================

    # ========================= add noise ==================================

    # def generate_dual_Beta(n, alpha1, beta1, alpha2, beta2):
    #     # 生成第一个峰的数据
    #     data1 = np.random.beta(alpha1, beta1, int(n / 2))
    #     # 生成第二个峰的数据
    #     data2 = - np.random.beta(alpha2, beta2, n - int(n / 2))
    #     # 合并两个峰的数据
    #     bimodal_data = np.concatenate((data1, data2))
    #     return bimodal_data
    #
    # def generate_Noice(data_diff, lambda_item, alpha1, beta1, alpha2, beta2):
    #     data_noi = data_diff.clone().cpu().numpy()
    #
    #     # generate noice for easy error classified instances
    #     len_noice = int(len(data_noi)*lambda_item)
    #     noice = generate_dual_Beta(len_noice, alpha1, beta1, alpha2, beta2)*0.5
    #     noice = np.pad(noice, (0, len(data_noi) - len_noice), 'constant')
    #
    #     data_noi += noice
    #     data_noi[data_noi < 0] = 0 - data_noi[data_noi < 0]
    #     data_noi[data_noi > 1] = 2 - data_noi[data_noi > 1]
    #
    #     return data_noi
    #
    #
    # lambda_item = args.alpha
    # alpha1, beta1 = 2, 3
    # alpha2, beta2 = 2, 3
    # data1_confidence_noi = generate_Noice(data1_confidence, lambda_item, alpha1, beta1, alpha2, beta2)
    # data2_confidence_noi = generate_Noice(data2_confidence, lambda_item, alpha1, beta1, alpha2, beta2)


    def generate_Guassion(n, mean1, std1):
        bimodal_data = np.random.normal(mean1, std1, n)
        return bimodal_data

    def generate_Noice(data_diff, lambda_item, mean1, std1):
        data_noi = data_diff.clone().cpu().numpy()

        # generate noice for easy error classified instances
        len_noice = int(len(data_noi)*lambda_item)
        noice = generate_Guassion(len_noice, mean1, std1)*0.5
        noice = np.pad(noice, (0, len(data_noi) - len_noice), 'constant')

        data_noi += noice
        data_noi[data_noi < 0] = 0 - data_noi[data_noi < 0]
        data_noi[data_noi > 1] = 2 - data_noi[data_noi > 1]

        return data_noi

    lambda_item = args.alpha
    mean1, std1 = 0, 1/3
    data1_confidence_noi = generate_Noice(data1_confidence, lambda_item, mean1, std1)
    data2_confidence_noi = generate_Noice(data2_confidence, lambda_item, mean1, std1)

    pcomp_confidence_noi = data2_confidence_noi - data1_confidence_noi


    # n = len(pcomp_confidence.cpu().numpy())
    # plt.rcParams['font.size'] = 16

    # weights = np.ones_like(data1_confidence_noi) / n
    # plt.hist(data1_confidence_noi, bins=100, weights=weights)
    # plt.title('data1 confidence Distribution')
    # plt.show()

    # weights = np.ones_like(pcomp_confidence.cpu().numpy()) / n
    # plt.hist(pcomp_confidence.numpy(), bins=100, weights=weights)
    # plt.title('confidence difference Distribution')
    # plt.show()

    # weights = np.ones_like(pcomp_confidence_noi) / n
    # plt.hist(pcomp_confidence_noi, bins=100, weights=weights)
    # plt.title('confidence difference (noice) Distribution')
    # plt.show()

    pcomp_confidence = torch.tensor(pcomp_confidence_noi)

    # ================================= end add ==================================

    #print(pcomp_confidence[:10])
    pcomp_train_data1, pcomp_train_data2, pcomp_pcomp_confidence = train_data1.clone(), train_data2.clone(), pcomp_confidence.clone()
    confdiff_train_loader = gen_confdiff_train_loader(train_data1, train_data2, pcomp_confidence, train_label1, train_label2, args.bs)
    given_train_loader = generate_pcomp_loaders(pcomp_train_data1, pcomp_train_data2, pcomp_pcomp_confidence, train_label1, train_label2, args.bs)

    model = get_model(args.ds, args.mo, dim, device)
    model2 = get_model(args.ds, args.mo, dim, device)
    if args.me == 'ConfDiffUnbiased':
        res_acc = ConfDiffUnbiased(model, confdiff_train_loader, test_loader, args, loss_fn, device, if_write=if_write, save_path=save_detail_path)
        print('ConfDiffUnbiased_acc: ', res_acc)
    elif args.me == 'ConfDiffABS':
        res_acc = ConfDiffABS(model, confdiff_train_loader, test_loader, args, loss_fn, device, if_write=if_write, save_path=save_detail_path)
        print('ConfDiffABS_acc: ', res_acc)
    elif args.me == 'ConfDiffReLU':
        res_acc = ConfDiffReLU(model, confdiff_train_loader, test_loader, args, loss_fn, device, if_write=if_write, save_path=save_detail_path)
        print('ConfDiffReLU_acc: ', res_acc)
    elif args.me == 'N_ABS':
        res_acc = ConfDiffABS_new(model, confdiff_train_loader, test_loader, args, loss_fn, device, if_write=if_write, save_path=save_detail_path)
        print('ConfDiffABS_acc: ', res_acc)
    elif args.me == 'N_Unbiased':
        res_acc = ConfDiffUnbiased_new(model, confdiff_train_loader, test_loader, args, loss_fn, device, if_write=if_write, save_path=save_detail_path)
        print('ConfDiffUnbiased_acc: ', res_acc)
    elif args.me == 'N_ReLU':
        res_acc = ConfDiffReLU_new(model, confdiff_train_loader, test_loader, args, loss_fn, device, if_write=if_write, save_path=save_detail_path)
        print('ConfDiffReLU_acc: ', res_acc)
    elif args.me == 'PcompUnbiased':
        res_acc = PcompUnbiased(model, given_train_loader, test_loader, args, loss_fn, device)
        print("PcompUnbiased Accuracy:", res_acc)
    elif args.me == 'PcompReLU':
        res_acc = PcompReLU(model, given_train_loader, test_loader, args, loss_fn, device)
        print("PcompReLU Accuracy:", res_acc)
    elif args.me == 'PcompTeacher':
        ema_model = get_model(args.ds, args.mo, dim, device)
        res_acc = PcompTeacher(model, ema_model, given_train_loader, test_loader, args, loss_fn, device)
        print("PcompTeacher Accuracy:", res_acc)
    elif args.me == 'PcompABS':
        res_acc = PcompABS(model, given_train_loader, test_loader, args, loss_fn, device)
        print("PcompABS Accuracy:", res_acc)
    elif args.me == 'CRCR_ABS':
        res_acc = CRCR_ABS(model, confdiff_train_loader, test_loader, args, loss_fn, device, if_write=if_write, save_path=save_detail_path)
        print("CRCR_ABS Accuracy:", res_acc)
    elif args.me == 'CRCR_ReLU':
        res_acc = CRCR_ReLU(model, confdiff_train_loader, test_loader, args, loss_fn, device, if_write=if_write, save_path=save_detail_path)
        print("CRCR_ReLU Accuracy:", res_acc)
    elif args.me == 'CRCR_Unbiased':
        res_acc = CRCR_Unbiased(model, confdiff_train_loader, test_loader, args, loss_fn, device, if_write=if_write, save_path=save_detail_path)
        print("CRCR_Unbiased Accuracy:", res_acc)


    acc_run_list[run_idx] = res_acc
    print('\n')
    if if_write:
        with open(save_total_path, "a") as f:
            f.writelines("{},{:.6f},None\n".format(run_idx + 1, res_acc))

print('Avg_acc:{}    std_acc:{}'.format(acc_run_list.mean(), acc_run_list.std()))
if if_write:
    with open(save_total_path, "a") as f:
        f.writelines("in total,{:.6f},{:.6f}\n".format(acc_run_list.mean(), acc_run_list.std()))
print('method:{}    lr:{}    wd:{}'.format(args.me, args.lr, args.wd))
print('loss:{}    prior:{}'.format(args.lo, args.prior))
print('model:{}    dataset:{}'.format(args.mo, args.ds))
print('num of sample:{}'.format(args.n))
print('\n')
