import numpy as np
import argparse
import time
import os
import collections
import json
import queue
import time

from utils.data_utils import load_dataset_numpy

import scipy.spatial.distance

from scipy.sparse import csr_matrix, coo_matrix
from scipy.sparse.csgraph import maximum_flow
from utils.flow import _make_edge_pointers

from utils.io_utils import init_dirs, model_naming, test_argparse

from cvxopt import solvers, matrix, spdiag, log, mul, sparse, spmatrix

def minll(G,h,p):
    m,v_in=G.size
    def F(x=None,z=None):
        if x is None:
            return 0, matrix(1.0,(v,1))
        if min(x)<=0.0:
            return None
        f = -sum(mul(p,log(x)))
        Df = mul(p.T,-(x**-1).T)
        if z is None:
            return f,Df
        # Fix the Hessian
        H = spdiag(z[0]*mul(p,x**-2))
        return f,Df,H
    return solvers.cp(F,G=G,h=h)


def find_remaining_cap_edges(edge_ptr,capacities,heads,tails, source, sink):
    ITYPE = np.int32
    n_verts = edge_ptr.shape[0] - 1
    n_edges = capacities.shape[0]
    ITYPE_MAX = np.iinfo(ITYPE).max

    # Our result array will keep track of the flow along each edge
    flow = np.zeros(n_edges, dtype=ITYPE)

    # Create a circular queue for breadth-first search. Elements are
    # popped dequeued at index start and queued at index end.
    q = np.empty(n_verts, dtype=ITYPE)

    # Create an array indexing predecessor edges
    pred_edge = np.empty(n_verts, dtype=ITYPE)

    # While augmenting paths from source to sink exist
    for k in range(n_verts):
        pred_edge[k] = -1
    path_edges = []
    # Reset queue to consist only of source
    q[0] = source
    start = 0
    end = 1
    # While we have not found a path, and queue is not empty
    path_found = False
    while start != end and not path_found:
        # Pop queue
        cur = q[start]
        start += 1
        if start == n_verts:
            start = 0
        # Loop over all edges from the current vertex
        for e in range(edge_ptr[cur], edge_ptr[cur + 1]):
            t = heads[e]
            if pred_edge[t] == -1 and t != source and\
                    capacities[e] > flow[e]:
                pred_edge[t] = e
                path_edges.append((cur,t))
                if t == sink:
                    path_found = True
                    break
                # Push to queue
                q[end] = t
                end += 1
                if end == n_verts:
                    end = 0
    return path_edges

def create_graph_rep(edge_matrix,n_1,n_2):
    graph_rep = []
    for i in range(n_1+n_2+2):
        graph_rep.append([])
        if i==0:
            #source
            for j in range(n_1+n_2+2):
                if j==0:
                    graph_rep[i].append(0)
                elif 1<=j<=n_1:
                    graph_rep[i].append(n_2)
                elif n_1<j<=n_1+n_2+1:
                    graph_rep[i].append(0)
        elif 1<=i<=n_1:
            # LHS vertices
            for j in range(n_1+n_2+2):
                if j<=n_1:
                    graph_rep[i].append(0)
                elif n_1<j<=n_1+n_2:
                    if edge_matrix[i-1,j-n_1-1]:
                        graph_rep[i].append(n_1*n_2)
                    else:
                        graph_rep[i].append(0)
                elif n_1+n_2<j:
                    graph_rep[i].append(0)
        elif n_1<i<=n_1+n_2:
            #RHS vertices
            for j in range(n_1+n_2+2):
                if j<=n_1+n_2:
                    graph_rep[i].append(0)
                elif j>n_1+n_2:
                    graph_rep[i].append(n_1)
        elif i==n_1+n_2+1:
            #Sink
            for j in range(n_1+n_2+2):
                graph_rep[i].append(0)

    graph_rep_array=np.array(graph_rep)

    return graph_rep_array

def set_classifier_prob_full_flow(top_level_vertices,n_1_curr,n_2_curr):
    for item in top_level_vertices:
        if item !=0 and item != sink_idx:
            classifier_probs[item-1,0]=n_1_curr/(n_1_curr+n_2_curr)
            classifier_probs[item-1,1]=n_2_curr/(n_1_curr+n_2_curr)

def set_classifier_prob_no_flow(top_level_vertices):
    for item in top_level_vertices:
        if item !=0 and item != sink_idx:
            if item<=n_1:
                classifier_probs[item-1,0]=1
                classifier_probs[item-1,1]=0
            elif item>n_1:
                classifier_probs[item-1,0]=0
                classifier_probs[item-1,1]=1

def graph_rescale(graph_rep_curr,top_level_indices):
    n_1_curr=len(np.where(top_level_indices<=n_1)[0])-1
    n_2_curr=len(np.where(top_level_indices>n_1)[0])-1
    # source rescale
    # print(graph_rep_curr[0])
    graph_rep_curr[0,:]=graph_rep_curr[0,:]/n_2
    graph_rep_curr[0,:]*=n_2_curr
    # print(graph_rep_curr[0])
    # bipartite graph edge scale
    graph_rep_curr[1:n_1_curr+1,:]=graph_rep_curr[1:n_1_curr+1,:]/(n_1*n_2)
    graph_rep_curr[1:n_1_curr+1,:]*=(n_1_curr*n_2_curr)
    # sink edges rescale
    graph_rep_curr[n_1_curr+1:,:]=graph_rep_curr[n_1_curr+1:,:]/n_1
    graph_rep_curr[n_1_curr+1:,:]*=n_1_curr
    return graph_rep_curr,n_1_curr,n_2_curr

def find_flow_and_split(top_level_indices):
    top_level_indices_1=None
    top_level_indices_2=None
    #Create subgraph from index array provided
    graph_rep_curr = graph_rep_array[top_level_indices]
    graph_rep_curr = graph_rep_curr[:,top_level_indices]
    graph_rep_curr,n_1_curr,n_2_curr = graph_rescale(graph_rep_curr,top_level_indices)
    graph_curr=csr_matrix(graph_rep_curr)
    flow_curr = maximum_flow(graph_curr,0,len(top_level_indices)-1)
    # Checking if full flow occurred, so no need to split
    if flow_curr.flow_value==n_1_curr*n_2_curr:
        set_classifier_prob_full_flow(top_level_indices,n_1_curr,n_2_curr)
        return top_level_indices_1,top_level_indices_2, flow_curr
    elif flow_curr.flow_value==0:
        set_classifier_prob_no_flow(top_level_indices)
        return top_level_indices_1,top_level_indices_2, flow_curr
    # Finding remaining capacity edges
    remainder_array = graph_curr-flow_curr.residual

    rev_edge_ptr, tails = _make_edge_pointers(remainder_array)

    edge_ptr=remainder_array.indptr
    capacities=remainder_array.data
    heads=remainder_array.indices

    edge_list_curr = find_remaining_cap_edges(edge_ptr,capacities,heads,tails,0,len(top_level_indices)-1)

#     print(edge_list_curr)
    gz_idx = []
    for item in edge_list_curr:
        gz_idx.append(item[0])
        gz_idx.append(item[1])
    if len(gz_idx)>0:
        gz_idx=np.array(gz_idx)
        gz_idx_unique=np.unique(gz_idx)
        top_level_gz_idx=top_level_indices[gz_idx_unique]
        top_level_gz_idx=np.insert(top_level_gz_idx,len(top_level_gz_idx),sink_idx)
        top_level_indices_1=top_level_gz_idx
    else:
        top_level_gz_idx=np.array([0,sink_idx])
    # Indices without flow
    top_level_z_idx=np.setdiff1d(top_level_indices,top_level_gz_idx)
    if len(top_level_z_idx)>0:
        # Add source and sink back to zero flow idx array
        top_level_z_idx=np.insert(top_level_z_idx,0,0)
        top_level_z_idx=np.insert(top_level_z_idx,len(top_level_z_idx),sink_idx)
        top_level_indices_2=top_level_z_idx
    
    return top_level_indices_1,top_level_indices_2, flow_curr

# def main():

parser = test_argparse()

# Lower bound compute args
parser.add_argument("--intersect_norm", default='l2',
                    help="norm to be used")
parser.add_argument('--use_test', dest='use_test', action='store_true')
parser.add_argument('--num_reps', type=int, default=1)
parser.add_argument('--alg_type',type=str)
parser.add_argument('--subsample_size',type=int, default=0)
parser.add_argument('--closest', action='store_true')
parser.add_argument('--start_idx', type=int, default=0)
parser.add_argument('--end_idx', type=int, default=10)
parser.add_argument('--input_eps', type=float, default=0.0)

args = parser.parse_args()

if args.n_classes == 2:
    args.class_1 = 3
    args.class_2 = 7
else:
    raise ValueError('Unsupported number of classes')

train_data, test_data, data_details = load_dataset_numpy(args, data_dir='/data/datasets',
                                                        training_time=False)

data_dim=data_details['n_channels']*data_details['h_in']*data_details['w_in']


# Getting model name details
args.trial_num=1
args.eps_step = args.epsilon*args.gamma/args.attack_iter
model_dir_name, log_dir_name, figure_dir_name, _ = init_dirs(
    args, train=False)
_, model_name = model_naming(args)


num_samples=args.num_samples

# Loading preordered data

x_filename = 'input_data/%s/%s_%s/%s_%s_%s_X.npy' % (args.dataset_in,args.class_1,args.class_2,args.class_1,args.class_2, args.dataset_in)
y_filename = 'input_data/%s/%s_%s/%s_%s_%s_Y.npy' % (args.dataset_in,args.class_1,args.class_2,args.class_1,args.class_2, args.dataset_in)


if os.path.exists(x_filename):
    X_curr=np.load(x_filename)
    Y_curr=np.load(y_filename)
else:
    raise ValueError('Where is the data?')

X_c1=X_curr[np.where(Y_curr==0)].reshape(args.num_samples,data_dim)
X_c2=X_curr[np.where(Y_curr==1)].reshape(args.num_samples,data_dim)

zero_indices=np.where(Y_curr==0)[0]
one_indices=np.where(Y_curr==1)[0]


if not os.path.exists('distances'):
    os.makedirs('distances')

if not os.path.exists('cost_results'):
    os.makedirs('cost_results')


rng = np.random.default_rng(77)

save_file_name = 'logloss_'
save_file_name_01 = 'zero_one_'

if 'all' in args.model:
    save_file_name += 'source' + '_' + str(args.subsample_size) + '_' + args.norm
    save_file_name_01 += 'source' + '_' + str(args.subsample_size) + '_' + args.norm
else:
    if args.closest:
        save_file_name += model_name + '_' + str(args.input_eps) + '_' + str(args.start_idx) + '_' + str(args.end_idx) + '_' + args.alg_type
    else:
        save_file_name += model_name + '_' + str(args.input_eps) + '_subsample' + str(args.subsample_size) + '_' + args.alg_type
        save_file_name_01 += model_name + '_' + str(args.input_eps) + '_subsample' + str(args.subsample_size) + '_' + args.alg_type

f = open('cost_results/' + args.dataset_in + '/' + str(args.class_1) + '_' + str(args.class_2) + '/' + save_file_name + '.txt', 'a')
f_time = open('cost_results/timing_results/' + args.dataset_in + '/' + str(args.class_1) + '_' + str(args.class_2) + '/' + save_file_name + '.txt', 'a')
f2 = open('cost_results/' + args.dataset_in + '/' + str(args.class_1) + '_' + str(args.class_2) + '/' + save_file_name_01 + '.txt', 'a')

intersect_eps_list=np.arange(2.0,6.2,0.2)

for intersect_eps in intersect_eps_list:
    loss_list = []
    time_list = []
    num_edges_list = []
    for rep in range(args.num_reps):
        print(save_file_name)
        if os.path.exists('graph_data/optimal_probs/' + args.dataset_in + '/' + str(args.class_1) + '_' + str(args.class_2) + '/' + save_file_name + '_' + str(intersect_eps) + '.txt'):
            prob_matrix=np.loadtxt('graph_data/optimal_probs/' + args.dataset_in + '/' + str(args.class_1) + '_' + str(args.class_2) + '/' + save_file_name + '_' + str(intersect_eps) + '.txt')

            zero_one_loss=0.0
            for i,item in enumerate(prob_matrix):
                if i<args.subsample_size:
                    if item[0]<item[1]:
                        zero_one_loss+=1.0
                else:
                    if item[0]>=item[1]:
                        zero_one_loss+=1.0
            final_zero_one_loss=zero_one_loss/(2*args.subsample_size)
            print('0-1 loss is %s' % (zero_one_loss/(2*args.subsample_size)))
            f2.write(str(intersect_eps)+','+ str(final_zero_one_loss)+'\n')
        else:
            if 'all' in args.model:
                dist_mat_name = 'input_data/' + args.dataset_in + '/' + str(args.class_1) + '_' + str(args.class_2) + '/' + str(args.class_1) + '_' + str(args.class_2) + '_' + args.dataset_in + '_dists.npy'
            else:
                if args.closest:
                    dist_mat_name = 'dl_output/' + args.dataset_in + '/' + str(args.class_1) + '_' + str(args.class_2) + '/results/' + model_name + '_' + str(args.input_eps) + '_' + str(args.start_idx) + '_' + str(args.end_idx) + '_' + args.alg_type + '.npy'
                else:
                    dist_mat_name = 'dl_output/' + args.dataset_in + '/' + str(args.class_1) + '_' + str(args.class_2) + '/results/' + model_name + '_' + str(args.input_eps) + '_subsample' + str(args.subsample_size) + '_' + args.alg_type + '.npy'

            print(dist_mat_name)

            if os.path.exists(dist_mat_name) and 'all' not in args.model:
                print('Loading distances')
                D_12 = np.load(dist_mat_name)
            else:
                # Call the function from the deeper layer intersection file
                if 'all' in args.model:
                    indices_1 = rng.integers(num_samples,size=args.subsample_size)
                    indices_2 = rng.integers(num_samples, size=args.subsample_size)

                    X_c1_curr = X_c1[indices_1]
                    X_c2_curr = X_c2[indices_2]

                    if args.intersect_norm == 'l2':
                        D_12 = scipy.spatial.distance.cdist(X_c1_curr,X_c2_curr,metric='euclidean')
                    elif args.intersect_norm == 'linf':
                        D_12 = scipy.spatial.distance.cdist(X_c1_curr,X_c2_curr,metric='chebyshev')
                    D_12/=2.0
                    # np.save(dist_mat_name, D_12)
                else:
                    ValueError('Compute dists first!')

            print(intersect_eps)
            # Add edge if cost 0
            edge_matrix = D_12 <= intersect_eps
            edge_matrix = edge_matrix.astype(float)

            num_edges = len(np.where(edge_matrix!=0)[0])
            print(num_edges)
            num_edges_list.append(num_edges)

            n_1=args.subsample_size
            n_2=args.subsample_size

            # Create graph representation
            graph_rep_array = create_graph_rep(edge_matrix,n_1,n_2)

            time1= time.clock()
            q = queue.Queue()
            # Initial graph indices
            q.put(np.arange(n_1+n_2+2))
            sink_idx=n_1+n_2+1
            count=0
            classifier_probs=np.zeros((n_1+n_2,2))
            while not q.empty():
                print('Current queue size at eps %s is %s' % (intersect_eps,q.qsize()))
                curr_idx_list=q.get()
                # print(q.qsize())
                list_1, list_2, flow_curr=find_flow_and_split(curr_idx_list)
                # print(list_1,list_2,flow_curr.flow_value)
                if list_1 is not None:
                    q.put(list_1)
                if list_2 is not None:
                    q.put(list_2)
            time2 = time.clock()

            loss = 0.0
            for i in range(len(classifier_probs)):
                if i<n_1:
                    loss+=np.log(classifier_probs[i][0])
                elif i>=n_1:
                    loss+=np.log(classifier_probs[i][1])
            loss=-1*loss/len(classifier_probs)
            print('Log loss for eps %s is %s' % (intersect_eps,loss))

            loss_list.append(loss)
            time_list.append(time2-time1)

            loss_avg=np.mean(loss_list)
            loss_var=np.var(loss_list)
            time_avg=np.mean(time_list)
            time_var=np.var(time_list)
            num_edges_avg=np.mean(num_edges_list)

            f.write(str(intersect_eps)+','+ str(loss_avg)+','+str(loss_var)+'\n')
            f_time.write(str(intersect_eps)+','+ str(time_avg)+','+str(time_var)+','+str(num_edges_avg)+'\n')
            np.savetxt('graph_data/optimal_probs/' + args.dataset_in + '/' + str(args.class_1) + '_' + str(args.class_2) + '/' + save_file_name + '_' + str(intersect_eps) + '.txt', classifier_probs, fmt='%.5f')


# if __name__ == "__main__":
#     main()