# -*- coding: utf-8 -*-

import numpy as np
from scipy import io as spio
import scipy.sparse as sp
from sklearn.model_selection import RepeatedKFold
from auc_solam import auc_solam, auc_solam_decentralized
from auc_data import get_res_idx
from concurrent.futures import ProcessPoolExecutor, wait, as_completed
import itertools
import os
import matplotlib.pyplot as plt
from auc_data import auc_data
from scipy.sparse import isspmatrix

def exp_stability(data, eta_p, topology_type, num_nodes, sample_size, work_dir):
    x, y = auc_data(data)

    if isspmatrix(x):
        x = x.toarray()  # to dense matrix for speeding up the computation

    print('data size x {}, y {}'.format(x.shape, y.shape))
    options = dict()
    options['n_proc'] = 40
    options['n_repeates'] = 10
    options['n_pass'] = 2
    options['data'] = data
    options['n_cv'] = 10
    options['rec_log'] = .5
    options['rec'] = 100
    options['log_res'] = True
    options['beta'] = 0
    options['eta_p'] = eta_p
    n_rep = options['n_repeates']

    options['num_nodes'] = num_nodes
    options['topology_type'] = topology_type
    options['sample_size'] = sample_size

    n_tr = int(len(y)*0.8)
    n_iter = options['n_pass'] * n_tr
    res_idx = get_res_idx(options['n_pass'] * n_tr, options)
    options['etas'] = options['eta_p'] * np.ones(n_iter) / np.sqrt(n_iter)
    n_idx = len(res_idx)
    options['res_idx'] = res_idx
    options['n_iter'] = n_iter

    rkf = RepeatedKFold(n_splits=5, n_repeats=n_rep)
    with ProcessPoolExecutor(options['n_proc']) as executor:
        results = executor.map(help_exp, itertools.repeat((x, y, options)), rkf.split(x))
    dist_ret = np.zeros([n_rep * 5, n_idx])
    gen_ret = np.zeros([n_rep * 5, n_idx])
    k = 0
    for ret_dist_, ret_gen_ in results:
        dist_ret[k] = ret_dist_
#        print(auc_ret[k])
        gen_ret[k] = ret_gen_
        k = k + 1  
    dist_diff = dict()
    gen_diff = dict()
    dist_diff['mean'] = np.mean(dist_ret, 0)
    dist_diff['std'] = np.std(dist_ret, 0)
    gen_diff['mean'] = np.mean(gen_ret, 0)
    gen_diff['std'] = np.std(gen_ret, 0)

    cur_path = os.getcwd()
    save_path = os.path.join(cur_path, work_dir, 'result.npy')

    np.save(save_path, {'options':options, 'n_tr':n_tr, 'res_idx':res_idx, 'dist_diff':dist_diff, 'gen_diff':gen_diff}) 
    
def help_exp(arg1, arg2):
    arg = arg1 + (arg2,)
    return auc_exp_dis(*arg)
    # return auc_exp_(*arg)

def get_idx(n_data, n_iter):
    idx = np.zeros(n_iter, dtype=int)

    # random selection
    for i in range(n_iter):
        idx[i] = np.random.randint(0,high=n_data)

    return idx
    
def auc_exp_(x, y, options, idx):
    idx_tr = idx[0]
    idx_te = idx[1]
    x_tr, x_te = x[idx_tr], x[idx_te]
    y_tr, y_te = y[idx_tr], y[idx_te]
    
    x_tr_, y_tr_ = x_tr[:-1], y_tr[:-1]

    n_pass = options['n_pass']
    n_tr = len(y_tr_)
    options['ids'] = get_idx(n_tr, n_pass)

    
    ws_a, gens_a = auc_solam(x_tr_, y_tr_, x_te, y_te, options)
    
    # a neighboring dataset
    x_tr_[-1,:] = x_tr[-1, :]
    y_tr_[-1] = y_tr[-1]
    ws_b, gens_b = auc_solam(x_tr_, y_tr_, x_te, y_te, options)
    
    n_res = len(gens_a)
    ret_dist = np.zeros(n_res)
    ret_gen = np.zeros(n_res)
    for i in range(n_res):
        tmp = ws_a[i] - ws_b[i]
        ret_dist[i] = np.linalg.norm(tmp)
    ret_gen = (gens_a + gens_b) / 2
    return ret_dist, ret_gen

##############################################################################
'''
Distributed version
'''
def split_data(matrix, labels, k, sample_size):
    m, n = matrix.shape
    result_matrix = []
    result_labels = []
    if sample_size == 0:
        rows_per_partition = m // k
    else:
        rows_per_partition = sample_size

    for i in range(k):
        start_row = i * rows_per_partition
        end_row = start_row + rows_per_partition

        if end_row > m:
            break

        partition_matrix = matrix[start_row:end_row]
        partition_labels = labels[start_row:end_row]
        result_matrix.append(partition_matrix)
        result_labels.append(partition_labels)

    return result_matrix, result_labels

def contruct_two_datasets(result_matrix, result_labels):

    num_sample = len(result_matrix)
    result_matrix_a = []
    result_labels_a = []
    result_matrix_b = []
    result_labels_b = []

    for i in range(num_sample):
        temp_matrix_a = np.copy(result_matrix[i][:-1])
        temp_labels_a = np.copy(result_labels[i][:-1])
        result_matrix_a.append(temp_matrix_a)
        result_labels_a.append(temp_labels_a)

        temp_matrix_b = np.copy(result_matrix[i][:-1])
        temp_labels_b = np.copy(result_labels[i][:-1])
        temp_matrix_b[-1,:] = result_matrix[i][-1,:]
        temp_labels_b[-1] = result_labels[i][-1]
        result_matrix_b.append(temp_matrix_b)
        result_labels_b.append(temp_labels_b)
    
    return result_matrix_a, result_labels_a, result_matrix_b, result_labels_b
        


def auc_exp_dis(x, y, options, idx):
    idx_tr = idx[0]
    idx_te = idx[1]
    x_tr, x_te = x[idx_tr], x[idx_te]
    y_tr, y_te = y[idx_tr], y[idx_te]

    ## contrcuct two different datasets ##
    x_tr_all, y_tr_all = split_data(x_tr, y_tr, options['num_nodes'], options['sample_size']) # x_tr_all is a list of matrix, len(x_tr_all) = num_nodes
    x_tr_a, y_tr_a, x_tr_b, y_tr_b = contruct_two_datasets(x_tr_all, y_tr_all)

    n_iter = options['n_iter']
    n_tr = len(y_tr_a[0])
    options['ids'] = get_idx(n_tr, n_iter)

    ws_a, gens_a = auc_solam_decentralized(x_tr_a, y_tr_a, x_te, y_te, options)

    # a neighboring dataset
    ws_b, gens_b = auc_solam_decentralized(x_tr_b, y_tr_b, x_te, y_te, options)
    
    n_res = len(gens_a)
    ret_dist = np.zeros(n_res)
    ret_gen = np.zeros(n_res)
    for i in range(n_res):
        tmp = ws_a[i] - ws_b[i]
        ret_dist[i] = np.linalg.norm(tmp)
    ret_gen = (gens_a + gens_b) / 2
    return ret_dist, ret_gen