import sys
sys.path.append('..')

import numpy as np
import pandas as pd
from scipy import stats
from scipy import signal
import matplotlib.pyplot as plt

import matplotlib
font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 15}
matplotlib.rc('font', **font)

from IGPPP import IGPPP

from simulate_poisson_process import *

import os

class get_powerset:
    def __init__(self, N, order):
        self.N = N
        self.order = order
        self._get_nodes()

    def _activate_nodes(self, C, ci, N, depth):
        """
        A recursive function to generate only the nodes which
        are required.
        Parameters:
        -----------
        None
        Returns:
        --------
        None
        """
        if depth == 0:
            self.C_vec.append(C)
            return

        for c in range(ci, N-depth+1):
            C_current = np.array(C, dtype=bool)
            C_current[c] = True
            C_current = self._activate_nodes(C_current, c+1, N, depth-1)

    def _get_nodes(self):
        """
        Initalise nodes for the model.
        Parameters:
        -----------
        None
        Returns:
        --------
        None
        """
        N = self.N
        self.C_vec = list()
        for d in np.arange(1,self.order+1):
            self._activate_nodes(np.zeros(N, dtype=bool), 0, N, d)
        self.C_vec = np.array(self.C_vec, dtype=bool)

results_path = '../results/'
if not os.path.exists(results_path):
    os.mkdir(results_path)
    
store_data_path = '../store_data/'
if not os.path.exists(store_data_path):
    os.mkdir(store_data_path)

from simulate_poisson_process_v2 import *

np.random.seed(1)
T_max = 10
d = 4
N = 50

scipy_multivar_norm = get_scipy_multivar_gauss(T_max, d)
scipy_mixture = create_mixture_of_pdf(T_max, d, N, get_scipy_multivar_gauss)
samples = draw_samples(scipy_mixture, T_max, 100000)
# samples_for_pdf = draw_samples(scipy_mixture, T_max, int(1e8))

higher_dimension_sparse_path = results_path + 'higher_dimension_sparse/'
if not os.path.exists(higher_dimension_sparse_path):
    os.mkdir(higher_dimension_sparse_path)

higher_dimension_sparse_store_path = store_data_path + 'higher_dimension_sparse/'
if not os.path.exists(higher_dimension_sparse_store_path):
    os.mkdir(higher_dimension_sparse_store_path)

pd.DataFrame(samples, columns=['Feature1', 'Feature2', 'Feature3', 'Feature4']).to_csv('{}Observations.csv'.format(higher_dimension_sparse_store_path), index=False)


h_list = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2]
for h in h_list:
    IGPPP_model_order1 = IGPPP(samples, 0.1, 10, order=1, kernel_estimator=True, h=h)
    IGPPP_model_order1.train(100, natural_gradient=True, verbose=True, verbose_step=1)

    IGPPP_model_order2 = IGPPP(samples, 0.1, 10, order=2, kernel_estimator=True, h=h)
    IGPPP_model_order2.train(100, natural_gradient=True, verbose=True, verbose_step=1)

    IGPPP_model_order3 = IGPPP(samples, 0.1, 10, order=3, kernel_estimator=True, h=h)
    IGPPP_model_order3.train(100, natural_gradient=True, verbose=True, verbose_step=1)

    IGPPP_model_order4 = IGPPP(samples, 0.1, 10, order=4, kernel_estimator=True, h=h)
    IGPPP_model_order4.train(100, natural_gradient=True, verbose=True, verbose_step=1)

    C_vec = get_powerset(4,4).C_vec



    plt.clf()
    higher_dimension_sparse_path = results_path + 'higher_dimension_sparse/h{}/'.format(h)
    if not os.path.exists(higher_dimension_sparse_path):
        os.mkdir(higher_dimension_sparse_path)

    higher_dimension_sparse_store_path = store_data_path + 'higher_dimension_sparse/h{}/'.format(h)
    if not os.path.exists(higher_dimension_sparse_store_path):
        os.mkdir(higher_dimension_sparse_store_path)
    C_vec = get_powerset(4,4).C_vec
    # plt.clf()
    # fig, ax = plt.subplots(4,4, figsize=(32, 10))
    for i, idx in enumerate(C_vec):
        process_store_path = higher_dimension_sparse_store_path + 'p{}/'.format(''.join(np.array(np.where(idx)[0] + 1, dtype=str)))
        if not os.path.exists(process_store_path):
            os.mkdir(process_store_path)
        # ground_truth_list = list()
        intensity_list = list()
        
        # pdf, bins = get_emperical_pdf(samples_for_pdf, T_max, int(1e3) + 1, idx)
        # b, a = signal.butter(3, 0.05)
        # pdf = signal.filtfilt(b, a, pdf)
        # counts = IGPPP_model_order4.get_count(idx)[0]
        # if counts == 0:
        #     counts = 1
        # ground_truth_intensity = pdf[4::10] / np.sum(pdf[4::10]) * counts
        # ground_truth_bin = bins[4::10]
        
        C_event_count_list, C_time_list = IGPPP_model_order4._count_observations_in_bin(samples, 0.1, [idx])


        order1_intensity = IGPPP_model_order1.get_intensity(idx)
        order2_intensity = IGPPP_model_order2.get_intensity(idx)
        order3_intensity = IGPPP_model_order3.get_intensity(idx)
        order4_intensity = IGPPP_model_order4.get_intensity(idx)

        # ground_truth_list.append(np.squeeze(ground_truth_bin))
        # ground_truth_list.append(np.squeeze(ground_truth_intensity))
        intensity_list.append(np.squeeze(IGPPP_model_order1.bin_times))
        intensity_list.append(np.squeeze(order1_intensity))
        intensity_list.append(np.squeeze(order2_intensity))
        intensity_list.append(np.squeeze(order3_intensity))
        intensity_list.append(np.squeeze(order4_intensity))

        # pd.DataFrame(np.array(ground_truth_list).T, columns=['Time', 'Ground Truth']).to_csv('{}ground_truth.csv'.format(process_store_path), index=False)
        pd.DataFrame(np.array(intensity_list).T, columns=['Time', 'Order1', 'Order2', 'Order3', 'Order4']).to_csv('{}intensity.csv'.format(process_store_path), index=False)
        pd.DataFrame(np.array([np.squeeze(C_time_list), np.squeeze(C_event_count_list)]).T, columns=['Event Time', 'Event Count']).to_csv('{}event_time.csv'.format(process_store_path), index=False)

        
    #     cax = ax[np.unravel_index(i, (4,4))]
    #     cax.plot(IGPPP_model_order1.bin_times, order1_intensity, linewidth=3.5)
    #     cax.plot(IGPPP_model_order2.bin_times, order2_intensity, '.-', linewidth=3)
    #     cax.plot(IGPPP_model_order3.bin_times, order3_intensity, '--', linewidth=3)
    #     cax.plot(IGPPP_model_order4.bin_times, order4_intensity, ':', linewidth=3)


    #     cax.plot(ground_truth_bin, ground_truth_intensity, 'k', linewidth=3)    
    #     cax.plot(event_time, np.zeros(len(event_time)), '.k')

    #     cax.set_xlabel('Time', fontsize=16, fontweight='bold')
    #     cax.set_ylabel('Intensity', fontsize=16, fontweight='bold')
    #     cax.set_title('Process: {}'.format(np.where(idx)[0] + 1), fontsize=16, fontweight='bold')
    #     print('idx:', idx)
    # cax = ax[3,3]
    # cax.axis('off')
    # color_list = [c['color'] for c in matplotlib.rcParams['axes.prop_cycle']]
    # line_list = list()
    # line_list.append(matplotlib.lines.Line2D([0], [0], color=color_list[0], linewidth=3))
    # line_list.append(matplotlib.lines.Line2D([0], [0], color=color_list[1], linewidth=3))
    # line_list.append(matplotlib.lines.Line2D([0], [0], color=color_list[2], linewidth=3))
    # line_list.append(matplotlib.lines.Line2D([0], [0], color=color_list[3], linewidth=3))
    # line_list.append(matplotlib.lines.Line2D([0], [0], color='black', linewidth=3))
    # cax.legend(line_list, ['Order: 1', 'Order: 2', 'Order: 3', 'Order: 4', 'Ground Truth'], ncol=2, loc='lower right')
    # plt.tight_layout()
    # plt.savefig('{}higher_dimension_sparse_h04.eps'.format(higher_dimension_sparse_path), format='eps')
    # plt.savefig('{}higher_dimension_sparse_h04.png'.format(higher_dimension_sparse_path), format='png', dpi=600)
    # fig.show()


    # for idx in C_vec:
    #     plt.clf()
    #     pdf, bins = get_emperical_pdf(samples_for_pdf, T_max, int(1e3) + 1, idx)
    #     b, a = signal.butter(3, 0.05)
    #     pdf = signal.filtfilt(b, a, pdf)
    #     counts = IGPPP_model_order4.get_count(idx)[0]
    #     if counts == 0:
    #         counts = 1
    #     ground_truth_intensity = pdf[4::10] / np.sum(pdf[4::10]) * counts
    #     ground_truth_bin = bins[4::10]
        
    #     order1_intensity = IGPPP_model_order1.get_intensity(idx)
    #     order2_intensity = IGPPP_model_order2.get_intensity(idx)
    #     order3_intensity = IGPPP_model_order3.get_intensity(idx)
    #     order4_intensity = IGPPP_model_order4.get_intensity(idx)
        
        
    #     plt.plot(IGPPP_model_order1.bin_times, order1_intensity, linewidth=3.5)
    #     plt.plot(IGPPP_model_order2.bin_times, order2_intensity, '.-', linewidth=3)
    #     plt.plot(IGPPP_model_order3.bin_times, order3_intensity, '--', linewidth=3)
    #     plt.plot(IGPPP_model_order4.bin_times, order4_intensity, ':', linewidth=3)

    #     pdf, bins = get_emperical_pdf(samples_for_pdf, T_max, int(1e3) + 1, idx)
    #     b, a = signal.butter(3, 0.05)
    #     pdf = signal.filtfilt(b, a, pdf)
    #     plt.plot(ground_truth_bin, ground_truth_intensity, 'k', linewidth=3)
    #     # plt.plot(samples[:,2], np.zeros(len(samples[:,2])), '.')
    #     plt.plot(IGPPP_model_order4.get_event_time(idx)[0], np.zeros(len(IGPPP_model_order4.get_event_time(idx)[0])), '.k')
    #     plt.legend(['Order: 1', 'Order: 2', 'Order: 3', 'Order: 4', 'Ground Truth'], ncol=2)
    #     plt.xlabel('Time', fontsize=16, fontweight='bold')
    #     plt.ylabel('Intensity', fontsize=16, fontweight='bold')
    #     print('idx:', idx)
    #     plt.tight_layout()
    #     plt.savefig('{}higher_dimension_sparse_process{}_h{}.eps'.format(higher_dimension_sparse_path, ''.join(np.array(np.where(idx)[0] + 1, dtype='str')), str(h).replace(".", "")), format='eps')
    #     plt.savefig('{}higher_dimension_sparse_process{}_h{}.png'.format(higher_dimension_sparse_path, ''.join(np.array(np.where(idx)[0] + 1, dtype='str')), str(h).replace(".", "")), format='png', dpi=1200)
    #     # plt.show()













