import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import numpy as np
import torch


torch.cuda.empty_cache()
from utils.metrics import RankingLoss, Coverage, AveragePrecision, OneError, HammingLoss
from utils.ADMM_ import admm
from utils.admm_Q import admm_lasso
from utils.admm_O import admm_o
from scipy.io import loadmat
import argparse

import random
from numpy.linalg import svd, norm
from gen_similarity import get_similarity_matrix


def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU，为所有GPU设置随机种子
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


# main setting
parser = argparse.ArgumentParser(
    prog='MVPML_CI demo file.',
    usage='Demo with Multi-View Partial Multi-Label Learning.',
    epilog='end',
    add_help=True
)
# hyper-param
parser.add_argument('--dataset', type=str, default='emotions_new.mat')
parser.add_argument('--p', type=int, default=3)
parser.add_argument('--r', type=int, default=1)


parser.add_argument('--gamma1', type=float, default=5)
parser.add_argument('--gamma2', type=float, default=17.79)
parser.add_argument('--gamma3', type=float, default=17.79)
parser.add_argument('--maxiter', type=int, default=30)
parser.add_argument('--lr_lbfgs', type=float, default=1e-5)
parser.add_argument('--lr_gd', type=float, default=1e-2)
parser.add_argument('--maxiter_bfgs', type=int, default=300)
parser.add_argument('--folds', type=int, default=10)
parser.add_argument('--dir', help='result save path', type=str, default='results_trainall-emo', required=False)
# gpu
parser.add_argument('--exp', type=str, default="trainall")
parser.add_argument('--gpu', type=str, default="0")
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--admm_iter', type=int, default=100)
args = parser.parse_args()
# set gpu idx
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
# set random seed
set_seed(args.seed)
print(args)



def main():
    data_path = "dataset/"
    dataset = args.dataset
    p = args.p
    r = args.r



    # lambda_ = args.lambda_
    gamma1 = args.gamma1
    gamma2 = args.gamma2
    gamma3 = args.gamma3
    maxiter = args.maxiter
    maxiter_bfgs = args.maxiter_bfgs
    folds = args.folds

    # Take the corresponding data from .mat
    mat_data_dict = loadmat(os.path.join(data_path, dataset))
    data = mat_data_dict["data"]
    print("data shape: ", data.shape)
    # print(data)
    target = mat_data_dict["target"]
    print("target shape: ", target.shape)
    # print(target)
    weak_target = mat_data_dict["p{}r{}_noise_target".format(str(p), str(r))]
    print("weak target shape: ", weak_target.shape)
    # print(weak_target)
    idx = mat_data_dict["idx"]
    print("idx shape: ", idx.shape)
    # print(idx)
    v_num = data.shape[0]
    print("view number: ", v_num)

    # min_max normalization
    for i in range(0, v_num):
        view_data = data[i, 0]
        n_sample, n_fea = data[i, 0].shape
        print("For View {}, n_sample={} and n_fea={}. ".format(i, n_sample, n_fea))
        data[i, 0] = (view_data - np.min(view_data, axis=0, keepdims=True)) / (
                np.max(view_data, axis=0, keepdims=True) - np.min(view_data, axis=0, keepdims=True))

    # ten-fold
    save_list = []
    for fold in range(0, folds):
        print("Current Fold: {}".format(str(fold)))
        # 变成索引 要➖1
        # 10x1 -> 1x59
        test_idx = idx[fold, 0][0] - 1
        train_idx = np.ones(n_sample, np.bool_)
        train_idx[test_idx] = False
        # print("train_idx", train_idx) # bool list
        # print("train_idx_shape", train_idx.shape)
        #
        # print("Test idx: ", test_idx)  #数字 list
        # print("test_idx_shape", test_idx.shape)
        # print("Test num: ", len(test_idx))
        test_target = target[test_idx, :]
        train_target = target[train_idx, :]

        print("Train num: ", len(train_target))

        # build train and test multi-view features
        test_data = [data[i, 0][test_idx, :] for i in range(0, v_num)]
        train_data = [data[i, 0][train_idx, :] for i in range(0, v_num)]

        print("train_data_shape", len(train_data))
        
        # # normalize
        # mean_std_array = [ (np.mean(train_data[i], axis=0, keepdims=True), np.std(train_data[i], axis=0, keepdims=True)) for i in range(0, v_num)]
        # train_data = [ (train_data[i] - mean_std_array[i][0]) / mean_std_array[i][1] for i in range(0, v_num) ]
        # test_data =  [ (test_data[i]  - mean_std_array[i][0]) / mean_std_array[i][1] for i in range(0, v_num) ]
        
        # train_data = [ np.concatenate((train_data[i], np.ones((len(train_data[i]), 1))), axis=1) for i in range(0, v_num) ]
        # test_data =  [ np.concatenate((test_data[i],  np.ones((len(test_data[i]),  1))), axis=1) for i in range(0, v_num) ]

        # noisy label matrix L of train dataset
        train_weak_target = weak_target[train_idx, :]

        # core algorithm
        W = MVPML_CI(train_weak_target, train_data, gamma1, gamma2, gamma3, maxiter, maxiter_bfgs, fold)

        test_datall = np.concatenate(test_data, axis=1)
        test_datall = torch.FloatTensor(test_datall)
        print("Test datall shape", test_datall.shape)
        Y_test = test_datall @ W
        # calculate indexes

        with torch.no_grad():
            test_target = torch.FloatTensor(test_target)
            Y_test = torch.FloatTensor(Y_test)
            RK = RankingLoss(Y_test, test_target)
            CV = Coverage(Y_test, test_target)
            AP = AveragePrecision(Y_test, test_target)
            OE = OneError(Y_test, test_target)

            Y_bin = perform_binarization(Y_test, thres=0.74)
            HM = HammingLoss(Y_bin, test_target)

            # bin_Y_test, bin_test_target = map(lambda x: binarization(torch.softmax(x, dim=1), 0.1), (Y_test, test_target))
            
            print("Ranking Loss: ", RK)
            print("Coverage: ", CV)
            print("AveragePrecision: ", AP)
            print("OneError: ", OE)
            fold_save_dict = {
                "W": W,
                "Y_test": Y_test,
                "measures": {
                    "RankingLoss": RK,
                    "Coverage": CV,
                    "AveragePrecision": AP,
                    "OneError": OE,
                    "HammingLoss": HM
                }
            }
            save_list.append(fold_save_dict)

    RK_list = torch.FloatTensor([save_list[i]["measures"]["RankingLoss"] for i in range(0, folds)])
    RK_mean, RK_std = torch.mean(RK_list), torch.std(RK_list)
    CV_list = torch.FloatTensor([save_list[i]["measures"]["Coverage"] for i in range(0, folds)])
    CV_mean, CV_std = torch.mean(CV_list), torch.std(CV_list)
    AP_list = torch.FloatTensor([save_list[i]["measures"]["AveragePrecision"] for i in range(0, folds)])
    AP_mean, AP_std = torch.mean(AP_list), torch.std(AP_list)
    OE_list = torch.FloatTensor([save_list[i]["measures"]["OneError"] for i in range(0, folds)])
    OE_mean, OE_std = torch.mean(OE_list), torch.std(OE_list)
    HM_list = torch.FloatTensor([save_list[i]["measures"]["HammingLoss"] for i in range(0, folds)])
    HM_mean, HM_std = torch.mean(HM_list), torch.std(HM_list)
    # save
    save_dict = {
        "experiment_file": args.exp,
        "dataset": args.dataset,
        "args": str(args),
        "ten_fold": save_list,
        "total": {
            "RankingLoss": (RK_mean, RK_std),
            "Coverage": (CV_mean, CV_std),
            "AveragePrecision": (AP_mean, AP_std),
            "OneError": (OE_mean, OE_std),
            "HammingLoss": (HM_mean, HM_std)
        }
    }
    print(args)
    print(save_dict["total"])


    # result dir  
    save_dir = './' + args.dir
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    save_file = os.path.join(save_dir, (args.exp + '_' + str(args.dataset) + '_p = ' + str(args.p)+ '_r = ' + str(args.r) +  '.txt'))
    with open(save_file, 'a') as file:
        file.write(str(args.dataset) + '_max = ' + str(args.maxiter) + '_p = ' + str(args.p)+ '_r = ' + str(args.r) + '_gamma1 = ' + str(args.gamma1)  + '_gamma2 = ' + str(args.gamma2) + '_gamma3 = ' + str(args.gamma3) +' _lbfgs_ ' + str(args.maxiter_bfgs) + ' _seed_ ' + str(args.seed) + 'admm = ' + str(args.admm_iter) + '\n')
    with open(save_file, 'a') as file:
        file.write('RankingLoss: ' + str(RK_mean) + ', ' + str(RK_std) + '\n')
        file.write('Coverage: ' + str(CV_mean) + ', ' + str(CV_std) +'\n')
        file.write('AveragePrecision: ' + str(AP_mean) + ', ' + str(AP_std) + '\n')
        file.write('OneError: ' + str(OE_mean) + ', ' + str(OE_std) + '\n')
        file.write('HammingLoss: ' + str(HM_mean) + ', ' + str(HM_std) + '\n')
        file.write('\n')




def perform_binarization(Y_test, thres=0.74):
    Y_bin = (Y_test - torch.min(Y_test, dim=1, keepdim=True)[0]) / (torch.max(Y_test, dim=1, keepdim=True)[0] - torch.min(Y_test, dim=1, keepdim=True)[0])
    Y_bin[Y_bin > thres] = 1
    Y_bin[Y_bin < thres] = 0
    return Y_bin


def binarization(d, t):
    pass


def MVPML_CI(train_weak_target, train_data, gamma1, gamma2, gamma3, maxiter, maxiter_bfgs, fold):
    v_num = len(train_data)

    # full view
    train_datall = torch.FloatTensor(np.concatenate(train_data, axis=1)).cuda()
    train_weak_target = torch.FloatTensor(train_weak_target).cuda()
    # numbers
    n_num, f_num = train_datall.shape
    print("Train_datall shape: ", (n_num, f_num))
    c_num = train_weak_target.shape[-1]
    # initialize W,S,U
    W = torch.zeros(f_num, c_num).cuda()
    W = torch.nn.Parameter(W.clone().detach())
    torch.nn.init.xavier_normal_(W)
    C = torch.zeros(n_num, n_num).cuda()

    P = torch.rand(n_num, c_num).cuda()
    Q = torch.rand(n_num, c_num).cuda()

    # U = torch.zeros((n_num, n_num)).cuda()
    # initialize A_v of each view
    train_data = [torch.FloatTensor(train_data[i]).cuda() for i in range(0, v_num)]
    # S = [initialize_A(train_data[i]) for i in range(0, v_num)]
    S = [get_similarity_matrix(train_data[i].cpu().numpy()) for i in range(v_num)]
    # t = [1 / torch.pow(torch.norm(A[i] - U, 2), 2) for i in range(0, v_num)]

    # C = sum(S) / v_num
    I = [torch.zeros(n_num, n_num).cuda() for i in range(v_num)]


    # tao = [t[i] / sum(t) for i in range(0, v_num)]

    Inconsistency = sum(I) / v_num


    # tmp = svd(P.cpu().detach().numpy(), compute_uv=0)
    # tmp = torch.Tensor(tmp.reshape((len(tmp), 1)))
    _, tmp, _ = torch.svd(P)

    loss1 = torch.pow(torch.norm(train_weak_target - Inconsistency  @ Q - train_datall @ W, 2), 2) + \
            torch.pow(torch.norm(train_weak_target - C @ P - Inconsistency @ Q, 2), 2) + \
            torch.pow(torch.norm(W, 2), 2) + \
            gamma1 * torch.sum(tmp) + \
            gamma2 * torch.norm(Q, 1)

    for i in range(0, v_num):
        loss1 += torch.pow(torch.norm(train_data[i] - (C + I[i]) @ train_data[i], 2), 2) + \
                 torch.pow(torch.norm(C - S[i], 2), 2) + gamma3 * torch.norm(I[i], 1)


    print("loss1=", loss1.item() / n_num)

    for iter in range(0, maxiter):
        print("Current Iter: ", iter)

        # use BFGS to update C
        parameters = torch.nn.Parameter(C.clone().detach())

        train_data_for_C = list(map(lambda x: x.clone().detach(), train_data))
        # train_weak_target_for_C = map(lambda x: x.clone().detach(), train_weak_target)
        train_weak_target_for_C = train_weak_target.clone().detach()
        I_for_C = list(map(lambda x: x.clone().detach(), I))
        S_for_C = list(map(lambda x: x.clone().detach(), S))
        # A_for_C = list(map(lambda x: x.clone().detach(), A))
        P_for_C = P.clone().detach()
        Q_for_C = Q.clone().detach()
        Inconsistency_for_C = Inconsistency.clone().detach()
        non_parameters = [train_data_for_C, train_weak_target_for_C, S_for_C, I_for_C, Inconsistency_for_C, P_for_C,
                          Q_for_C]

        C = L_BFGS(parameters, non_parameters, C_loss_fun, maxiter=maxiter_bfgs, flag="C")
        # C = GD(parameters, non_parameters, C_loss_fun, maxiter=1000, flag="C")
        # C = normalize_(C)
        
        # use ADMM to update Q
        admm_iter = args.admm_iter
        admm_rho = 1
        # admm_lambda_max = torch.norm(train_datall.T @ train_weak_target, torch.inf)
        # admm_lambda1 = 0.01 * admm_lambda_max
        Q = admm_lasso(train_datall, train_weak_target, W, C, P, Inconsistency, admm_rho, admm_iter, gamma2,flag="Q")

        # use BFGS to update W
        parameters = torch.nn.Parameter(W.clone().detach())
        train_datall_for_W = train_datall.clone().detach()
        train_weak_target_for_W = train_weak_target.clone().detach()
        Q_for_W = Q.clone().detach()
        Inconsistency_for_W = Inconsistency.clone().detach()

        # I_for_W = list(map(lambda x: x.clone().detach(), I))

        non_parameters = [train_datall_for_W, train_weak_target_for_W, Q_for_W, Inconsistency_for_W]
        # if args.dataset in ['emotions_new.mat']:
        #     W = L_BFGS(parameters, non_parameters, W_loss_fun, maxiter=maxiter_bfgs, flag="W")
        # else:
        W = GD(parameters, non_parameters, W_loss_fun, maxiter=1000, flag="W")


        # use BFGS to update I_v
        # for i in range(0, v_num):
        #     parameters = torch.nn.Parameter(I[i].clone().detach())
        #     train_data_for_Ii = train_data[i].clone().detach()
        #     train_weak_target_for_Ii = train_weak_target.clone().detach()
        #     train_datall_for_Ii = train_datall.clone().detach()
        #     W_for_Ii = W.clone().detach()
        #     C_for_Ii = C.clone().detach()
        #     P_for_Ii = P.clone().detach()
        #     Q_for_Ii = Q.clone().detach()
        #     Inconsistency_for_Ii = (Inconsistency - I[i] / v_num).clone().detach()
        #     temp = I[i] / v_num
        #     non_parameters = [train_data_for_Ii, train_datall_for_Ii, train_weak_target_for_Ii, W_for_Ii, C_for_Ii,
        #                       P_for_Ii, Q_for_Ii, Inconsistency_for_Ii]
        #     I[i] = L_BFGS(parameters, non_parameters, Ii_loss_fun, maxiter=maxiter_bfgs,
        #                   flag="I[{}]".format(i))
        #     Inconsistency = Inconsistency - temp + I[i] / v_num
        
        
        # use ADMM to update I_i
        admm_iter = args.admm_iter
        admm_rho = 1

        for i in range(0, v_num):
            temp = Inconsistency - I[i]/v_num
            Ii_old = I[i]/v_num

            I[i] = admm_o(train_data[i], train_datall, train_weak_target, W, C, P, Q, Inconsistency, temp,  admm_rho, admm_iter, gamma3, v_num, flag="I[{}]".format(i))
            Inconsistency = Inconsistency - Ii_old+ I[i] / v_num


        # use ADMM to update P
        admm_iter = args.admm_iter
        admm_rho = 1
        mu = 1.0
        tol = 1e-4
        # admm_lambda_max = torch.norm(train_datall.T @ train_weak_target, torch.inf)
        # admm_lambda1 = 0.01 * admm_lambda_max
        Y = train_weak_target.cpu().detach().numpy()
        E = Inconsistency.cpu().detach().numpy()
        C_for_P = C.cpu().detach().numpy()
        Q_for_P = Q.cpu().detach().numpy()

        P = admm(Y, E, C_for_P, Q_for_P, gamma1, admm_rho, mu, admm_iter, tol)

        tmp = svd(P.cpu().detach().numpy(), compute_uv=0)
        tmp = torch.Tensor(tmp.reshape((len(tmp), 1)))

        loss1 = torch.pow(torch.norm(train_weak_target - Inconsistency @ Q - train_datall @ W, 2), 2) + \
                torch.pow(torch.norm(train_weak_target - C @ P - Inconsistency @ Q, 2), 2) + \
                torch.pow(torch.norm(W, 2), 2) + \
                gamma1 * torch.sum(tmp) + \
                gamma2 * torch.norm(Q, 1)

        for i in range(0, v_num):
            loss1 += torch.pow(torch.norm(train_data[i] - (C + I[i]) @ train_data[i], 2), 2) + \
                    torch.pow(torch.norm(C - S[i], 2), 2) + gamma3 * torch.norm(I[i], 1)

        
        print("loss1=", loss1.item() / n_num)

    return W.cpu()


def Parameter(x):
    x = x.clone().detach().cpu()
    x = torch.nn.Parameter(x)

    return x.cuda()


def L_BFGS(parameters, non_parameters, loss_fun, maxiter=100, flag=""):
    optim = torch.optim.LBFGS([parameters], lr = args.lr_lbfgs, max_iter = maxiter)
    optim.zero_grad()

    def closure():
        loss = loss_fun(parameters, *non_parameters)
        loss.backward(retain_graph=True)
        return loss

    optim.step(closure=closure)
    loss = loss_fun(parameters, *non_parameters).item()
    print("As for {}, L_BFGS loss: ".format(flag), loss)
    return parameters


def GD(parameters, non_parameters, loss_fun, maxiter=100, flag=""):
    optim = torch.optim.Adam([parameters], lr=args.lr_gd)
    loss = 0
    for _ in range(0, 100):
        optim.zero_grad()
        loss = loss_fun(parameters, *non_parameters)
        loss.backward()
        optim.step()
    print("As for {}, GD loss: ".format(flag), loss.item())
    return parameters


def C_loss_fun(C, train_data, train_weak_target, S, I, E, P, Q):
    v_num = len(train_data)
    C_loss = 0
    for i in range(0, v_num):
        C_loss += torch.pow(torch.norm(train_data[i] - (C + I[i]) @ train_data[i], 2), 2) + \
                  torch.pow(torch.norm(C - S[i], 2), 2)

    return C_loss + torch.pow(torch.norm(train_weak_target - C @ P - E @ Q, 2), 2)



def W_loss_fun(W, train_datall, train_weak_target, Q, Inconsistency):
    return torch.pow(torch.norm(train_weak_target - Inconsistency @ Q - train_datall @ W, 2), 2) + torch.pow(
        torch.norm(W, 2), 2)

def normalize_(A_i):
    v = torch.sum(A_i, dim=1)
    D = torch.diag(v)
    A_i = torch.linalg.inv(D) @ A_i
    return A_i



if __name__ == "__main__":
    main()



