#%%
import numpy as np
from scipy import sparse
from math import sqrt
from numpy.linalg import norm
import networkx as nx
from utils import rank1_update
from numpy.random import default_rng
random = default_rng()
from tqdm import tqdm

class SolverADMM:

    def __init__(self, rho= 1.0, eps_rel= 1e-3, eps_abs= 1e-6, primal_thd= 1e-8, dual_thd= 1e-8, max_iter= 100, dtype= 'float32'):
        self.rho = rho
        self.eps_abs = eps_abs
        self.eps_rel = eps_rel
        self.max_iter = max_iter
        self.primal_thd = primal_thd
        self.dual_thd = dual_thd
        self.dtype= dtype
    
   
    def get_external_params(self, adj, Theta):
        # compressed graph information
        n_nodes= len(adj)
        dim = Theta.shape[1]
        self.Adj_bool = adj>0 # bool representation of adjacency
        self.range_transp = range_transpose(self.Adj_bool)# indices for transposition operation
        self.weights = adj[self.Adj_bool].astype(self.dtype)
        self.n_edges_2 = np.sum(self.Adj_bool)

        # optimization related quantities
        # variables
        # FIXME if zero initialization is chosen later, write these lines 
        self.u = 0*random.standard_normal((self.n_edges_2, dim), dtype= self.dtype)
        self.z = 0*random.standard_normal((self.n_edges_2, dim), dtype= self.dtype)
        # stopping criterion
        self.eps_abs_primal = sqrt(self.n_edges_2*dim)*self.eps_abs
        self.eps_abs_dual = sqrt(n_nodes*dim)*self.eps_abs

        range_nodes = np.arange(n_nodes, dtype= self.dtype)
        self.nodes_to_edges_2  = sparse.csr_array((range_nodes[:,None] ==  np.repeat(range_nodes, self.Adj_bool.sum(axis=1))).astype(self.dtype))
        self.dim = dim

        return self

    def solve(self, first_term, l_net= 1.0, progress_bar= False):
        primal_gap = 1e8#self.primal_thd + 1.0
        dual_gap = 1e8#self.dual_thd + 1.0

        z = self.z
        u = self.u

        rho = self.rho
        l_net_over_rhoxweights = l_net/self.rho*self.weights[:,None]
        iter_range = range(self.max_iter)
        if progress_bar: iter_range = tqdm(iter_range)

        for i in iter_range:
            z_old = z.copy()

            # algorithm steps
            w_mat = first_term.prox(z_old - u, self)
            w_rep = np.repeat(w_mat, repeats= self.Adj_bool.sum(axis= 1), axis= 0)#np.einsum('ik,jk->ijk', w_mat, ones)
            w_u = w_rep + u
            in_denom = w_u- w_u[self.range_transp]
            theta_param = np.maximum(0.5, 1-l_net_over_rhoxweights/(norm(in_denom, axis= 1, keepdims= True) + 1e-12))
            #theta = theta[:,None]
            # z = theta * w_u + (1-theta) * w_u[self.range_transp]
            z = w_u[self.range_transp] + theta_param * in_denom
            u  += w_rep - z

            z_diff = np.max(np.abs(z - z_old))
            if z_diff <= self.primal_thd: break

            
        
        self.z = z
        self.u = u

        return w_mat
    
class LstsqADMM():
    
    def __init__(self, l_reg= 0.0):
        self.l_reg = l_reg
        
    def get_external_params(self, solver, A= None, b= None):
        n_nodes= len(solver.Adj_bool)
        
        self.b = np.zeros((n_nodes, solver.dim)) if b is None else b
        self.b = self.b.astype(solver.dtype)

        A_ext = solver.rho*np.einsum("i,jk->ijk", solver.Adj_bool.sum(axis=1), np.eye(solver.dim, dtype= solver.dtype)) \
            + self.l_reg * np.tile(np.eye(solver.dim, dtype= solver.dtype), (n_nodes, 1, 1))
        if A is not None:
            A_ext += A
        A_ext = A_ext.astype(solver.dtype)
        
        self.invA_ext = np.linalg.inv(A_ext)
    
    def prox(self, z_minus_u, solver):
        b_ext = self.b + solver.rho*solver.nodes_to_edges_2.dot(z_minus_u)
        return np.einsum('kij, kj -> ki', self.invA_ext, b_ext)
    
    def primal_step(self, arg):
        pass 

    def update(self, u, x, y):
        self.invA_ext[u] -= rank1_update(self.invA_ext[u], x)
        self.b[u] += y*x
        return self
    

class ThetaSmootherADMM():
    
    def __init__(self, W_raw, dtype= 'float32'):
        self.dtype= dtype
        self.W_raw = W_raw.astype(self.dtype)
        self.dim = W_raw.shape[1]
        
    
    def get_external_params(self, solver):
        self.denom = 1 + solver.rho*solver.Adj_bool.sum(axis=1, keepdims= True)
        return self

    def prox(self, z_minus_u, solver):
        return (self.W_raw + solver.rho*solver.nodes_to_edges_2.dot(z_minus_u)) /self.denom

def transp(tensor):
    return np.transpose(tensor, (1,0,2))

# def jtransp(tensor):
#     return jnp.transpose(tensor, (1,0,2))

def extended_incidence(Adj):
    n_nodes = len(Adj)
    tri_inds = np.triu_indices(n_nodes)
    Adj_tri_vec = Adj[tri_inds]
    i_is_edge_source = (tri_inds[0] == np.arange(n_nodes)[:,None])
    edge_exists = Adj_tri_vec>0 
    return sparse.csr_array(edge_exists*i_is_edge_source).astype("float32")

def range_transpose(Adj_bool):
    """ Given a sparse matrix A, let inds represen t the indices of non null elements of flattened A, and
    let v represent the vector of non null elements of A. This function returns a permutation of indices of v 
    (from 0 to len(v) = nb of non null elements of A), such that v[returned_permuatation] is the vector of 
    non null elements of the transpose of A. 

    Parameters
    ----------
    Adj_bool : numpy.ndarray
        A boolean representation of the incidence matrix with True iif the weight is non null
    """
    n_edges_2 = Adj_bool.sum()
    inds = np.where(Adj_bool.flatten())[0] # indices of non null weights
    n_nodes = len(Adj_bool)
    support = -np.ones(n_nodes**2, dtype= "int64")
    support[inds] = np.arange(n_edges_2) # inject the range 0 : n_edges into the n_nodes**2 array
    range_n_2 = np.arange(n_nodes**2) # range of n_nodes **2 array
    support_permut = support[n_nodes*(range_n_2%n_nodes)+range_n_2//n_nodes] #transpose the support array, as if its corresponding matrix is transposed
    return support_permut[support_permut>=0]

class SolverPrimalDual:
    
    #### init func ###############################
    def __init__(self, tau= 0.999, max_iter= 500, abs_tol= 1e-8, rel_tol= 1e-8, dtype= 'float32'):
        
        self.tau = tau # sufficient to be chosen in (0,1)
        self.abs_tol = abs_tol
        self.rel_tol = rel_tol
        self.max_iter = max_iter # number of iterations of running the optimization code block
        self.dtype= dtype

    def get_external_params(self, Adj, Theta):

        self.Adj = Adj.astype(self.dtype) 
        self.dim = Theta.shape[1] 
        
        G = nx.from_numpy_array(Adj)
        self.B = nx.incidence_matrix(G,weight='weight', oriented= True).T.astype(self.dtype)
        # try:
        self.Sigma_B = self.B / np.sum(np.abs(self.B),axis=1)[:,None] # NOTE keepdims does not work here
        # except ValueError: # for older scipy versions
        #     self.Sigma_B = self.B / np.sum(np.abs(self.B),axis=1)
            
    def U_hat_update(self, U_bar, alpha_par):

        temp_norm = np.linalg.norm(U_bar, axis=1, keepdims= True)
        temp = 1 - alpha_par*np.reciprocal(temp_norm +1e-16)
        temp = np.maximum(0,temp)

        return U_bar - temp*U_bar
    
   
    
    #### new func ######################################
    
    def solve(self, first_term, l_net, theta0, U0, progress_bar= False):
        
        # E,N = self.B.shape
        U_old = U0#.copy()
        Theta_old = theta0#.copy()
        # recorded_costs = []
        # gradient_norms = []
        # Theta_diff_list = []
        # U_diff_list = []
        # b_mat = np.einsum("ij, ijk,i->ik", y, X, self.T_vec)
        
        # inverted_matrices = self.matrix_inversion_tensor(XX, lambda_par)
        iter_range = range(self.max_iter)
        if progress_bar: iter_range = tqdm(iter_range)
        for _ in iter_range:   
            
            # Primal update
            Theta = first_term.prox(Theta_old - first_term.T_B_transpose.dot(U_old))

            # Dual update
            U_bar = U_old + self.Sigma_B.dot(2*Theta-Theta_old) #  dual update
            U = self.U_hat_update(U_bar, l_net)
           
            # Cost evaluation
            # cost, gradient = self.cost_function(XX, yX, y_norm_sum, Theta, lambda_par, alpha_par, U)
            
            
            # Stopping
            # grad_norm = np.max(np.abs(gradient))
            Theta_diff = np.max(np.abs(Theta - Theta_old))
            # U_diff = np.max(np.abs(U - U_old))
            # if grad_norm <= self.abs_tol: break
            # if max(Theta_diff, U_diff) <= self.abs_tol: break
            if Theta_diff <= self.abs_tol: break
           
            # log optimization progress 
            # recorded_costs.append(cost)
            # gradient_norms.append(grad_norm)
            # Theta_diff_list.append(Theta_diff)
            # U_diff_list.append(U_diff)
            
            # Update primal and dual for next round
            Theta_old = Theta#.copy()
            U_old = U#.copy()

        return Theta, U#, recorded_costs, gradient_norms, Theta_diff_list, U_diff_list
    


class LstsqPrimalDual():
    # TODO: make necessary changes
    def __init__(self, l_reg= 0.0):
        self.l_reg = l_reg
        
    def get_external_params(self, solver, A= None, b= None):
        n_nodes= len(solver.Adj)
        self.T_vec = self.T_vec_func(solver).astype(solver.dtype)
        self.T_B_transpose = self.T_vec * solver.B.T
        
        self.b = np.zeros((n_nodes, solver.dim), dtype= solver.dtype) if b is None else self.T_vec*b
        A_ext = np.zeros((n_nodes, solver.dim, solver.dim), dtype= solver.dtype)
        range_dim = np.arange(solver.dim)
        A_ext[:, range_dim, range_dim] = np.ones(solver.dim, dtype= solver.dtype) + self.l_reg*self.T_vec
        if A is not None: 
            A_ext += self.T_vec[:,None]*A.astype(solver.dtype)
        self.invA_ext = np.linalg.inv(A_ext)
        
    def prox(self, input_mat):
        return np.einsum("kij,kj->ki", self.invA_ext, input_mat + self.b)
    
    def update(self, u, x, y):
        invA_ext_x = self.invA_ext[u] @ x
        self.invA_ext[u] -= self.T_vec[u]*np.outer(invA_ext_x, invA_ext_x)/(1 + self.T_vec[u]*x @ invA_ext_x)
        self.b[u] += self.T_vec[u]*y*x
        return self
    
    def T_vec_func(self, solver):
        degrees = np.sum(solver.Adj,axis=1, keepdims= True)
        return solver.tau/degrees


def cost(XX, yX, Theta, B, alpha_opti, lambda_opti= 0.0):
        # FIXME make 0.5 factor in lstsq term

        term_1 = np.einsum("ij,ij", np.einsum('kij, kj -> ki', XX, Theta)/2 - yX, Theta)

        reg = np.sum(Theta*Theta)
        
        total_variation = np.linalg.norm(B.dot(Theta), axis= 1).sum()
                
        return term_1 + lambda_opti*reg + alpha_opti*total_variation

def cost_and_grad(self, XX, yX, Theta, lambda_opti, alpha_opti, U):
        # FIXME make 0.5 factor in lstsq term

        term_1_grad = np.einsum('kij, kj->ki', XX, Theta) - yX

        term_1 = np.einsum("ij,ij", term_1_grad - yX, Theta)/2

        reg = np.sum(Theta*Theta)
        
        total_variation = np.linalg.norm(self.B.dot(Theta), axis= 1).sum()
        
        cost = term_1 + lambda_opti*reg + alpha_opti*total_variation
        grad = term_1_grad + 2*lambda_opti*Theta + self.B.T @ U
        
        return cost, grad

# @njit(fastmath=  True, nopython= True)

# def my_einsum(M,N):
#     res = np.zeros((M.shape[0], M.shape[1]))
#     for k in range(M.shape[0]):
#         for i in range(M.shape[1]):
#             for j in range(M.shape[2]):
#                 res[k,i] += M[k,i,j]*N[i,j]
#     return res 
# %%
