import sys
#from tomita_rnn import *
import numpy as np
import os
from sklearn import cluster #MiniBatchKMeans
from sklearn.preprocessing import StandardScaler
from sklearn_extra.cluster import KMedoids
from sklearn import metrics
import DFA
#import kmc2
import time
import operator
#import dill
import pickle
from sklearn_som.som import SOM
from sklearn.manifold import TSNE
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('agg')
#from tomita_rnn import confusion_matrix
import pickle
import graphviz as gv
data_type = 'float64'
#input_dim = 5
#alphabet = [str(bit) for bit in range(1,input_dim+1)]
#

class DFACOntainer(object):
    def __init__(self, dfa, nstates, accuracy):
        self.dfa = dfa
        self.nstates = nstates
        self.accuracy = accuracy


def numpy_floatX(data):
    return np.asarray(data, dtype=data_type)


def perf_measure(logger, y_true, y_pred, ep = 0.5, use_self=False):

    if use_self:
        pos_label = 1.0
        neg_label = 0.0

        pos_id = np.where(y_true == pos_label)[0]
        aa = np.where(y_pred[pos_id]>=1.0-ep)[0]
        TP = numpy_floatX(aa.shape[0])

        aa = np.where(y_pred[pos_id]<1.0-ep)[0]
        FN = numpy_floatX(aa.shape[0])

        neg_id = np.where(y_true == neg_label)[0]
        aa = np.where(y_pred[neg_id]>=ep)[0]
        FP =numpy_floatX(aa.shape[0])

        aa = np.where(y_pred[neg_id]<ep)[0]
        TN = numpy_floatX(aa.shape[0])

        # TN, FP, FN, TP = confusion_matrix(y_true, y_pred)

        precision = TP / (TP + FP + 1e-5)
        recall = TP / (TP + FN + 1e-5)
        accuracy = (TP + TN) / (TP + FP + FN + TN)
        f1 = (2 * precision * recall) / (precision + recall+1e-6)

        logger.info("TP: %s FP: %s TN: %s FN: %s" % (TP, FP, TN, FN))
        return (precision, recall, accuracy, f1, FP, FN)
    else:
        y_pred_int = np.ones_like(y_pred, dtype='int2')
        aa = np.where(y_pred < 1.0-ep)[0]
        y_pred_int[aa] = 0

        precision = metrics.precision_score(y_true=y_true, y_pred=y_pred_int)
        recall = metrics.recall_score(y_true=y_true, y_pred=y_pred_int)
        accuracy = metrics.accuracy_score(y_true=y_true, y_pred=y_pred_int)
        f1 = metrics.f1_score(y_true=y_true, y_pred=y_pred_int)

        return(precision, recall, accuracy, f1)


def test_dfa(logger, data, inp_len, y_true, dfa, alphabet):
    logger.info('\n')
    logger.info("==Start Testing Extracted DFA===")
    y_pred = np.zeros((data.shape[0],), dtype='int32')
    for seq_id in range(data.shape[0]):
        seq_instance = data[seq_id, :inp_len[seq_id]]
        #print(seq_instance)
        #print(alphabet)
        #print(seq_instance)
        input_seq = [alphabet[seq] for seq in seq_instance]
        #input_seq = ''.join(input_seq)
        dfa.reset()
        dfa.input_sequence(input_seq)
        if dfa.status():
            y_pred[seq_id] = 1
        dfa.reset()

    logger.info("y_true : {}".format(y_true))
    logger.info("y_pred : {}".format(y_pred))
    # (precision, recall, accuracy, f1) = perf_measure(y_true=y_true, y_pred=y_pred, use_self=True)
    (precision, recall, accuracy, f1, fp, fn) = perf_measure(logger, y_true=y_true, y_pred=y_pred, use_self=True)
    logger.info("Precision: %s Recall: %s Accuracy: %s F1: %s" % (precision, recall, accuracy, f1))
    return [precision, recall, accuracy, f1, fp, fn]


def plot_and_save(data, pred, fname):
    if not os.path.exists(os.path.dirname(fname)):
        os.makedirs(os.path.dirname(fname)) 
    plt.figure(figsize=(8, 6), dpi=200)
    plt.scatter(data[:, 0], data[:, 1], c=pred)
    plt.title('Clustering of Hidden States')
    plt.xlabel('d0')
    plt.ylabel('d1')
    plt.savefig(fname)



# def remove_fp_fn(x, x_len, h_log, y, preds):
#     fp_mask = get_fp_mask(y, preds)
#     fn_mask = get_fn_mask(y, preds)
#     #logger.info(fp_mask.shape, fp_mask.nonzero().squeeze().shape, fn_mask.shape, fn_mask.nonzero().squeeze().shape)
#     FP_id = fp_mask.nonzero().squeeze()
#     FN_id = fp_mask.nonzero().squeeze()

#     indexes = []

#     if len(FP_id.shape) == 0 or len(FP_id) == 0:
#         indexes = FN_id.numpy()
#     elif len(FN_id.shape) == 0  or len(FN_id) == 0:
#         indexes = FP_id.numpy()
#     else:
#         indexes = torch.cat([FP_id, FN_id], axis = 0).numpy()


#     x = x.numpy()
#     x_len = x_len.numpy()
#     y = y.numpy()

#     x = np.delete(x, indexes , axis = 0)
#     x_len = np.delete(x_len, indexes, axis = 0)
#     y = np.delete(y, indexes, axis = 0)

#     x = x[:, 1:] # remove <START>
#     x_len = x_len - 2  # remove <END> with padding
#     h_log = h_log[:, :-2] # use hstates accordingly 

#     return x, x_len, h_log, y, y_preds


def extract_dfas(logger, dfa_min_state, dfa_max_state, dfa_state_step, cluster_method, 
                    x, x_len, h_log, y, input_dim, input_range, alphabet = None, normalize_states = False, 
                    pos_data_only = False, tsne = False, tsne_fname = None, use_prune = True):
    
    # input range is the the list of symbols in x. Its a bit hacky right now,
    ##----- Extraction Options -------------------
    state_num_min = dfa_min_state
    state_num_max = dfa_max_state
    state_step = dfa_state_step
    use_pos_data_only = pos_data_only
    if alphabet is None:
        alphabet = [str(bit) for bit in range(1, input_dim + 1)]

    if use_pos_data_only:
        logger.info('Remove negative samples')
        pos_idx = np.where(y == 1)[0]
        x = x[pos_idx]
        x_len = x_len[pos_idx]
        y = y[pos_idx]
        h_log = h_log[pos_idx]
        logger.info("input shape after removing negative samples : {}".format(x.shape))
    #----------------------------------------------
    y = y.astype('int32')
    [sample_num, sample_len] = x.shape
    #logger.info(x_len)
    seq_end_list = x_len

    x_list = []
    h_list = []
    for x_id in range(sample_num):
        x_list.append(x[x_id, :seq_end_list[x_id]])
        for seq_id in range(seq_end_list[x_id]+1):
            h_list.append(h_log[x_id, seq_id]) 

    #print(x_list)
    h_list = np.array(h_list, dtype=data_type)

    if normalize_states:
        h_list = StandardScaler().fit_transform(h_list)

    n_states_range = range(state_num_min,state_num_max,state_step)
    #accuracy_log = np.zeros((len(n_states_range),),dtype=data_type)
    accuracy_train_log = np.zeros((len(n_states_range),),dtype=data_type)
    n_states_log = np.zeros((len(n_states_range),),dtype='int32')
    dfa_list = {}
    for n_states in range(state_num_min,state_num_max,state_step):
        print("=======================================================================================================================================")
        logger.info('\n')
        logger.info('Begin DFA extraction with n_state:%d' %(n_states))

        states = ["q"+str(ind) for ind in range(1,n_states+1)]

        start_clustering_time = time.time()

        logger.info("Num of states : {}, clustering method : {}".format(n_states, cluster_method))
        if cluster_method == "kmeans":
            # seeding = kmc2.kmc2(X=h_list, k=n_states, chain_length=200, afkmc2=True,
            #                     weights=None, random_state=np.random.RandomState(0))
            cluster_model = cluster.MiniBatchKMeans(n_clusters=n_states)
        elif cluster_method== "kmedoids":
            cluster_model = KMedoids(n_clusters=n_states)
            print(cluster_method, " created")
        elif cluster_method == "som":
            cluster_model = SOM(m=n_states, n=1, dim=h_list.shape[1])
        print(cluster_method)
        cluster_model.fit(h_list)
        logger.info('Done fitting takes time: %f' % (time.time() - start_clustering_time))

        if hasattr(cluster_model, 'labels_'):
            h_pred = cluster_model.labels_.astype(np.int) + 1
        else:
            h_pred = cluster_model.predict(h_list) + 1

        logger.info("h_list : {} {}".format(h_list.shape, h_list.dtype))
        if tsne:
            h_embedded = TSNE(n_components=2, init='random').fit_transform(h_list)
            plot_and_save(h_embedded, h_pred, tsne_fname)

        #logger.info("Centeroids :")
        #logger.info(cluster_model.cluster_centers_)
        #logger.info("h Predictions")
        #logger.info(h_pred)
        transitions = {}
        for idx in range(1,n_states+1):
            transitions["q"+str(idx)] = {}

        #score = metrics.silhouette_score(h_reshaped, h_pred, metric='euclidean')
        #logger.info(score)

        start_state = np.zeros(sample_num,'int32')
        end_states = np.zeros(sample_num,'int32')
        transit_states_cnt = np.zeros((n_states, input_dim, n_states),dtype='int32')

        mask_id_pre = 0

        for x_id in range(sample_num):
            h_pred_one = h_pred[mask_id_pre:(mask_id_pre+seq_end_list[x_id]+1)]
            #print(len(h_pred_one), seq_end_list[x_id])
            mask_id_pre += seq_end_list[x_id]+1
            start_state[x_id] = h_pred_one[0]
            end_states[x_id] = h_pred_one[-1]

            for seq_id in range(seq_end_list[x_id]):
                current_state = h_pred_one[seq_id]
                #print("alphabet : ", alphabet)
                #print(x_list[x_id][seq_id])
                current_input = alphabet[x_list[x_id][seq_id]]
                next_state = h_pred_one[seq_id+1]
                transitions["q"+str(current_state)][current_input] = "q" + str(next_state)
                #logger.info("{}, {}, {}, {}, {}".format(transit_states_cnt.shape, x_id, seq_id, next_state-1, x_list[x_id][seq_id] ))
                transit_states_cnt[current_state-1, x_list[x_id][seq_id], next_state-1] += 1

        unique_start_state, start_state_count = np.unique(start_state,return_counts=True)
        logger.info('Start states')
        logger.info(unique_start_state)
        logger.info('\n')
        logger.info(start_state_count)
        unique_start_state = unique_start_state[np.argmax(start_state_count)]
        logger.info('Final start state: {}'.format(unique_start_state))

        end_states = end_states * y
        unique_end_state, end_state_count = np.unique(end_states, return_counts=True)

        if not use_pos_data_only:
            #TODO: why am i removing this? why ?
            print("not use_pos_data_only : ", unique_end_state[0], end_state_count[0])
            unique_end_state = unique_end_state[1:]
            end_state_count = end_state_count[1:]
        
        logger.info('End states')
        logger.info(unique_end_state)
        logger.info('\n')
        logger.info(end_state_count)

        if use_prune:
            unique_end_state = np.delete(unique_end_state, np.where(end_state_count < 10)[0])
        
        print(unique_end_state, unique_end_state.dtype)
        accepts = ["q"+str(ind) for ind in unique_end_state]

        for state_idx in range(n_states):
            states_log = np.zeros(input_dim)
            #for input_idx in range(1, input_dim):
            for input_idx in input_range:
                if (np.sum(transit_states_cnt[state_idx,input_idx,:] == 0) == n_states):
                    transitions["q" + str(state_idx+1)][alphabet[input_idx]] = "q"+str(state_idx+1)
                elif (np.sum(transit_states_cnt[state_idx,input_idx,:] == 0) < n_states):
                    idx = np.where(transit_states_cnt[state_idx,input_idx,:]>0)[0]
                    multi_states = transit_states_cnt[state_idx,input_idx,idx]
                    #print(transit_states_cnt[state_idx,input_idx,:])
                    #print("multi_states : ", multi_states)
                    idx = idx[np.argmax(multi_states)]
                    #print(idx)
                    transitions["q"+str(state_idx+1)][alphabet[input_idx]] = "q" +str(idx+1)
                states_log[input_idx] = int(transitions["q"+str(state_idx + 1)][alphabet[input_idx]][1:])
                if np.unique(states_log).shape[0] == 1:
                    if str(np.unique(states_log)[0]) in accepts:
                        #print(accepts)
                        #print(transitions["q"+str(state_idx + 1)][alphabet[input_idx]])
                        accepts.remove(transitions["q"+str(state_idx + 1)][alphabet[input_idx]])


        logger.info('Final accept states: {}'.format(accepts))
        sys.stdout.flush()

        start = "q"+str(unique_start_state)

        # if n_states == 26:
        #     np.savez(''.join((params_dir,model_name,'_dfa_config_',str(n_states),'.npz')),
        #              transitions=transitions, transit_states_cnt=transit_states_cnt,
        #              states=states, start=start, accepts=accepts, alphabet=alphabet)
        #     logger.info('Done saving the best DFA')
        print("transition : ", transitions)
        print("alphabet : ", alphabet)
        delta = (lambda s, a: transitions[s][a])
        print(transitions)
        if isinstance(alphabet, dict):
            alpha = list(alphabet.values())
        else:
            alpha = alphabet

        d0 = DFA.DFA(states=states, start=start, accepts=accepts, alphabet=alpha, delta=delta)
        #logger.info("Given a binary input, d accepts if the number represented is divisible by 5 (plus the empty string):")

        #logger.info("Resetting...")
        d0.reset()
        logger.info("test before minimize")
        logger.info(" Original DFA without minimize: ")
        d0.pretty_print()

        [precision, recall, accuracy, f1, fp, fn] =  test_dfa(logger, data=x, inp_len=x_len, y_true=y, dfa=d0, alphabet=alphabet)
        #accuracy_train_log[n_states_range.index(n_states)] = accuracy

        #raw_input()
        logger.info("==Minimized===")
        d0.minimize()
        d0.pretty_print()
        logger.info("This DFA has %s states" % len(d0.states))
        #csvoutline.append(str(len(d0.states)))
        logger.info('\n')
        #d0.pretty_print()
        n_states_log[n_states_range.index(n_states)] = len(d0.states)
        logger.info('Evaluate on training set')
        [precision, recall, accuracy, f1, fp, fn] =  test_dfa(logger, data=x, inp_len=x_len, y_true=y, dfa=d0, alphabet=alphabet)
        accuracy_train_log[n_states_range.index(n_states)] = accuracy

        dfac = DFACOntainer(d0, len(d0.states), accuracy_train_log[n_states_range.index(n_states)])
        dfa_list[n_states] = dfac

    for ns, dfac in dfa_list.items():
        logger.info("{} : {}".format(ns, dfac.accuracy))

    return dfa_list, alpha

def create_diagraph(dfa, alphabet, name):
    dot = gv.Digraph(name=name, directory="../dfa_diagrams/", format="png")
    named_states = {s : "S{}".format(i+1) for i, s in enumerate(dfa.states)}
    for state in dfa.states:
        shape="oval" 
        label=named_states[state]
        color="black"
        if state == dfa.start:
            shape="hexagon" 
            label="Start"
        if state in dfa.accepts:
            color = "green" 
        
        dot.node(named_states[state], color=color, shape=shape, label=label)
    
    if isinstance(alphabet, dict):
        labels = list(alphabet.values())
    else: labels = alphabet

    for c in labels:
        for s, ns in named_states.items():
            o = dfa.delta(s, c)
            dot.edge(ns, named_states[o], label = c)

    dot.render()
    return dot


def get_best_dfa(logger, dfa_list, val_x, val_x_len, val_y, alphabet):
    best_dfa = None
    best_k = None
    best_acc = 0
    for k, dfac in dfa_list.items():
        [precision, recall, accuracy, f1, fp, fn] =  test_dfa(logger, data=val_x, inp_len=val_x_len, y_true=val_y, dfa=dfac.dfa, alphabet=alphabet)
        logger.info("Clustered States : {} Final States : {}".format(k, dfac.nstates))
        logger.info("Train Accuracy : {}, Validation Accuracy : {}".format(dfac.accuracy, accuracy))
        logger.info("Validation: \nprecision :{} \nrecall : {} \nf1 : {} \nfp : {} \nfn : {}".format(precision, recall, accuracy, f1, fp, fn))
        if accuracy > best_acc:
            best_acc = accuracy
            best_dfa = dfac
            best_k = k

    return best_dfa, best_k


def get_transition_from_delta(delta, alphabet, states):
    transitions = {s : {} for s in states}
    for sc in states:
        for ii in alphabet:
            transitions[sc][ii] = delta(sc, ii) 

    return transitions


def save_dfa(dfa, path):
    doc = {
        "states" : dfa.states, 
        "start" : dfa.start,
        "transitions": get_transition_from_delta(dfa.delta, dfa.alphabet, dfa.states),
        "accepts" : dfa.accepts, 
        "alphabet" :dfa.alphabet 
    }

    with open(path, "wb") as ff:
        pickle.dump(doc, ff)


def load_dfa(path):
    with open(path, "rb") as ff:
        doc = pickle.load(ff)

    dfa = DFA.DFA(states=doc["states"], 
                    start=doc["start"], 
                    accepts=doc["accepts"], 
                    alphabet=doc["alphabet"], 
                    delta=(lambda s, a: doc["transitions"][s][a]))

    return dfa
