import numpy as np
import math, time, collections, os, errno, sys, code, random, pickle
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn import mixture
from sklearn.cluster import KMeans
import pandas as pd
from multiprocessing import Pool

from mtticc.src.TICC_helper import *
from mtticc.src.admm_solver import ADMMSolver
from mtticc.General_utils import *

from datetime import datetime
from multiprocessing.dummy import Pool as ThreadPool

from functools import partial
from multiprocessing.dummy import Pool as ThreadPool


class MTTICC:
    def __init__(self, fixed_window=10, number_of_clusters=5, lambda_parameter=11e-2,
                 beta=400, maxIters=1000, threshold=2e-5, write_out_file=False,
                 prefix_string="", num_proc=1, compute_BIC=False, cluster_reassignment=20, biased=False,
                 decay_func='1/log(e+x)', input_pattern='single', window_pattern='fixed',
                 dynamic_attention='none'):
        """
        Parameters:
            - window_size: size of the sliding window
            - number_of_clusters: number of clusters
            - lambda_parameter: sparsity parameter
            - switch_penalty: temporal consistency parameter
            - maxIters: number of iterations
            - threshold: convergence threshold
            - write_out_file: (bool) if true, prefix_string is output file dir
            - prefix_string: output directory if necessary
            - cluster_reassignment: number of points to reassign to a 0 cluster
            - biased: Using the biased or the unbiased covariance
        """
        self.window_size = fixed_window
        self.number_of_clusters = number_of_clusters
        self.lambda_parameter = lambda_parameter
        self.switch_penalty = beta
        self.maxIters = maxIters
        self.threshold = threshold
        self.write_out_file = write_out_file
        self.prefix_string = prefix_string
        self.num_proc = num_proc # Number of processes
        self.compute_BIC = compute_BIC # Whether to compute the BIC
        self.cluster_reassignment = cluster_reassignment
        self.num_blocks = self.window_size + 1
        self.biased = biased
        self.decay_func = decay_func # Decay function 'log'/'exp.05'/'exp1'
        self.input_pattern = input_pattern
        self.window_pattern = window_pattern # the window to get the context info can be 'fixed' / 'temporal'
        self.dynamic_attention = dynamic_attention

        pd.set_option('display.max_columns', 500)
        np.set_printoptions(formatter={'float': lambda x: "{0:0.4f}".format(x)})
        np.random.seed(102)

    def fit(self, input_file):
        """
        Main method for TICC solver.
        Parameters:
            - input_file: location of the data file
        """
        assert self.maxIters > 0  # must have at least one iteration
        # self.log_parameters() # Display the critical parameters

        # The basic folder to be created
        str_NULL = self.prepare_out_directory()

        # Get data into proper format:
        # - Return the data array list, row number list, column number

        seq_arr_list, seq_rows_size_list, seq_col_size, seq_intervals = self.load_data(input_file)

        # Reshape the time intervals
        if self.window_pattern == 'fixed':
            intervals = [[0] + list(interv[self.window_size:]) for interv in seq_intervals]
        elif self.window_pattern == 'dynamic':
            intervals = [[0] + list(interv[1:]) for interv in seq_intervals]
        # Convert the time intervals to weights
        w_intervals = [[FromIntervaltoWeight(x, func=self.decay_func) for x in tmp] for tmp in intervals]

        # Organize the data with context information
        # - Stack the features inside a window into a vector
        complete_D_list, complete_D_all = self.stack_training_data(seq_arr_list, seq_rows_size_list,
                                                                   seq_intervals, seq_col_size)

        # Initialization
        # - Gaussian Mixture
        gmm = mixture.GaussianMixture(n_components=self.number_of_clusters, covariance_type="full")
        gmm.fit(complete_D_all)
        clustered_points_all = gmm.predict(complete_D_all)
        gmm_clustered_pts = clustered_points_all + 0

        train_cluster_inverse = {}
        log_det_values = {}
        computed_covariance = {}
        cluster_mean_info = {}
        cluster_mean_stacked_info = {}
        old_clustered_points = None
        empirical_covariances = {}

        # PERFORM TRAINING ITERATIONS
        pool = Pool(processes=self.num_proc)
        # Check the convergence of the algorithm
        lle_list, res_pri_list, res_dual_list = [],[],[]
        reassign_counter = []
        reassign_max_count = 5

        iters_times = []
        for iters in range(self.maxIters):
            start = time.time()
            print("\n\n\nITERATION ###", iters)
            # Get the index of events for each cluster from the stacked data (complete_D_all)
            train_clusters_arr = collections.defaultdict(list)
            for point, cluster_num in enumerate(clustered_points_all):
                train_clusters_arr[cluster_num].append(point)

            # Get the size of each cluster
            len_train_clusters = {k: len(train_clusters_arr[k]) for k in range(self.number_of_clusters)}

            opt_res = self.train_clusters(cluster_mean_info, cluster_mean_stacked_info, empirical_covariances,
                                          complete_D_all, train_clusters_arr, len_train_clusters, seq_col_size, pool)

            # Update the cluster parameters based on the ADMM learned results
            res_pri, res_dual = self.optimize_clusters(opt_res, log_det_values, computed_covariance,
                                                       train_cluster_inverse, len_train_clusters)
            res_pri_list.append(res_pri)
            res_dual_list.append(res_dual)

            # Update the old computed covariance
            old_computed_covariance = computed_covariance
            print("UPDATED THE OLD COVARIANCE")

            # Set up the parameters for the E-step
            self.trained_model = {'complete_D_all': complete_D_all,
                                  'complete_D_list': complete_D_list,
                                  'cluster_mean_info': cluster_mean_info,
                                  'cluster_mean_stacked_info': cluster_mean_stacked_info,
                                  'log_det_values': log_det_values,
                                  'computed_covariance': computed_covariance,
                                  'train_cluster_inverse': train_cluster_inverse,
                                  'seq_col_size': seq_col_size}

            # Predict the cluster belongings for each timestamp
            clustered_points_list, clustered_points_all, sum_lle_list, sum_lle_all, tmp_time = self.predict_clusters(w_intervals)
            lle_list.append(sum_lle_all)

            # Update the lengths
            new_train_clusters = collections.defaultdict(list)
            for point, cluster in enumerate(clustered_points_all):
                new_train_clusters[cluster].append(point)
            # Get the size of each cluster
            len_new_train_clusters = {k: len(new_train_clusters[k]) for k in range(self.number_of_clusters)}
            # Make a copy of the cluster belonging before the reassignment
            before_empty_cluster_assign = clustered_points_all.copy()
            nc = self.number_of_clusters

            # ----------------------------------------------------------------------------------------------------------
            # Reassign the data to empty clusters
            if iters != 0:
                cluster_norms = [(np.linalg.norm(old_computed_covariance[self.number_of_clusters, i]), i)
                                 for i in range(self.number_of_clusters)]
                norms_sorted = sorted(cluster_norms, reverse=True)
                # clusters that are not 0 as sorted by norm
                valid_clusters = [cp[1] for cp in norms_sorted if len_new_train_clusters[cp[1]] != 0]

                # Check if there are consecutive reassignment actions --
                # To get avoid the repeatedly generation and reassignment for empty clusters
                ept_cluster = 0 in len_new_train_clusters.values()
                # If the stack is not full, push the non-empty cluster size of current iter into the stack
                if len(reassign_counter) < reassign_max_count:
                    reassign_counter.append(ept_cluster)
                # Else, pop the oldest value and push the value of current iter
                else:
                    reassign_counter = reassign_counter[1:]
                    reassign_counter.append(ept_cluster)

                # Reassign data to the empty clusters (assuming more non empty clusters than empty ones)
                # Check if the reassignment times have reached the reassign_max_count
                reassign_flag = False
                if sum(reassign_counter) <= reassign_max_count:
                    counter = 0 # counter index for the non-empty clusters
                    for cluster_num in range(self.number_of_clusters):
                        if len_new_train_clusters[cluster_num] == 0:
                            reassign_flag = True
                            cluster_selected = valid_clusters[counter]  # a cluster that is not len 0
                            counter = (counter + 1) % len(valid_clusters) # select non-empty cluster in a loop
                            print("cluster that is zero is:", cluster_num, "selected cluster instead is:", cluster_selected)
                            # select a random starting point from the selected non-empty cluster
                            start_point = np.random.choice(new_train_clusters[cluster_selected])
                            # reassign cluster_reassignment number of points from point_num
                            for i in range(0, self.cluster_reassignment):
                                point_to_move = start_point + i
                                if point_to_move >= len(clustered_points_all): # till the end of the selected cluster
                                    break
                                # Reset the cluster belonging labels for the reassigned points
                                clustered_points_all[point_to_move] = cluster_num
                                # Reset the covariance, mean_stacked_info, mean_info of the reassigned cluster
                                computed_covariance[nc, cluster_num] = old_computed_covariance[nc, cluster_selected]
                                cluster_mean_stacked_info[nc, cluster_num] = complete_D_all[point_to_move, :]
                                cluster_mean_info[nc, cluster_num] = complete_D_all[point_to_move, :][
                                      (self.window_size - 1) * seq_col_size:self.window_size * seq_col_size]

                # Display the cluster size after reassignment
                if reassign_flag:
                    print('** Cluster belonging after reassignment ** ')
                    for cluster_num in range(self.number_of_clusters):
                        print("length of cluster #", cluster_num, "-------->",
                              sum([x == cluster_num for x in clustered_points_all]))

            # Check the convergence of the clustering results
            # If the clustering results in current iter is the same as previous iter, break
            if np.array_equal(old_clustered_points, clustered_points_all):
                print("\n\n\n\nCONVERGED!!! BREAKING EARLY!!!")
                break
            old_clustered_points = before_empty_cluster_assign
            # end of training
            iters_times.append(time.time() - start)

        print('*************** Avg Iter Time: ', np.mean(iters_times), '***************')

        # Close the pool
        if pool is not None:
            pool.close()
            pool.join()

        # Plot the clustering results in a figure
        # Write the clustering results to a npy file
        self.write_plot(clustered_points_all, str_NULL)
        np.save(str_NULL+'clusteringResults.npy', clustered_points_list)

        # Compute the Bayesian inference criterion
        bic = 0
        if self.compute_BIC:
            bic = computeBIC(clustered_points_list, train_cluster_inverse, empirical_covariances)

        # Intermediate data to check the convergence
        conv_list = [lle_list, res_pri_list, res_dual_list]

        # Convert window labels to the label for each point
        if self.window_pattern == 'fixed':
            label_list = [[int(tmp[0])]*(self.window_size-1) + [int(i) for i in tmp] for tmp in clustered_points_list]
        elif self.window_pattern == 'dynamic':
            label_list = [[int(i) for i in tmp] for tmp in clustered_points_list]

        return clustered_points_list, label_list, conv_list, bic, train_cluster_inverse, empirical_covariances, complete_D_list #, np.mean(iters_times)

    # Compute the number of matched points for each method
    def compute_matches(self, train_confusion_matrix_EM, train_confusion_matrix_GMM, train_confusion_matrix_kmeans):
        # Find the index of the matched cluster
        matching_Kmeans = find_matching(train_confusion_matrix_kmeans)
        matching_GMM = find_matching(train_confusion_matrix_GMM)
        matching_EM = find_matching(train_confusion_matrix_EM)
        correct_e_m = 0
        correct_g_m_m = 0
        correct_k_means = 0
        for cluster in range(self.number_of_clusters):
            correct_e_m += train_confusion_matrix_EM[cluster, matching_EM[cluster]]
            correct_g_m_m += train_confusion_matrix_GMM[cluster, matching_GMM[cluster]]
            correct_k_means += train_confusion_matrix_kmeans[cluster, matching_Kmeans[cluster]]
        return (correct_e_m, correct_g_m_m, correct_k_means), (matching_EM, matching_GMM, matching_Kmeans)

    # Save the clustering result to display in a figure
    # Input:
    # clustered_points_all: clustering result labels
    # str_NULL: folder to save the figure
    def write_plot(self, clustered_points_all, str_NULL):
        plt.figure()
        plt.plot(np.arange(len(clustered_points_all)), clustered_points_all, color="r")
        plt.ylim((-0.5, self.number_of_clusters + 0.5))
        if self.write_out_file:
            plt.savefig(str_NULL + "TRAINING_EM_lam_sparse=" + str(self.lambda_parameter) + "switch_penalty = "
                        + str(self.switch_penalty) + ".jpg")
        plt.close("all")
        print("Done writing the figure")

    # Calculate the LLE for each timestamp in sequences
    # Input:
    # - n: number of features
    # - complete_D_list: a list of array sequences
    # - cluster_mean_stacked_info: dictionary with the average event with context info
    # - log_det_values: dictionary with the log-det for each cluster
    # - train_cluster_inverse: dictionary with the inverse-covariance for each cluster
    def smoothen_clusters(self, n, complete_D_list, cluster_mean_stacked_info,
                          log_det_values, train_cluster_inverse):
        print("Beginning the smoothing ALGORITHM")
        LLE_all_list = []
        for cIdx in range(len(complete_D_list)):
            # Slice the current sequence
            tmp_complete_D = complete_D_list[cIdx]
            clustered_points_len = len(tmp_complete_D)

            # Initialize the LLE matrix for the current sequence
            LLE_all_points_clusters = np.zeros([clustered_points_len, self.number_of_clusters])

            # Calculate the LLE for each timestamp in the sequence
            for point in range(clustered_points_len):
                for cluster in range(self.number_of_clusters):
                    cluster_mean_stacked = cluster_mean_stacked_info[self.number_of_clusters, cluster]
                    x = tmp_complete_D[point, :] - cluster_mean_stacked
                    inv_cov_matrix = train_cluster_inverse[self.number_of_clusters, cluster]
                    log_det_cov = log_det_values[self.number_of_clusters, cluster]
                    # TODO: CHECK THIS FORMULA:
                    # log_det_cov --> - log_det_cov according to the original paper
                    # LLE_all_list is the negative of the true LLE actually!
                    lle = np.dot(x.reshape([1, self.window_size * n]),
                                 np.dot(inv_cov_matrix, x.reshape([n * self.window_size, 1]))) + log_det_cov
                    LLE_all_points_clusters[point, cluster] = lle

            LLE_all_list.append(LLE_all_points_clusters)
        return LLE_all_list

    # Update the cluster parameters based on the ADMM learned results
    # Input variables:
    # - optRes: dictionary to store the ADMM learned results
    # - log_det_values: dictionary with the log-det for each cluster
    # - computed_covariance: dictionary with the covariance for each cluster
    # - train_cluster_inverse: dictionary with the inverse-covariance for each cluster
    # - len_train_clusters: dictionary with the size for each cluster
    def optimize_clusters(self, optRes, log_det_values, computed_covariance, train_cluster_inverse, len_train_clusters):
        # Primal and dual residual values for checking the optimization convergence for each cluster
        cluster_res_pri, cluster_res_dual = [],[]

        # For each cluster, update the parameters given the ADMM learned results
        for cluster in range(self.number_of_clusters):
            if optRes[cluster] == None:
                continue
            val, res_pri, res_dual= optRes[cluster].get() # Get the ADMM returned values
            cluster_res_pri.append(res_pri)
            cluster_res_dual.append(res_dual)
            print("OPTIMIZATION for Cluster #", cluster, "DONE!!!")

            # THIS IS THE SOLUTION
            S_est = upperToFull(val, 0) # Get a symmetrical matrix based on val
            X2 = S_est
            u, _ = np.linalg.eig(S_est) # Get the eigenvalues for the S-est, used anywhere?
            cov_out = np.linalg.inv(X2) # inverse of X2 (inverse covariance) --> covariance

            # Store the log-det, covariance, inverse covariance
            # TODO: CHECK THIS FORMULA:
            # log(det(cov_out)) --> log(det(X2)) according to the original paper
            log_det_values[self.number_of_clusters, cluster] = np.log(np.linalg.det(cov_out)) # log-det of covariance
            computed_covariance[self.number_of_clusters, cluster] = cov_out # covariance
            train_cluster_inverse[self.number_of_clusters, cluster] = X2 # inverse covariance

        # Display the cluster size
        for cluster in range(self.number_of_clusters):
            print("length of the cluster ", cluster, "------>", len_train_clusters[cluster])

        return cluster_res_pri, cluster_res_dual

    # Learn the cluster parameters
    # Input:
    # - cluster_mean_info: dictionary with the average event without context info
    # - cluster_mean_stacked_info: dictionary with the average event with context info
    # - empirical_covariances: dictionary with the empirical covariance for each cluster
    # - complete_D_stacked: stacked sequences with the context info for each timestamp
    # - train_clusters_arr: dictionary with the index of event in the stacked data for each cluster
    # - len_train_clusters: dictionary with the size of each cluster
    # - n: number of features
    def train_clusters(self, cluster_mean_info, cluster_mean_stacked_info, empirical_covariances,
                       complete_D_stacked, train_clusters_arr, len_train_clusters, n, pool):
        # Initialize the parameter for each cluster
        # The value in optRes[i] stores the upper/lower values in the theta matrix
        optRes = [None for i in range(self.number_of_clusters)]

        # For each cluster: update the parameter by ADMMs
        for cluster in range(self.number_of_clusters):
            cluster_length = len_train_clusters[cluster]
            if cluster_length != 0:
                # Indices of samples(events) in the current cluster
                indices = train_clusters_arr[cluster]
                # Slice the events from the stacked data belonging to the current cluster
                D_train = np.zeros([cluster_length, self.window_size * n])
                for i in range(cluster_length):
                    D_train[i, :] = complete_D_stacked[indices[i], :]

                # Average of all timestamps for the right-most block (current timestamp) features inside the window
                curr_start, curr_end = (self.window_size - 1) * n, self.window_size * n
                cluster_mean_info[self.number_of_clusters, cluster] = np.mean(D_train, axis=0)[curr_start:curr_end].reshape([1, n])
                # Average of all timestamps for all stacked features inside the window
                cluster_mean_stacked_info[self.number_of_clusters, cluster] = np.mean(D_train, axis=0)

                # Fit a model - OPTIMIZATION
                probSize = self.window_size * n
                lamb = np.zeros((probSize, probSize)) + self.lambda_parameter
                S = np.cov(np.transpose(D_train), bias=self.biased)
                empirical_covariances[cluster] = S

                rho = 1
                solver = ADMMSolver(lamb, self.window_size, n, rho, S)
                # apply to process pool
                optRes[cluster] = pool.apply_async(solver, (1000, 1e-6, 1e-6, False,))
        return optRes

    # Reshape the input sequences
    # Input variables:
    # - Data: timestamps * feature_num
    # - m_list: a list of sequence lengths;
    # - n: feature_num
    def stack_training_data(self, Data, m_list, interval_list, n):
        # Generate the data with context info for each sequence
        complete_D_list = []
        # If window_pattern == 'fixed', each timestamp i takes its proceeding window_size-1 timestamps as context
        if self.window_pattern == 'fixed':
            for idx in range(len(Data)):
                complete_D_train = np.zeros([m_list[idx], self.window_size * n])
                training_indices = np.arange(m_list[idx])
                # Take the previous/future context information to generate the sliding window
                for i in range(m_list[idx]):
                    for k in range(self.window_size-1,-1,-1):
                        if i - k >= 0:
                            idx_k = training_indices[i - k]
                            # TODO: Flipped the direction of the context info with the current time stamp
                            # REVISED HERE: Flip the order of the sliding window to calculate the Toeplitz matrix
                            complete_D_train[i][k * n: (k + 1) * n] = Data[idx][idx_k][0:n]
                            # complete_D_train[i][(self.window_size-k-1)*n : (self.window_size-k)*n] = Data[idx][idx_k][0:n]
                # Truncate the first window_size samples
                complete_D_list.append(complete_D_train[self.window_size - 1:, :])

        # If window_pattern == 'dynamic', each timestamp
        elif self.window_pattern == 'dynamic':
            for seqIdx in range(len(Data)):
                # Initialize the context array
                context_array = np.zeros(np.shape(Data[seqIdx]))
                # Get the customized window sizes
                cus_windows = np.unique(self.dynamic_window_list)

                # Get the context info for features with different window size
                for tmp_window in cus_windows:
                    # Get the index of the columns with the window size tmp_window
                    col_idx = [idx for idx in range(len(self.dynamic_window_list))
                               if self.dynamic_window_list[idx] == tmp_window]
                    tmp_return = GetPreviousContextArray(Data[seqIdx][:, col_idx], interval_list[seqIdx],
                                                         self.decay_func, tmp_window, self.dynamic_attention)
                    context_array[:, col_idx] = tmp_return

                tmp_complete_D = np.concatenate((Data[seqIdx], context_array), axis=1)
                complete_D_list.append(tmp_complete_D)

        # Concatenate all arrays into a single one
        complete_D_all = np.concatenate(complete_D_list)
        return complete_D_list, complete_D_all

    # Generate the folder name (with critical parameters) to write the clustering results
    def prepare_out_directory(self):
        str_NULL = self.prefix_string + "lam_sparse=" + str(self.lambda_parameter) + "maxClusters=" + str(
            self.number_of_clusters + 1) + "/"
        if not os.path.exists(os.path.dirname(str_NULL)): # Create the folder str_NULL if it does not exit in the system
            try:
                os.makedirs(os.path.dirname(str_NULL))
            except OSError as exc:  # Guard against race condition of path already existing
                if exc.errno != errno.EEXIST:
                    raise
        return str_NULL

    # Load the sequential data from file
    # The data can be a single concatenated sequence or a set of sequences
    def load_data(self, input_file):
        with open(input_file, 'rb') as filehandle:
            Data, intervals = pickle.load(filehandle)

        if self.input_pattern == 'single':
            print('Loading data with a single sequence..')
        elif self.input_pattern == 'multiple':
            print('Loading data with multiple sequences..')

        (m_list, n) = [len(Data[i]) for i in range(len(Data))], np.shape(Data[0])[1]
        print("completed getting the data")
        return Data, m_list, n, intervals

    # Display the parameters
    def log_parameters(self):
        print("lam_sparse", self.lambda_parameter)
        print("switch_penalty", self.switch_penalty)
        print("num_cluster", self.number_of_clusters)
        print("num stacked", self.window_size)
    #
    # Predict the cluster belongings: Given the current trained model, predict clusters.
    # If the cluster segmentation has not been optimized yet, than this will be part of the interactive process.
    # - Args: a list of numpy arrays of data for which to predict clusters.
    # For each array, columns are dimensions of the data, each row is a different timestamp
    # - Returns: a list of vectors of predicted cluster for the points
    def predict_clusters(self, w_intervals_list, test_data_list = None):
        # if test_data is not None:
        #     if not isinstance(test_data, np.ndarray):
        #         raise TypeError("input must be a numpy array!")
        # else:
        #     test_data = self.trained_model['complete_D_train']
        # If test data is not given, set it as the training data
        if test_data_list is None:
            test_data_list = self.trained_model['complete_D_list']

        # SMOOTHING: Calculate the LLE for each timestamp in sequences
        lle_all_list = self.smoothen_clusters(self.trained_model['seq_col_size'],
                                              test_data_list,
                                              self.trained_model['cluster_mean_stacked_info'],
                                              self.trained_model['log_det_values'],
                                              self.trained_model['train_cluster_inverse'])

        # Update cluster points - using NEW smoothing
        start = time.time()
        clustered_points_list, neg_lle_list = [],[]
        for sIdx in range(len(lle_all_list)):
            # Find the path for each sequence
            clustered_points = updateClusters(lle_all_list[sIdx], w_intervals_list[sIdx],
                                              switch_penalty=self.switch_penalty)
            clustered_points_list.append(clustered_points)

            # Get the summed LLE for each sequence
            lle_all_points_clusters = lle_all_list[sIdx]
            neg_lle = sum([lle_all_points_clusters[idx][int(clustered_points[idx])]
                           for idx in range(len(lle_all_points_clusters))])
            neg_lle_list.append(neg_lle)
        tmp_time = time.time() - start
        print('CHECK TIME: ', tmp_time)

        clustered_points_all = np.concatenate(clustered_points_list)
        neg_lle_all = sum(neg_lle_list)

        return clustered_points_list, clustered_points_all, neg_lle_list, neg_lle_all, tmp_time
