
import numpy as np
from sklearn import metrics

def auc_solam_decentralized(x_tr, y_tr, x_te, y_te, options):

    ###########new things###############
    topology_type = options['topology_type']
    num_nodes = options['num_nodes']
    Connect_matrix = generate_P(topology_type, num_nodes)

    x_tr_contate = np.concatenate(x_tr, axis=0)
    y_tr_contate = np.concatenate(y_tr)

    # options
    ids = options['ids']
    # eta_0 = options['eta']
    beta = options['beta']  # beta is the parameter R, we use beta for consistency

    etas = options['etas']
    res_idx = options['res_idx']
    n_tr, dim = x_tr[0].shape

    t = 0    # the time iterate" 
    # for storing the results
    n_idx = len(res_idx)
    i_res = 0
    gens = np.zeros(n_idx)
    ws = np.zeros((n_idx, dim + 3))

    v_list = []
    alpha_list = []
    sp_list = []
    
    for i in range(num_nodes):

        v = np.zeros(dim + 2)
        alpha = 0
        sp = 0   # the estimate of probability with positive example      

        v_list.append(v)
        alpha_list.append(alpha)
        sp_list.append(sp)

    gd = np.zeros(dim + 2)
    while t < len(ids):
        #print(ids[t])
        w__list = []
        t = t + 1
        for i in range(num_nodes):
            x_t = x_tr[i][ids[t-1], :]
            y_t = y_tr[i][ids[t-1]]
            wx = np.inner(v_list[i][:dim], x_t) #np.inner(x_t, v[:dim])
            eta = etas[t-1]
            if y_t == 1:
                sp_list[i] = sp_list[i] + 1
                p = sp_list[i] / t
                gd[:dim] = (1 - p) * (wx - v_list[i][dim] - 1 - alpha_list[i]) * x_t
                gd[dim] = (p - 1) * (wx - v_list[i][dim])
                gd[dim+1] = 0
                gd_alpha = (p - 1) * (wx + p * alpha_list[i])
            else:
                p = sp_list[i] / t
                gd[:dim] = p * (wx - v_list[i][dim + 1] + 1 + alpha_list[i]) * x_t
                gd[dim] = 0
                gd[dim+1] = p * (v_list[i][dim + 1] - wx)
                gd_alpha = p * (wx + (p - 1) * alpha_list[i])
            # print(eta)
            v_list[i] = v_list[i] - eta * gd
            alpha_list[i] = alpha_list[i] + eta * gd_alpha        
            
            v_list[i][:dim] = 1 / (1 + beta * eta) * v_list[i][:dim]
            w_ = v_list[i][:dim]
            w__list.append(w_)

        v_temp_list = []
        alpha_temp_list = []
        for i in range(num_nodes):
            v_temp = np.zeros(dim + 2)
            alpha_temp = 0
            for j in range(num_nodes):
                v_temp += Connect_matrix[i][j] * v_list[j]
                alpha_temp += Connect_matrix[i][j] * alpha_list[j]
        
            v_temp_list.append(v_temp)
            alpha_temp_list.append(alpha_temp)
        
        v_list = []
        alpha_list = []
        for i in range(num_nodes):
            v_list.append(v_temp_list[i])
            alpha_list.append(alpha_temp_list[i])

        # get the average of v and alpha
        v_avg = np.zeros(dim + 2)
        alpha_avg = 0
        w__avg = np.zeros(dim)
        for i in range(num_nodes):
            v_avg += v_list[i]
            alpha_avg += alpha_list[i]
            w__avg += w__list[i]
        
        v_avg /= num_nodes
        alpha_avg /= num_nodes
        w__avg /= num_nodes
        
        if i_res < n_idx and res_idx[i_res] == t:                    
            if not np.all(np.isfinite(w__avg)):
                gens[i_res:] = gens[i_res - 1]    
                ws[i_res:, :] = v_avg
                break

            pred = (x_te.dot(w__avg.T)).ravel()
            fpr, tpr, thresholds = metrics.roc_curve(y_te, pred.T, pos_label = 1)               
            test_err = metrics.auc(fpr, tpr)

            pred = (x_tr_contate.dot(w__avg.T)).ravel()
            fpr, tpr, thresholds = metrics.roc_curve(y_tr_contate, pred.T, pos_label = 1)     
            train_err = metrics.auc(fpr, tpr)   
            gens[i_res] = test_err - train_err
            ws[i_res, :dim + 2] = v_avg
            ws[i_res, dim + 2] = alpha_avg
            i_res = i_res + 1

    return ws, gens

def generate_P(mode, size):
    result = np.zeros((size, size))
    if mode == "all":
        result = np.ones((size, size)) / size
    elif mode == "single":
        for i in range(size):
            result[i][i] = 1
    elif mode == "ring":
        for i in range(size):
            result[i][i] = 1 / 3
            result[i][(i - 1 + size) % size] = 1 / 3
            result[i][(i + 1) % size] = 1 / 3
    elif mode == "star":
        for i in range(size):
            result[i][i] = 1 - 1 / size
            result[0][i] = 1 / size
            result[i][0] = 1 / size
    elif mode == "meshgrid":
        assert size > 0
        i = int(np.sqrt(size))
        while size % i != 0:
            i -= 1
        shape = (i, size // i)
        nrow, ncol = shape
        # print(shape, flush=True)
        topo = np.zeros((size, size))
        for i in range(size):
            topo[i][i] = 1.0
            if (i + 1) % ncol != 0:
                topo[i][i + 1] = 1.0
                topo[i + 1][i] = 1.0
            if i + ncol < size:
                topo[i][i + ncol] = 1.0
                topo[i + ncol][i] = 1.0
        topo_neighbor_with_self = [np.nonzero(topo[i])[0] for i in range(size)]
        for i in range(size):
            for j in topo_neighbor_with_self[i]:
                if i != j:
                    topo[i][j] = 1.0 / max(len(topo_neighbor_with_self[i]),
                                           len(topo_neighbor_with_self[j]))
            topo[i][i] = 2.0 - topo[i].sum()
        result = topo
    elif mode == "exponential":
        x = np.array([1.0 if i & (i - 1) == 0 else 0 for i in range(size)])
        x /= x.sum()
        topo = np.empty((size, size))
        for i in range(size):
            topo[i] = np.roll(x, i)
        result = topo
    # print(result, flush=True)
    return result