import numpy as np

from math import log10, floor
from numpy import linalg


class MarkovChain:

    def __init__(self, transition_matrix, means, covariance_matrices, initial_distribution, seed=777):
        self.transition_matrix = transition_matrix
        self.num_states = len(transition_matrix)
        # Check that the length of means is equal to the number of states
        assert (len(means) == self.num_states)
        self.means = means
        self.num_dimensions = len(self.means[0])
        # Check that the length of covariance matrices is equal to the number of states
        assert (len(covariance_matrices) == self.num_states)
        # Check that the dimensionality of covariance matrices is equal to the number of dimensions in mean
        assert (len(covariance_matrices[0]) == self.num_dimensions)
        self.covariance_matrices = covariance_matrices
        # Check that the length of the initial distribution is equal to the number of states
        assert (len(initial_distribution) == self.num_states)
        self.initial_distribution = initial_distribution
        self.seed = seed
        self.stationary_distribution = None
        self.true_covariance_matrix = None
        self.true_mean = np.zeros(self.num_dimensions)
        self.transition_eigengap = None  # Eigen gap of the transition matrix
        self.cov_eigengap = None  # Eigen gap of the covariance matrix
        self.principal_components = None  # Principal components of the covariance matrix
        self.largest_eigenvector = None  # Largest eigenvector of the covariance matrix
        self.eigenvalues = None  # Sorted eigenvalues of the covariance matrix
        self.lambda1 = None  # Largest eigenvalue of the covariance matrix
        self.lambda2 = None  # Second-largest eigenvalue of the covariance matrix
        self.lambda_star = None  # Largest eigenvalue amongst all covariance matrices
        self.sample_cache = {}
        self.sample_index = {}

        self.initialize()  # Initialize all member variables

    def initialize(self):
        self.get_stationary_distribution()
        self.get_true_covariance_matrix()
        self.update_params()

    def get_sym_mat_sqrt(self, M):
        # Computing diagonalization
        eigenvalues, eigenvectors = np.linalg.eigh(M)
        # Ensuring square root matrix exists
        assert (eigenvalues >= 0).all()
        sqrt_matrix = eigenvectors * np.sqrt(eigenvalues) @ np.transpose(eigenvectors)
        return sqrt_matrix

    def update_params(self):
        print("Updating parameters")
        data = []
        state = self.get_initial_state()
        for i in range(100000):
            sample = self.get_sample(state)
            data.append([i, state, sample])
            state = self.get_next_state(state)
        sample_mean = np.zeros(self.num_dimensions)
        sample_sum = np.zeros(self.num_dimensions)
        counter = 0

        itr = 0
        for t in range(len(data)):
            At = data[t][2]
            sample_sum += At
            sample_mean = sample_sum / (counter + 1)
            itr += 1
            counter += 1
        print("Sample Mean calculated")

        sample_covariance_matrix = np.zeros((self.num_dimensions, self.num_dimensions))
        sample_covariance_matrix_sum = np.zeros((self.num_dimensions, self.num_dimensions))
        counter = 0
        itr = 0
        for t in range(len(data)):
            At = data[t][2]
            At = At - sample_mean
            sample_covariance_matrix_sum += np.outer(At, At)
            sample_covariance_matrix = sample_covariance_matrix_sum / (counter + 1)
            itr += 1
            counter += 1
        print("Sample Covariance calculated")

        sample_covariance_matrix = (sample_covariance_matrix + np.transpose(sample_covariance_matrix)) / 2
        self.true_covariance_matrix = sample_covariance_matrix
        self.true_mean = sample_mean

        eigenvalues, eigenvectors = linalg.eig(self.true_covariance_matrix)
        sorted_indices = np.real(eigenvalues).argsort()[::-1]
        self.principal_components = np.real(eigenvectors[:, sorted_indices])
        self.eigenvalues = np.real(eigenvalues)[sorted_indices]

        self.largest_eigenvector = self.principal_components[:, 0]
        self.largest_eigenvector /= linalg.norm(self.largest_eigenvector)

        self.lambda1 = self.eigenvalues[0]
        if len(self.eigenvalues) > 1:
            self.lambda2 = self.eigenvalues[1]
            self.cov_eigengap = self.lambda1 - self.lambda2
        self.print()

    def generate_batch_samples(self, state, num_samples=1000):
        # samples = np.random.multivariate_normal(mean=self.means[state],
        #                                         cov=self.covariance_matrices[state],
        #                                         size=num_samples)
        ###################################
        mu = self.means[state][0]
        # mu_total = np.sum(self.means[:,0]*self.stationary_distribution)
        # std = np.sqrt(mu_total*(1-mu_total))
        # flat_samples = (np.random.binomial(1, mu, (self.num_dimensions,num_samples))-mu_total)/std
        # std = np.sqrt(mu*(1-mu))
        flat_samples = np.random.binomial(1, mu, (self.num_dimensions,num_samples))
        flat_samples = self.get_sym_mat_sqrt(self.covariance_matrices[state])@flat_samples
        samples = np.transpose(flat_samples)
        ###################################
        # ub = self.means[state][0]
        # flat_samples = np.random.uniform(0,ub,(self.num_dimensions,num_samples))
        # flat_samples = self.get_sym_mat_sqrt(self.covariance_matrices[state])@flat_samples
        # samples = np.transpose(flat_samples)
        ###################################
        # flat_samples = np.random.randn(self.num_dimensions,num_samples)
        # flat_samples = self.get_sym_mat_sqrt(self.covariance_matrices[state])@flat_samples
        # samples = np.transpose(flat_samples)
        ###################################
        # samples = []
        # for i in range(num_samples):
        #     sample_i = []
        #     for d in range(self.num_dimensions):
        #         sample_i.append(np.random.binomial(1,self.means[state][d]))
        #     samples.append(sample_i)
        # samples = np.array(samples)
        self.sample_cache[state] = samples
        self.sample_index[state] = 0

    def get_stationary_distribution(self):
        if self.stationary_distribution is not None:
            return self.stationary_distribution
        w, v = linalg.eig(np.transpose(self.transition_matrix))
        index = np.argsort(np.real(w))[-1]
        self.stationary_distribution = np.real(v[:, index])

        trans_eigenvals, _ = linalg.eig(np.transpose(self.transition_matrix))
        sorted_eigenvalues = np.real(trans_eigenvals)[np.argsort(np.real(trans_eigenvals))]
        self.transition_eigengap = 1 - abs(sorted_eigenvalues[-2])

        self.stationary_distribution /= np.sum(self.stationary_distribution)

        # Verify that stationary distribution is correct
        assert np.allclose(self.stationary_distribution,
                           np.matmul(self.stationary_distribution, self.transition_matrix),
                           1e-5 * max(self.stationary_distribution))

        return self.stationary_distribution

    def get_true_covariance_matrix(self):
        if self.true_covariance_matrix is not None:
            return self.true_covariance_matrix
        # True expected matrix computation
        self.true_covariance_matrix = np.zeros((self.num_dimensions, self.num_dimensions))
        for i in range(self.num_states):
            self.true_covariance_matrix += (self.stationary_distribution[i] * self.covariance_matrices[i])
            eigenvalues, v = linalg.eig(self.covariance_matrices[i])
            sorted_indices = np.argsort(eigenvalues)
            if self.lambda_star is None:
                self.lambda_star = np.real(eigenvalues[sorted_indices][-1])
            else:
                self.lambda_star = max(self.lambda_star, np.real(eigenvalues[sorted_indices][-1]))

        self.true_covariance_matrix = (self.true_covariance_matrix + np.transpose(self.true_covariance_matrix)) / 2

        assert ((self.true_covariance_matrix == np.transpose(self.true_covariance_matrix)).all())

        eigenvalues, eigenvectors = linalg.eig(self.true_covariance_matrix)
        sorted_indices = np.real(eigenvalues).argsort()[::-1]
        self.principal_components = np.real(eigenvectors[:,sorted_indices])
        self.eigenvalues = np.real(eigenvalues)[sorted_indices]

        self.largest_eigenvector = self.principal_components[:, 0]
        self.largest_eigenvector /= linalg.norm(self.largest_eigenvector)

        self.lambda1 = self.eigenvalues[0]
        if len(self.eigenvalues) > 1:
            self.lambda2 = self.eigenvalues[1]
            self.cov_eigengap = self.lambda1 - self.lambda2

        return self.true_covariance_matrix

    def get_initial_state(self):
        # return np.random.choice(np.arange(self.num_states), 1, p=self.initial_distribution)
        return np.random.choice(np.arange(self.num_states), 1, p=self.stationary_distribution)

    def get_next_state(self, state):
        next_state_probabilities = self.transition_matrix[state][0]
        assert (len(next_state_probabilities) == self.num_states)
        return np.random.choice(np.arange(self.num_states), 1, p=next_state_probabilities)

    def get_sample(self, state, num_samples=1000):
        if (int(state) not in self.sample_index) or (
                self.sample_index[int(state)] == len(self.sample_cache[int(state)])):
            self.generate_batch_samples(int(state), num_samples)
        sample = self.sample_cache[int(state)][self.sample_index[int(state)]]
        self.sample_index[int(state)] += 1
        assert (len(sample) == self.num_dimensions)
        return (sample - self.true_mean)

    def special_round(self, x, sig):
        return round(x, sig - int(floor(log10(abs(x)))) - 1)

    def print(self):
        print("============================================")
        print("Transition Matrix Eigengap : ", self.transition_eigengap)
        lambda_unicode = '\u03BB'  # Unicode for the Greek letter lambda
        subscript_2 = '\u2082'  # Unicode for subscript 2
        lambda2 = abs(1 - self.transition_eigengap)
        if lambda2 == 0:
            lambda2 = 1e-5
        l = lambda_unicode + subscript_2 + " : " + str(self.special_round(lambda2, 3))
        print(l)
        print("True Covariance Matrix Eigengap : ", self.cov_eigengap)
        print("Largest Eigenvalue of True Covariance Matrix : ", self.lambda1)
        print("Second-largest Eigenvalue of True Covariance Matrix : ", self.lambda2)
        print("Largest Eigenvalue amongst Covariance Matrices of all states : ", self.lambda_star)
        print("============================================")
