import numpy  as np

from IsserlistPartitions import get_isserlis_partitions

from monomial_coefficients import get_monomial_exponents
from monomial_coefficients import exp2index_mapping
from monomial_coefficients import combine_count_split
from monomial_coefficients import get_multinomial_exponents_and_coefficients

class PASS(object):

    def get_all_moments_exponents3(self):

        self.t_m = np.zeros(shape=(self.mon_terms, 2), dtype=np.int32)  # m+m done!
        self.t_m_m = np.zeros(shape=(self.mon_terms, self.mon_terms, 2), dtype=np.int32)  # m+m done!

        self.t_m_I = np.zeros(shape=(self.mon_terms, self.d, 2),dtype=np.int32)  # m+ I     done!
       
        Identity = np.eye(self.d, dtype=np.int32)

        for i in range(1, self.monomial_exponents.shape[0]):
            e1 = self.monomial_exponents[i]
            if self.degrees_m[i] == 2:
                self.t_m[i] = exp2index_mapping(e1, 1)
            elif self.degrees_m[i] == 1:
                k = 0
                for m3 in self.multinomial_exponents3:
                    e1m3 = e1+m3
                    arr = np.array(exp2index_mapping(e1m3, 2)[ self.i_partitions_4], dtype="str")
                    t_4_m_mu, t_4_m_mu_counts = combine_count_split(arr, self.shape_4)
                    self.i_partitions_4_m_mu.append(t_4_m_mu)

                    k += 1
                k = 0
                for I in Identity:
                    e1I = e1+I
                    self.t_m_I[i, k] = exp2index_mapping(e1I, 1)
                    self.i_partitions_2_m_I.append( self.t_m_I[i, k])  # not using in Stan
                    k += 1

            for j in range(1, i+1):
                #print ("===================")
                
                e2 = self.monomial_exponents[j]
                e1e2 = e1+e2
                if (self.degrees_m[i]+self.degrees_m[j]) == 2:
                    self.t_m_m[i, j] = exp2index_mapping(e1e2, 1)
                    #print (self.t_m_m[i,j])

                elif (self.degrees_m[i]+self.degrees_m[j]) == 4:
                    #print ("==============================")
                    arr = np.array(exp2index_mapping(e1e2)[self.i_partitions_4], dtype="str")
                    t_4_m_m, t_4_m_m_counts = combine_count_split(arr, self.shape_4)
                    self.i_partitions_4_m_m.append(t_4_m_m)
                    #print (t_4_m_m)

                elif self.degrees_m[i]+self.degrees_m[j] == 3:
                    k = 0
                    for m3 in self.multinomial_exponents3:
                        e1e2m3 = e1e2 + m3
                        arr = np.array(exp2index_mapping(e1e2m3, 3)[self.i_partitions_6], dtype="str")
                        t_6_m_m_mu,  t_6_m_m_mu_counts = combine_count_split(arr, self.shape_6)
                        self.i_partitions_6_m_m_mu.append(t_6_m_m_mu)
                        k += 1
                    k = 0
                    for I in Identity:
                        e1e2I = e1e2+I
                        arr = np.array(exp2index_mapping(e1e2I, 2)[ self.i_partitions_4], dtype="str")
                        t_4_m_m_I, t_4_m_m_I_counts = combine_count_split(arr, self.shape_4)
                        self.i_partitions_4_m_m_I.append(t_4_m_m_I)
                        k += 1

        #print (self.t_m_m)
        self.i_partitions_2_m_I =    np.array(self.i_partitions_2_m_I,   dtype=np.int32)
        self.i_partitions_4_m_m =    np.array(self.i_partitions_4_m_m,   dtype=np.int32)
        self.i_partitions_4_m_mu =   np.array(self.i_partitions_4_m_mu,dtype=np.int32)
        self.i_partitions_4_m_m_I =  np.array(self.i_partitions_4_m_m_I, dtype=np.int32)
        self.i_partitions_6_m_m_mu = np.array(self.i_partitions_6_m_m_mu, dtype=np.int32)
        


    def __init__(self, d):
        self.d = d
        self.M = 2

        self.t_m = []
        self.t_m_I = []
        self.t_m_m = []
        self.i_partitions_2_m_I = []

        self.i_partitions_4_m_m = []
        self.i_partitions_4_m_mu = []
        self.i_partitions_4_m_m_I = []
        
        self.i_partitions_6_m_m_mu = []
        

        self.shape_4 = (3, 2, 3)
        self.shape_6 = (15, 3, 3)

        self.monomial_exponents = get_monomial_exponents(self.d, self.M)
        

        self.mon_terms = self.monomial_exponents.shape[0]
        self.multinomial_exponents3, self.multinomial_coefficients3 = get_multinomial_exponents_and_coefficients(self.d, 3)
        self.mult_terms3 = self.multinomial_exponents3.shape[0]

        self.degrees_m = np.array([e1.sum() for e1 in self.monomial_exponents], dtype=np.int32)
        self.degrees_mu3 = np.array([e1.sum() for e1 in self.multinomial_exponents3], dtype=np.int32)

        self.i_partitions_2 = get_isserlis_partitions(list(range(1, 3)))-1
        self.i_partitions_4 = get_isserlis_partitions(list(range(1, 5)))-1
        self.i_partitions_6 = get_isserlis_partitions(list(range(1, 7)))-1
 
        self.get_all_moments_exponents3()
        

