import numpy as np
from scipy.stats import special_ortho_group

from src.hmm_belief_forward import HMMBeliefForward


class HMMSpectralDirectEstimation(object):
    def __init__(self,
                 n_contexts: int,
                 n_regimes: int,
                 hot_start: int = 1000,
                 random_seed: int = 1989,
                 emission_order_larger = 1):

        self.n_contexts = n_contexts
        self.n_regimes = n_regimes
        self.hot_start = hot_start
        self.belief_estimator = HMMBeliefForward
        self.random_seed = random_seed
        self.emission_order_larger = emission_order_larger
        self._init_constant()

    def _init_constant(self):
        np.random.seed(self.random_seed)
        self.index_context = list(np.arange(self.n_contexts))
        self.curr_iter = 1
        self.current_P31 = np.zeros(shape=(self.n_contexts, self.n_contexts))
        self.current_P32 = np.zeros(shape=(self.n_contexts, self.n_contexts))
        self.current_P312 = np.zeros(shape=(self.n_contexts, self.n_contexts, self.n_contexts))
        self.current_P31_sum = np.zeros(shape=(self.n_contexts, self.n_contexts))
        self.current_P32_sum = np.zeros(shape=(self.n_contexts, self.n_contexts))
        self.current_P312_sum = np.zeros(shape=(self.n_contexts, self.n_contexts, self.n_contexts))

        self.current_P31_approx_sum = np.zeros(shape=(self.n_contexts, self.n_contexts))
        self.current_P32_approx_sum = np.zeros(shape=(self.n_contexts, self.n_contexts))
        self.current_P312_approx_sum = np.zeros(shape=(self.n_contexts, self.n_contexts, self.n_contexts))

        self.current_counter = 0
        self.current_U1 = np.zeros(shape = (self.n_contexts, self.n_regimes))
        self.current_U2 = np.zeros(shape = (self.n_contexts, self.n_regimes))
        self.current_U3 = np.zeros(shape = (self.n_contexts, self.n_regimes))
        self.current_R1 =  np.zeros(shape = (self.n_regimes, self.n_regimes))
        self.current_L =  np.zeros(shape = (self.n_regimes, self.n_regimes))
        self.current_transition =  np.ones(shape=(self.n_regimes, self.n_regimes))
        self.current_transition /= self.current_transition.sum(axis=0)
        self.current_emission = np.ones(shape = (self.n_contexts, self.n_regimes))
        self.current_emission /= self.current_emission.sum(axis=0)
        self.current_belief = np.ones(shape = self.n_regimes)
        self.current_belief /= self.current_belief.sum()
        self.hist_context = []
        self.hist_transition = []
        self.hist_emission = []
        self.hist_belief = []
        self.dict_hist_belief = {}
        self._sample_theta()

    def update_p(self):
        if self.curr_iter >= 3:
            self.current_counter += 1

            context_t = self.hist_context[-1]
            context_tm1 = self.hist_context[-2]
            context_tm2 = self.hist_context[-3]

            context_t_onehot = np.zeros(self.n_contexts)
            context_t_onehot[context_t] = 1
            context_tm1_onehot = np.zeros(self.n_contexts)
            context_tm1_onehot[context_tm1] = 1
            context_tm2_onehot = np.zeros(self.n_contexts)
            context_tm2_onehot[context_tm2] = 1

            self.current_P31_sum[context_t, context_tm2] += 1
            self.current_P31 = self.current_P31_sum / self.current_counter


            self.current_P32_sum[context_t, context_tm1] += 1
            self.current_P32 = self.current_P32_sum / self.current_counter

            self.current_P312_sum[context_t, context_tm2, context_tm1] += 1
            self.current_P312 = self.current_P312_sum / self.current_counter

            self.current_P31_approx_sum += np.outer(context_t_onehot, context_tm2_onehot)
            self.current_P32_approx_sum += np.outer(context_t_onehot, context_tm1_onehot)
            self.current_P312_approx_sum += \
                context_t_onehot[:, None, None] * context_tm2_onehot[None, :, None] * context_tm1_onehot[None, None,:]
        else:
            pass

    def _update_u(self):
        U, _, Vh= np.linalg.svd(self.current_P31, full_matrices=False)
        self.current_U3 = U[:, :self.n_regimes]
        self.current_U1 = Vh.T[:, :self.n_regimes]

        _, _, Vh2 = np.linalg.svd(self.current_P32, full_matrices=False)
        self.current_U2 = Vh2.T[:, :self.n_regimes]

    def _sample_theta(self):
        self.current_THETA = special_ortho_group.rvs(dim=self.n_regimes, random_state=self.random_seed)

    def _update_r1(self):
        L_comp_1 = self._compute_l_comp(index_row=0)
        eig_val, eig_vec = np.linalg.eig(L_comp_1)

        idx = (-eig_val).argsort()[::-1]
        eig_vec = eig_vec[:, idx]
        eig_vec /= np.linalg.norm(eig_vec, axis=0, keepdims=True)

        self.current_R1 = eig_vec

    def _compute_l_comp(self, index_row: int):
        theta_ = self.current_THETA[index_row, :]
        RL_comp = self.current_U3.T @ (self.current_P312 @ (self.current_U2 @ theta_)) @ self.current_U1
        RR_comp_inv = np.linalg.inv(self.current_U3.T @ self.current_P31 @ self.current_U1)
        return RL_comp @ RR_comp_inv

    def _update_l(self):
        current_L = np.zeros(shape=(self.n_regimes, self.n_regimes))
        R1_inv = np.linalg.inv(self.current_R1)
        for i in range(0, self.n_regimes):
            L_comp_i = self._compute_l_comp(index_row=i)   # (H×H)
            element_l = np.diag(R1_inv @ L_comp_i @ self.current_R1)
            current_L[i, :] = element_l
        self.current_L = current_L

    def _update_emission_transition(self):
        current_emission = self.current_U2 @ np.linalg.inv(self.current_THETA) @ self.current_L
        current_emission = np.maximum(current_emission, 0)
        self.current_emission = current_emission / np.sum(current_emission, axis=0, keepdims=True)
        current_transition = np.abs(np.linalg.pinv(self.current_U3.T @ self.current_emission) @ self.current_R1)
        self.current_transition = current_transition / np.sum(current_transition, axis=0, keepdims=True)

        is_flip = \
            int(np.argmax((self.current_emission.T * np.array(range(1, self.n_contexts + 1))).mean(axis=1))) != self.emission_order_larger

        if is_flip:
            self.current_emission = np.flip(self.current_emission, axis=1)
            self.current_transition = np.flip(self.current_transition)

    def update_spectral_estimation(self):
        self._update_u()
        self._update_r1()
        self._update_l()
        self._update_emission_transition()

    def run_one_iteration(self, context):
        self.hist_context += [context]
        self.update_p()
        if self.curr_iter >= self.hot_start:
            self.update_spectral_estimation()

        self.hist_emission += [self.current_emission]
        self.hist_transition += [self.current_transition]

    def update_belief_one_iteration(self, vec_init, number_round):
        belief_estimator = self.belief_estimator(vec_init=vec_init,
                                                 transition_matrix=self.hist_transition[number_round - 1],
                                                 emission_matrix=self.hist_emission[number_round - 1])
        hist_belief = belief_estimator.run_numba(self.hist_context[:number_round])
        return hist_belief

    def update_belief(self, vec_init):
        for curr_iter_ in range(1, len(self.hist_context) + 1):
            print(curr_iter_)
            if curr_iter_ >= self.hot_start:
                hist_belief = self.update_belief_one_iteration(vec_init, curr_iter_)
                self.current_belief = hist_belief[-1]
                if self.curr_iter % 100 == 0:
                    self.dict_hist_belief[self.curr_iter_] = hist_belief

            self.hist_belief += [self.current_belief]

    def run(self, list_contexts):
        for context_ in list_contexts:
            print(self.curr_iter)
            self.run_one_iteration(context_)
            self.curr_iter += 1
