import copy

import numpy as np

from numpy import linalg
from MarkovChains.MarkovChain import MarkovChain
from scipy.stats import ortho_group


def dfs(root, visited, adj_matrix):
    num_visited = 1
    visited[root] = True
    for i in range(len(adj_matrix[root])):
        if adj_matrix[root][i] == 1 and not (visited[i]):
            num_visited += dfs(i, visited, adj_matrix)
    return num_visited


def check_connected(adj_matrix):
    num_states = adj_matrix.shape[0]
    visited = np.zeros(num_states)
    num_visited = dfs(0, visited, adj_matrix)
    return num_visited == num_states


def generate_erdosrenyi_mc(num_states=50, num_dimensions=100, seed=777, cov_eigengap_threshold=10, p=0.7):
    np.random.seed(seed)
    cov_eigengap = -1
    markov_chain = None

    # while abs(cov_eigengap - cov_eigengap_threshold) < 10:
    transition_matrix = np.zeros((num_states, num_states))

    itr = 1
    while True:
        print("Erdos Renyi Matrix generation Iteration ", itr)
        adj_matrix = np.random.binomial(1, p, (num_states, num_states))
        for i in range(num_states):
            for j in range(i + 1, num_states):
                adj_matrix[j][i] = adj_matrix[i][j]
        if check_connected(adj_matrix):
            break
        itr += 1
    for i in range(num_states):
        adj_matrix[i][i] = 1
    print("Erdos Renyi Matrix generation complete.")

    transition_matrix = copy.deepcopy(adj_matrix)
    normalising_constant = np.sum(transition_matrix, 1)
    transition_matrix = (transition_matrix.T / normalising_constant).T
    transition_matrix = transition_matrix*(1-0.001) + 0.001*np.ones((num_states,num_states))/num_states

    means = np.zeros((num_states, num_dimensions))

    covariance_matrices = []
    for i in range(num_states):
        cov_i = np.zeros((num_dimensions, num_dimensions))
        c = np.random.uniform(1, 10)
        beta = np.random.uniform(1, 10)
        for x in range(num_dimensions):
            for y in range(num_dimensions):
                cov_i[x, y] = np.exp(-abs(x - y) * c) * (5 * np.power(x + 1, -beta)) * (5 * np.power(y + 1, -beta))
        cov_i = (cov_i + np.transpose(cov_i)) / 2
        assert ((cov_i == np.transpose(cov_i)).all())
        covariance_matrices.append(cov_i)

    initial_distribution = abs(np.random.randn(num_states))
    initial_distribution /= np.sum(initial_distribution)
    markov_chain = MarkovChain(transition_matrix=transition_matrix,
                               means=means,
                               covariance_matrices=covariance_matrices,
                               initial_distribution=initial_distribution,
                               seed=seed)

    cov_eigengap = markov_chain.cov_eigengap

    markov_chain.print()
    return markov_chain
