import sys
sys.path.append('..')

import numpy as np
import pandas as pd
from scipy import stats
from scipy import signal
import matplotlib.pyplot as plt

import matplotlib
font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 15}
matplotlib.rc('font', **font)

from IGPPP import IGPPP

from simulate_poisson_process_v2 import *

import os

class get_powerset:
    def __init__(self, N, order):
        self.N = N
        self.order = order
        self._get_nodes()

    def _activate_nodes(self, C, ci, N, depth):
        """
        A recursive function to generate only the nodes which
        are required.
        Parameters:
        -----------
        None
        Returns:
        --------
        None
        """
        if depth == 0:
            self.C_vec.append(C)
            return

        for c in range(ci, N-depth+1):
            C_current = np.array(C, dtype=bool)
            C_current[c] = True
            C_current = self._activate_nodes(C_current, c+1, N, depth-1)

    def _get_nodes(self):
        """
        Initalise nodes for the model.
        Parameters:
        -----------
        None
        Returns:
        --------
        None
        """
        N = self.N
        self.C_vec = list()
        for d in np.arange(1,self.order+1):
            self._activate_nodes(np.zeros(N, dtype=bool), 0, N, d)
        self.C_vec = np.array(self.C_vec, dtype=bool)

results_path = '../results/'
if not os.path.exists(results_path):
    os.mkdir(results_path)
    
store_data_path = '../store_data/'
if not os.path.exists(store_data_path):
    os.mkdir(store_data_path)

from simulate_poisson_process_v2 import *

np.random.seed(1)

T_max = 10
d = 2
N = 20

scipy_multivar_norm = get_scipy_multivar_gauss(T_max, d)
scipy_mixture = create_mixture_of_pdf(T_max, d, N, get_scipy_multivar_gauss)
samples = draw_samples(scipy_mixture, T_max, 100000)
samples_for_pdf = draw_samples(scipy_mixture, T_max, int(100000))

two_dimension_dense_path = results_path + 'two_dimension_dense/'
if not os.path.exists(two_dimension_dense_path):
    os.mkdir(two_dimension_dense_path)

two_dimension_dense_store_path = store_data_path + 'two_dimension_dense/'
if not os.path.exists(two_dimension_dense_store_path):
    os.mkdir(two_dimension_dense_store_path)

pd.DataFrame(samples, columns=['Feature1', 'Feature2']).to_csv('{}Observations.csv'.format(two_dimension_dense_store_path), index=False)

h_list = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2]
for h in h_list:
    IGPPP_model_order1 = IGPPP(samples, 0.1, 10, order=1, kernel_estimator=True, h=h)
    IGPPP_model_order1.train(100, natural_gradient=True, verbose=True, verbose_step=1)
    IGPPP_model_order2 = IGPPP(samples, 0.1, 10, order=2, kernel_estimator=True, h=h)
    IGPPP_model_order2.train(100, natural_gradient=True, verbose=True, verbose_step=1)

    C_vec = get_powerset(2,2).C_vec


    two_dimension_dense_path = results_path + 'two_dimension_dense/h{}/'.format(h)
    if not os.path.exists(two_dimension_dense_path):
        os.mkdir(two_dimension_dense_path)

    two_dimension_dense_store_path = store_data_path + 'two_dimension_dense/h{}/'.format(h)
    if not os.path.exists(two_dimension_dense_store_path):
        os.mkdir(two_dimension_dense_store_path)

    for idx in C_vec:
        plt.clf()
        process_store_path = two_dimension_dense_store_path + 'p{}/'.format(''.join(np.array(np.where(idx)[0] + 1, dtype=str)))
        if not os.path.exists(process_store_path):
            os.mkdir(process_store_path)
        ground_truth_list = list()
        intensity_list = list()

        pdf, bins = get_emperical_pdf(samples_for_pdf, T_max, int(100000) + 1, idx)
        b, a = signal.butter(3, 0.05)
        pdf = signal.filtfilt(b, a, pdf)
        counts = IGPPP_model_order2.get_count(idx)[0]
        if counts == 0:
            counts = 1
        ground_truth_intensity = pdf[4::10] / np.sum(pdf[4::10]) * counts
        ground_truth_bin = bins[4::10]
        
        C_event_count_list, C_time_list = IGPPP_model_order2._count_observations_in_bin(samples, 0.1, [idx])
        
        order1_intensity = IGPPP_model_order1.get_intensity(idx)
        order2_intensity = IGPPP_model_order2.get_intensity(idx)
        
        ground_truth_list.append(np.squeeze(ground_truth_bin))
        ground_truth_list.append(np.squeeze(ground_truth_intensity))
        intensity_list.append(np.squeeze(IGPPP_model_order1.bin_times))
        intensity_list.append(np.squeeze(order1_intensity))
        intensity_list.append(np.squeeze(order2_intensity))
        
        # pd.DataFrame(np.array(ground_truth_list).T, columns=['Time', 'Ground Truth']).to_csv('{}ground_truth.csv'.format(process_store_path), index=False)
        pd.DataFrame(np.array(intensity_list).T, columns=['Time', 'Order1', 'Order2']).to_csv('{}intensity.csv'.format(process_store_path), index=False)
        pd.DataFrame(np.array([np.squeeze(C_time_list), np.squeeze(C_event_count_list)]).T, columns=['Event Time', 'Event Count']).to_csv('{}event_time.csv'.format(process_store_path), index=False)
        
        # plt.plot(IGPPP_model_order1.bin_times, order1_intensity, linewidth=3)
        # plt.plot(IGPPP_model_order2.bin_times, order2_intensity, '--', linewidth=3)

        # pdf, bins = get_emperical_pdf(samples_for_pdf, T_max, int(1e3) + 1, idx)
        # b, a = signal.butter(3, 0.05)
        # pdf = signal.filtfilt(b, a, pdf)
        # plt.plot(ground_truth_bin, ground_truth_intensity, 'k')
        # # plt.plot(samples[:,2], np.zeros(len(samples[:,2])), '.')
        # plt.plot(event_time, np.zeros(len(event_time)), '.k')
        # plt.legend(['Order: 1', 'Order: 2', 'Ground Truth'], ncol=2)
        # plt.xlabel('Time', fontsize=16, fontweight='bold')
        # plt.ylabel('Intensity', fontsize=16, fontweight='bold')
        # print('idx:', idx)
        # plt.tight_layout()
        # plt.savefig('{}two_dimension_dense_process{}_h{}.eps'.format(two_dimension_dense_path, ''.join(np.array(np.where(idx)[0] + 1, dtype='str')), str(h).replace(".", "")), format='eps')
        # plt.savefig('{}two_dimension_dense_process{}_h{}.png'.format(two_dimension_dense_path, ''.join(np.array(np.where(idx)[0] + 1, dtype='str')), str(h).replace(".", "")), format='png', dpi=1200)
        # # plt.show()


