from Topo_utils import threshold_W, create_Z, find_idx_set, create_new_topo, create_new_topo_greedy,gradient_l1
import numpy as np
import scipy.linalg as slin
from copy import copy


def my_permutation(topo):
    P = np.eye(len(topo))
    P = P[topo, :]
    return P


class topo_colide_ev:
    def __init__(self, seed=0):
        super().__init__()
        np.random.seed(seed)
    
    def score(self, W, sigma):
        dif = self.Id - W 
        rhs = self.cov @ dif
        loss = ((0.5 * np.trace(dif.T @ rhs)) / sigma) + (0.5 * sigma * self.d) + self.lambda1 * np.abs(W).sum()
        G_loss = (-rhs / sigma) + self.lambda1 * np.sign(W)
        return loss, G_loss
    
    def _adam_update(self, grad, iter, beta_1, beta_2):
        self.opt_m = self.opt_m * beta_1 + (1 - beta_1) * grad
        self.opt_v = self.opt_v * beta_2 + (1 - beta_2) * (grad ** 2)
        m_hat = self.opt_m / (1 - beta_1 ** iter)
        v_hat = self.opt_v / (1 - beta_2 ** iter)
        grad = m_hat / (np.sqrt(v_hat) + 1e-8)
        return grad
    
    def _solve_W(self, U0, P, sigma0, lr=0.001, beta_1=0.99, beta_2=0.999):
        self.opt_m, self.opt_v = 0, 0
        sigma = sigma0.copy()
        U_temp = U0.copy()
        for iter in range(1, self.inner_max_itr+1):
            Gobj = -(self.cov @ (self.Id - (P.T @ U_temp @ P)) / sigma) + self.lambda1 * np.sign((P.T @ U_temp @ P))
            grad = self._adam_update(Gobj, iter, beta_1, beta_2)
            U_temp -= P @ (lr * grad) @ P.T
            U_temp *= self.U
        return (P.T @ U_temp @ P), U_temp
    
    def _init_W(self, topo, sigma0):
        P = my_permutation(topo)
        U_temp = np.zeros((self.d, self.d))
        sigma = sigma0.copy()
        obj_prev = 1e16
        for iter in range(1, self.outer_max_itr+1):
            W, U_temp = self._solve_W(U_temp, P, sigma)
            dif = self.Id - W 
            rhs = self.cov @ dif
            sigma = np.sqrt(np.trace(dif.T @ rhs) / (self.d))
            if iter % 20 == 0:
                obj_new, _ = self.score(W , sigma)
                if np.abs((obj_prev - obj_new) / obj_prev) <= 1e-6:
                    break
                obj_prev = obj_new
        return W, sigma

    def _update_topo_linear(self, topo, idx, sigma0, opt=1):
        topo0 = copy(topo)
        topo0 = create_new_topo(topo=topo0, idx=idx, opt=opt)
        P = my_permutation(topo0)
        U_temp = np.zeros((self.d, self.d))
        sigma = sigma0.copy()
        obj_prev = 1e16
        for iter in range(1, self.outer_max_itr+1):
            W, U_temp = self._solve_W(U_temp, P, sigma)
            dif = self.Id - W 
            rhs = self.cov @ dif
            sigma = np.sqrt(np.trace(dif.T @ rhs) / (self.d))
            if iter % 20 == 0:
                obj_new, _ = self.score(W , sigma)
                if np.abs((obj_prev - obj_new) / obj_prev) <= 1e-6:
                    break
                obj_prev = obj_new
        return W, topo0, sigma
    
    def _h(self, W):
        I = np.eye(self.d)
        s = 1
        M = s * I - np.abs(W)
        h = - np.linalg.slogdet(M)[1] + self.d * np.log(s)
        G_h = slin.inv(M).T

        return h, G_h

    def fit(self, X, topo: list, no_large_search, size_small, size_large, inner_max_itr=10, outer_max_itr=100):
        self.n, self.d = X.shape
        self.inner_max_itr, self.outer_max_itr = inner_max_itr, outer_max_itr
        self.X = X
        self.Id = np.eye(self.d).astype(np.float64)
        self.U = np.triu(np.ones((self.d,self.d)), k=1)
        self.cov = X.T @ X / float(self.n)
        self.lambda1 = 0.05
        self.sigma = np.min(np.linalg.norm(self.X, axis=0) / np.sqrt(self.n)).astype(np.float64)
        iter_count = 0
        large_space_used = 0
        if not isinstance(topo, list):
            raise TypeError
        else:
            self.topo = topo

        Z = create_Z(self.topo)
        self.Z = Z
        self.W, self.sigma = self._init_W(self.topo, sigma0=self.sigma)
        loss, G_loss = self.score(W=self.W, sigma=self.sigma)
        h, G_h = self._h(W=self.W)
        idx_set_small, idx_set_large = find_idx_set(G_h=G_h, G_loss=G_loss, Z=self.Z, size_small=size_small,
                                                    size_large=size_large)
        idx_set = list(idx_set_small)
        while bool(idx_set):

            idx_len = len(idx_set)
            loss_collections = np.zeros(idx_len)

            for i in range(idx_len):
                W_c, topo_c, sigma_c = self._update_topo_linear(topo = self.topo,idx = idx_set[i], sigma0=self.sigma)
                loss_c,_ = self.score(W = W_c, sigma=sigma_c)
                loss_collections[i] = loss_c
            
            if np.any(loss > np.min(loss_collections)):
                self.topo = create_new_topo_greedy(self.topo,loss_collections,idx_set,loss)

            else:
                if large_space_used < no_large_search:
                    idx_set = idx_set_large.difference(idx_set_small)
                    idx_set = list(idx_set)
                    idx_len = len(idx_set)
                    loss_collections = np.zeros(idx_len)
                    for i in range(idx_len):
                        W_c, topo_c, sigma_c = self._update_topo_linear(topo=self.topo, idx=idx_set[i], sigma0=self.sigma)
                        loss_c, _ = self.score(W=W_c, sigma=sigma_c)
                        loss_collections[i] = loss_c

                    if np.any(loss > loss_collections):
                        large_space_used += 1
                        self.topo = create_new_topo_greedy(self.topo, loss_collections, idx_set, loss)
                    else:
                        print("Using larger search space, but we cannot find better loss")
                        break


                else:
                    print("We reach the number of chances to search large space, it is {}".format(
                        no_large_search))
                    break

            self.Z = create_Z(self.topo)
            self.W, self.sigma = self._init_W(self.topo, sigma0=self.sigma)
            loss, G_loss = self.score(W=self.W, sigma=self.sigma)
            h, G_h = self._h(W=self.W)
            idx_set_small, idx_set_large = find_idx_set(G_h=G_h, G_loss=G_loss, Z=self.Z, size_small=size_small,
                                                        size_large=size_large)
            idx_set = list(idx_set_small)

            iter_count += 1
        print("Max count:", iter_count)
        return self.W, self.sigma, self.topo, loss

class topo_colide_nv:
    def __init__(self, seed=0):
        super().__init__()
        np.random.seed(seed)
    
    def score(self, W, sigma):
        dif = self.Id - W 
        rhs = self.cov @ dif
        inv_SigMa = np.diag(1.0/(sigma))
        loss = ((np.trace(inv_SigMa @ (dif.T @ rhs)) + np.sum(sigma)) / (2.0)) + self.lambda1 * np.abs(W).sum()
        G_loss = (-rhs @ inv_SigMa) + self.lambda1 * np.sign(W)
        return loss, G_loss
    
    def _adam_update(self, grad, iter, beta_1, beta_2):
        self.opt_m = self.opt_m * beta_1 + (1 - beta_1) * grad
        self.opt_v = self.opt_v * beta_2 + (1 - beta_2) * (grad ** 2)
        m_hat = self.opt_m / (1 - beta_1 ** iter)
        v_hat = self.opt_v / (1 - beta_2 ** iter)
        grad = m_hat / (np.sqrt(v_hat) + 1e-8)
        return grad
    
    def _solve_W(self, U0, P, sigma0, lr=0.001, beta_1=0.99, beta_2=0.999):
        self.opt_m, self.opt_v = 0, 0
        sigma = sigma0.copy()
        U_temp = U0.copy()
        for iter in range(1, self.inner_max_itr+1):
            inv_SigMa = np.diag(1.0/(sigma))
            Gobj = -(self.cov @ (self.Id - (P.T @ U_temp @ P)) @ inv_SigMa) + self.lambda1 * np.sign((P.T @ U_temp @ P))
            grad = self._adam_update(Gobj, iter, beta_1, beta_2)
            U_temp -= P @ (lr * grad) @ P.T
            U_temp *= self.U
        return (P.T @ U_temp @ P), U_temp
    
    def _init_W(self, topo, sigma0):
        P = my_permutation(topo)
        U_temp = np.zeros((self.d, self.d))
        sigma = sigma0.copy()
        obj_prev = 1e16
        for iter in range(1, self.outer_max_itr+1):
            W, U_temp = self._solve_W(U_temp, P, sigma)
            dif = self.Id - W
            rhs = self.cov @ dif
            sigma = np.sqrt(np.diag(dif.T @ rhs))
            if iter % 20 == 0:
                obj_new, _ = self.score(W , sigma)
                if np.abs((obj_prev - obj_new) / obj_prev) <= 1e-6:
                    break
                obj_prev = obj_new
        return W, sigma

    def _update_topo_linear(self, topo, idx, sigma0, opt=1):
        topo0 = copy(topo)
        topo0 = create_new_topo(topo=topo0, idx=idx, opt=opt)
        P = my_permutation(topo0)
        U_temp = np.zeros((self.d, self.d))
        sigma = sigma0.copy()
        obj_prev = 1e16
        for iter in range(1, self.outer_max_itr+1):
            W, U_temp = self._solve_W(U_temp, P, sigma)
            dif = self.Id - W 
            rhs = self.cov @ dif
            sigma = np.sqrt(np.diag(dif.T @ rhs))
            if iter % 20 == 0:
                obj_new, _ = self.score(W , sigma)
                if np.abs((obj_prev - obj_new) / obj_prev) <= 1e-6:
                    break
                obj_prev = obj_new
        return W, topo0, sigma
    
    def _h(self, W):
        I = np.eye(self.d)
        s = 1
        M = s * I - np.abs(W)
        h = - np.linalg.slogdet(M)[1] + self.d * np.log(s)
        G_h = slin.inv(M).T

        return h, G_h

    def fit(self, X, topo: list, no_large_search, size_small, size_large, inner_max_itr=10, outer_max_itr=100):
        self.n, self.d = X.shape
        self.inner_max_itr, self.outer_max_itr = inner_max_itr, outer_max_itr
        self.X = X
        self.Id = np.eye(self.d).astype(np.float64)
        self.U = np.triu(np.ones((self.d,self.d)), k=1)
        self.cov = X.T @ X / float(self.n)
        self.lambda1 = 0.05
        self.sigma = (np.linalg.norm(self.X, axis=0) / np.sqrt(self.n)).astype(np.float64)
        iter_count = 0
        large_space_used = 0
        if not isinstance(topo, list):
            raise TypeError
        else:
            self.topo = topo

        Z = create_Z(self.topo)
        self.Z = Z
        self.W, self.sigma = self._init_W(self.topo, sigma0=self.sigma)
        loss, G_loss = self.score(W=self.W, sigma=self.sigma)
        h, G_h = self._h(W=self.W)
        idx_set_small, idx_set_large = find_idx_set(G_h=G_h, G_loss=G_loss, Z=self.Z, size_small=size_small,
                                                    size_large=size_large)
        idx_set = list(idx_set_small)
        while bool(idx_set):

            idx_len = len(idx_set)
            loss_collections = np.zeros(idx_len)

            for i in range(idx_len):
                W_c, topo_c, sigma_c = self._update_topo_linear(topo = self.topo,idx = idx_set[i], sigma0=self.sigma)
                loss_c,_ = self.score(W = W_c, sigma=sigma_c)
                loss_collections[i] = loss_c
            
            if np.any(loss > np.min(loss_collections)):
                self.topo = create_new_topo_greedy(self.topo,loss_collections,idx_set,loss)

            else:
                if large_space_used < no_large_search:
                    idx_set = idx_set_large.difference(idx_set_small)
                    idx_set = list(idx_set)
                    idx_len = len(idx_set)
                    loss_collections = np.zeros(idx_len)
                    for i in range(idx_len):
                        W_c, topo_c, sigma_c = self._update_topo_linear(topo=self.topo, idx=idx_set[i], sigma0=self.sigma)
                        loss_c, _ = self.score(W=W_c, sigma=sigma_c)
                        loss_collections[i] = loss_c

                    if np.any(loss > loss_collections):
                        large_space_used += 1
                        self.topo = create_new_topo_greedy(self.topo, loss_collections, idx_set, loss)
                    else:
                        print("Using larger search space, but we cannot find better loss")
                        break


                else:
                    print("We reach the number of chances to search large space, it is {}".format(
                        no_large_search))
                    break

            self.Z = create_Z(self.topo)
            self.W, self.sigma = self._init_W(self.topo, sigma0=self.sigma)
            loss, G_loss = self.score(W=self.W, sigma=self.sigma)
            h, G_h = self._h(W=self.W)
            idx_set_small, idx_set_large = find_idx_set(G_h=G_h, G_loss=G_loss, Z=self.Z, size_small=size_small,
                                                        size_large=size_large)
            idx_set = list(idx_set_small)

            iter_count += 1
        print("Max count:", iter_count)
        return self.W, self.sigma, self.topo, loss