import numpy as np
import pandas as pd
from scipy.special import comb
try:
    import cupy as cp
except:
    pass
import matplotlib.pyplot as plt
from PIL import Image
from skimage.transform import rescale
import os
import sys

import time
from sklearn import decomposition

from latent_linear_model import latent_linear_model

def activate_nodes(C_vec, C, ci, N, depth):
    """
    A recursive function to generate only the nodes which
    are required.

    Parameters:
    -----------
    C_vec: numpy.ndarray
        Matrix used to store the output
    C: numpy.ndarray
        One possible combination
    ci: int
        Current index
    N: int
        Size of the array
    depth: int
        Current depth of the recussion

    Returns:
    --------
    None
    """
    if depth == 0:
        C_vec.append(C)
        return

    for c in range(ci, N-depth+1):
        C_current = np.array(C, dtype=bool)
        C_current[c] = True
        C_current = activate_nodes(C_vec, C_current, c+1, N, depth-1)

def get_higher_order_nodes(N, order):
    """
    Returns a array of arrays, where each row
    represents an active node for the order
    of interaction specified. N represents the
    number of elements in the node.

    Parameters:
    -----------
    N: int
        Number of nodes (or elements)
    order: int
        Order of interaction

    Returns:
    --------
    C_vec: numpy.ndarray
        Each row is one of all possible combinations
        for the higher order interaction
    """
    C_vec = list()
    for d in np.arange(order)+1:
        activate_nodes(C_vec, np.zeros(N, dtype=bool), 0, N, d)
    C_vec = np.array(C_vec, dtype=bool)
    return C_vec

def get_scaled_image(path, num_pixels):
    X_PIL = Image.open(path)
    X = np.array(X_PIL)
    scale = num_pixels / X.shape[0]
    X_scaled = rescale(X, scale, multichannel=True)
    return X_scaled

def mixed_signal(X, A, N, order):
    higher_order_list = get_higher_order_nodes(N, order)
    C = int(np.sum(comb(N,np.arange(order)+1)))
    M = np.zeros((C, X.shape[1]))

    for i, hoe in enumerate(higher_order_list):
        M[i, :] = np.prod(X[hoe, :], axis=0)
    M = np.dot(A, M)
    M = M.reshape(M.shape[0], int((M.shape[-1]/3)**0.5), int((M.shape[-1]/3)**0.5), 3)
    M /= np.max(M)
    return M

def preprocess_output(X):
    X = np.array(X)
    if np.min(X) < 0:
        X -= np.min(X)
    X = np.array((X / np.max(X)) * 254, dtype=np.uint8)
    return X

def create_mixed_signal(num_pixels, A, N, order, save_path):
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    pd.DataFrame(A).to_csv("{}/mixing.csv".format(save_path), index=False)

    if N >= 1:
        X0_scaled = get_scaled_image('./BSS_data/4.2.05.tiff', num_pixels) # Airplane (F-16)
        # X0_scaled = get_scaled_image('./BSS_data/4.2.03.tiff', num_pixels) # Mandrill
    if N >= 2:
        X1_scaled = get_scaled_image('./BSS_data/4.2.06.tiff', num_pixels) # Lake
        # X1_scaled = get_scaled_image('./BSS_data/4.2.01.tiff', num_pixels) # Splash
    if N >= 3:
        X2_scaled = get_scaled_image('./BSS_data/4.2.07.tiff', num_pixels) # Capcicum
        # X2_scaled = get_scaled_image('./BSS_data/4.1.08.tiff', num_pixels) # Jelly beans
    if N >= 4:
        X3_scaled = get_scaled_image('./BSS_data/4.2.03.tiff', num_pixels) # Mandrill
    if N >= 5:
        X4_scaled = get_scaled_image('./BSS_data/4.2.01.tiff', num_pixels) # Splash
    if N >= 6:
        X5_scaled = get_scaled_image('./BSS_data/4.1.08.tiff', num_pixels) # Jelly beans
    if N >= 7:
        X6_scaled = get_scaled_image('./BSS_data/4.1.07.tiff', num_pixels)

    path = './{}/Ground_Truth/'.format(save_path)
    if not os.path.exists(path):
        os.makedirs(path)
    if N >= 1:
        Image.fromarray(preprocess_output(X0_scaled)).save('{}/Ground_Truth/Ground_Truth_0.png'.format(save_path), format='png')
    if N >= 2:
        Image.fromarray(preprocess_output(X1_scaled)).save('{}/Ground_Truth/Ground_Truth_1.png'.format(save_path), format='png')
    if N >= 3:
        Image.fromarray(preprocess_output(X2_scaled)).save('{}/Ground_Truth/Ground_Truth_2.png'.format(save_path), format='png')
    if N >= 4:
        Image.fromarray(preprocess_output(X3_scaled)).save('{}/Ground_Truth/Ground_Truth_3.png'.format(save_path), format='png')
    if N >= 5:
        Image.fromarray(preprocess_output(X4_scaled)).save('{}/Ground_Truth/Ground_Truth_4.png'.format(save_path), format='png')
    if N >= 6:
        Image.fromarray(preprocess_output(X5_scaled)).save('{}/Ground_Truth/Ground_Truth_5.png'.format(save_path), format='png')
    if N >= 7:
        Image.fromarray(preprocess_output(X6_scaled)).save('{}/Ground_Truth/Ground_Truth_6.png'.format(save_path), format='png')

    if N == 1:
        X = np.array([X0_scaled.flatten()])
    elif N == 2:
        X = np.array([X0_scaled.flatten(), X1_scaled.flatten()])
    elif N == 3:
        X = np.array([X0_scaled.flatten(), X1_scaled.flatten(), X2_scaled.flatten()])
    elif N == 4:
        X = np.array([X0_scaled.flatten(), X1_scaled.flatten(), X2_scaled.flatten(),
            X3_scaled.flatten()])
    elif N == 5:
        X = np.array([X0_scaled.flatten(), X1_scaled.flatten(), X2_scaled.flatten(),
            X3_scaled.flatten(), X4_scaled.flatten()])
    elif N == 6:
        X = np.array([X0_scaled.flatten(), X1_scaled.flatten(), X2_scaled.flatten(),
            X3_scaled.flatten(), X4_scaled.flatten(), X5_scaled.flatten()])
    elif N == 7:
        X = np.array([X0_scaled.flatten(), X1_scaled.flatten(), X2_scaled.flatten(),
            X3_scaled.flatten(), X4_scaled.flatten(), X5_scaled.flatten(), X6_scaled.flatten()])

    M = mixed_signal(X, A, N, order)
    return M

def run_experiment(num_pixels, A, N, order, natural_gradient, save_path, n_proc):
    M = create_mixed_signal(num_pixels, A, N, order, save_path)

    if natural_gradient:
        iterations = 50
        verbose_step = 1
    else:
        iterations = 100000
        verbose_step = 100

    ica = latent_linear_model(M, N, model='ica', order=order, multi_channel=True, compute_partition_function=True, D_eps=0, device='gpu', n_proc=n_proc)
    ica.fit(lr=1, iterations=iterations, natural_gradient=natural_gradient, tol=1e-7, verbose=True, verbose_step=verbose_step, store_values=True, save_at_tol=[1e-2, 1e-3, 1e-4, 1e-5], save_path=save_path)
    ica.save_result(save_path)

def create_random_mixing_matrix(N, order, NM, seed=None):
    C = int(np.sum(comb(N,np.arange(order)+1)))
    if not seed==None:
        np.random.seed(1)
    A = 1 + np.random.uniform(0, 5 ,(NM,C))
    return A

def igbss(num_pixels, N, order, NM, natural_gradient, n_proc, name_string):
    if name_string == '':
        save_path = './results/ICA_experiment_pixel{}_N{}_order{}_NM{}_ng{}'.format(num_pixels, N, order, NM, natural_gradient)
    else:
        save_path = './results/run{}/ICA_experiment_pixel{}_N{}_order{}_NM{}_ng{}'.format(name_string, num_pixels, N, order, NM, natural_gradient)
    print(save_path)

    A = create_random_mixing_matrix(N=N,order=order,NM=NM,seed=1)
    # print('A:\n', A)
    run_experiment(num_pixels, A, N, order, natural_gradient, save_path, n_proc)

def comparison(num_pixels, N, order, NM, name_string):
    save_path = './results/run{}/comparisons/'.format(name_string)
    print(save_path)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    A = create_random_mixing_matrix(N=N,order=order,NM=NM,seed=1)
    # print('A:\n', A)
    M = create_mixed_signal(num_pixels, A, N, order, save_path)
    image_shape = M[0].shape
    M = np.array([m.flatten() for m in M])

    # rng = 1

    estimators = [
        ('Eigenfaces - PCA using randomized SVD',
         decomposition.PCA(n_components=N, svd_solver='randomized', random_state=time.time_ns() % 2**32,
                           whiten=True),
        'PCA'),

        ('Non-negative components - NMF',
         decomposition.NMF(n_components=N, init=None, max_iter=1000, tol=5e-3, random_state=time.time_ns() % 2**32),
        'NMF'),

        ('Non-negative components - NMF',
         decomposition.NMF(n_components=N, init='nndsvda', max_iter=1000, tol=5e-3, random_state=time.time_ns() % 2**32),
        'NMF-NNDSVDA'),

        ('Independent components - FastICA',
         decomposition.FastICA(n_components=N, algorithm='deflation', whiten=True, max_iter=1000, tol=1e-7, random_state=time.time_ns() % 2**32),
        'FastICA'),

    #     ('Sparse comp. - MiniBatchSparsePCA',
    #      decomposition.MiniBatchSparsePCA(n_components=N, alpha=0.8,
    #                                       n_iter=100, batch_size=3,
    #                                       random_state=None,
    #                                       normalize_components=True)),

    #     ('MiniBatchDictionaryLearning',
    #         decomposition.MiniBatchDictionaryLearning(n_components=N, alpha=0.1,
    #                                                   n_iter=50, batch_size=3,
    #                                                   random_state=None)),

    #     ('Cluster centers - MiniBatchKMeans',
    #         MiniBatchKMeans(n_clusters=N, tol=1e-3, batch_size=20,
    #                         max_iter=50, random_state=None)),

        ('Factor Analysis components - FA',
         decomposition.FactorAnalysis(n_components=N, max_iter=2000, random_state=time.time_ns() % 2**32),
        'FA'),
        ('Dictionary learning',
            decomposition.MiniBatchDictionaryLearning(n_components=N, alpha=0.1,
                                                      n_iter=50, batch_size=3,
                                                      random_state=time.time_ns() % 2**32) ,
        'dl'),
    #     ('Dictionary learning - positive dictionary',
    #         decomposition.MiniBatchDictionaryLearning(n_components=N, alpha=0.1,
    #                                                   n_iter=50, batch_size=3,
    #                                                   random_state=None,
    #                                                   positive_dict=True)),
    #     ('Dictionary learning - positive code',
    #         decomposition.MiniBatchDictionaryLearning(n_components=N, alpha=0.1,
    #                                                   n_iter=50, batch_size=3,
    #                                                   random_state=None,
    #                                                   positive_code=True)),
        ('Dictionary learning - positive dictionary & code',
            decomposition.MiniBatchDictionaryLearning(n_components=N, alpha=0.1,
                                                      n_iter=50, fit_algorithm='cd',
                                                      batch_size=3,
                                                      random_state=time.time_ns() % 2**32,
                                                      positive_dict=True,
                                                      positive_code=True),
        'dl_positive'),
    ]

    for name, estimator, save_string in estimators:
        t0 = time.time()
        estimator.fit(M)
        train_time = (time.time() - t0)
        
        if hasattr(estimator, 'cluster_centers_'):
            components_ = estimator.cluster_centers_
        else:
            components_ = estimator.components_

        if np.min(components_) < 0:
            components_ -= np.min(components_)
        components_ = components_ / np.max(components_)
        for i, c in enumerate(components_[:N]):
            directory = '{}/{}_pixel{}_N{}_order{}_NM{}/'.format(save_path, save_string, num_pixels, N, order, NM)
            if not os.path.exists(directory):
                os.makedirs(directory)
            Image.fromarray(preprocess_output(c.reshape(image_shape))).save('{}/Components_Reconstructed_{}.png'.format(directory, i), format='png')

if __name__ == '__main__':
    model = sys.argv[1]
    try:
        power_of_2_pixel = int(sys.argv[2]) # 5
        num_pixels = 2**power_of_2_pixel
        N = int(sys.argv[3]) # 3
        order = int(sys.argv[4]) # 1
        NM = int(sys.argv[5]) # 3
    except:
        pass
    if model == 'igbss':
        natural_gradient = sys.argv[6] # True
        n_proc = int(sys.argv[7]) # 20
        try:
            name_string = sys.argv[8]
        except:
            name_string = ''

        if natural_gradient == "True":
            natural_gradient = True
        elif natural_gradient == "False":
            natural_gradient = False
        else:
            raise('Please specify "True" or "False" for natural gradient')

        igbss(num_pixels, N, order, NM, natural_gradient, n_proc, name_string)
    elif model == 'comparison':
        name_string = sys.argv[6]
        comparison(num_pixels, N, order, NM, name_string)
    elif model == 'comparison_all':
        power_of_2_pixel_orderlist = np.array([5], dtype=int)
        N_orderlist = np.array([6], dtype=int)
        order_orderlist = np.array([1, 2, 3, 4, 5, 6], dtype=int)
        NM_orderlist = np.array([6], dtype=int)
        name_string_orderlist = np.array(['order'])

        for power_of_2_pixel in power_of_2_pixel_orderlist:
            num_pixels = 2**power_of_2_pixel
            for N in N_orderlist:
                for order in order_orderlist:
                    for NM in NM_orderlist:
                        for name_string in name_string_orderlist:
                            comparison(num_pixels, N, order, NM, name_string)

        power_of_2_pixel_scalelist = np.array([3, 4, 5, 6, 7], dtype=int)
        N_scalelist = np.array([3], dtype=int)
        order_scalelist = np.array([1], dtype=int)
        NM_scalelist = np.array([3], dtype=int)
        name_string_scalelist = np.array(['scale'])

        for power_of_2_pixel in power_of_2_pixel_scalelist:
            num_pixels = 2**power_of_2_pixel
            for N in N_scalelist:
                for order in order_scalelist:
                    for NM in NM_scalelist:
                        for name_string in name_string_scalelist:
                            comparison(num_pixels, N, order, NM, name_string)

        power_of_2_pixel_varlist = np.array([5], dtype=int)
        N_orderlist = np.array([3], dtype=int)
        order_orderlist = np.array([1, 2, 3], dtype=int)
        NM_orderlist = np.array([3, 4, 5, 6, 9, 15, 22, 30], dtype=int)
        name_string_varlist = np.array(['var{}'.format(i) for i in range(40)])

        for power_of_2_pixel in power_of_2_pixel_orderlist:
            num_pixels = 2**power_of_2_pixel
            for N in N_orderlist:
                for order in order_orderlist:
                    for NM in NM_orderlist:
                        for name_string in name_string_varlist:
                            comparison(num_pixels, N, order, NM, name_string)
    else:
        raise('For the model, please choose either "igbss" or "comparison')
