import torch
import numpy as np
from projop.lp_balls import *
from projop.utils import *
from projop.halfspace import *

def seq_clamp (v, bound=[0, np.inf], index=0, order='asc'):
    v = torch.flip(v, dims=[0]) if order=='desc' else v
    vclamp = []
    for i in range(v.shape[0]):
        if (i == index):
            vclamp.append(torch.clamp (v[i], min=bound[0], max=bound[1]).ravel())
        elif (i > index):
            vclamp.append(torch.clamp (v[i], min=vclamp[i-1].item()).ravel())
        else:
            vclamp.append(v[i].ravel())
    return torch.cat(vclamp)


def project_specset (A, phi_name, params=[], tol=1e-5):
    if 'lap' in phi_name:
        L = torch.diag (torch.sum(A, dim=1)) - A
        # eigs, Q = torch.linalg.eigh (L)
        U, eigs, Vh = torch.linalg.svd (L)
        eigs = torch.flip (eigs, [0])
        # eigs are returned in descending order
    elif 'adj' in phi_name:
        eigs, Q = torch.linalg.eigh (A)
        # eigs are returned in ascending order
    eigs[torch.abs(eigs) <= tol] = 0
    if phi_name == "rank_lap_ub":
        proj_eigs = l0_ball_vec (eigs, bound=params[0])
    elif "nzeros" in phi_name and "lb" in phi_name:
        # since it's a sorted vector, we can exploit it to do a trick
        bound = params[0]
        eigs_weights = torch.zeros_like (eigs)
        eigs_weights[:bound] = 1.
        proj_eigs = hyperplane_projection (eigs, eigs_weights, 0)
        # then the trick is eigs_weights.T eigs = 0 for it to have at least bound zeros
    elif phi_name == "eig1_lap_ub":
        proj_eigs = linf_ball_vec (eigs, bound=params[0])
    elif phi_name == "eigsum_lap_ub":
        proj_eigs = l1_ball_vec (eigs, bound=params[0])
    elif phi_name == "eigi_lap_bound":
        proj_eigs = seq_clamp (eigs, bound=[params[1], params[2]], index=params[0], order='asc')
    elif phi_name == "eig2m_lap_bound":
        proj_eigs = seq_clamp (eigs, bound=[params[0], params[1]], index=1, order='asc')
    elif phi_name == "eig2m_lap_ubound":
        proj_eigs = seq_clamp (eigs, bound=[0, params[0]], index=1, order='asc')
    elif phi_name == "eig2m_lap_lbound":
        proj_eigs = seq_clamp (eigs, bound=[params[0], np.inf], index=1, order='asc')
    elif "cheeger_bound" in phi_name:
        # cheeger given
        cheeger_chi = params[0]
        lbound, ubound = cheeger_chi**2/2, 2*cheeger_chi
        # find the first non-trivial index
        nnz_inds = eigs > 0
        nnz_srtd = nnz_inds * torch.arange(start=eigs.shape[0], end=0, step=-1, 
                                           device=A.device, dtype=A.dtype)
        first_nnz_inds = torch.argmax(nnz_srtd, 1)
        later_nnz_inds = nnz_inds.scatter(1, first_nnz_inds, False)
        proj_eigs = eigs.clone()
        proj_eigs[first_nnz_inds.squeeze()] = torch.clamp (eigs[first_nnz_inds.squeeze()], 
                                                            min=lbound, max=ubound)
        proj_eigs[later_nnz_inds] = torch.clamp (eigs[later_nnz_inds], min=ubound)
    if 'lap' in phi_name:
        # print (eigs.detach().cpu().numpy(), proj_eigs.detach().cpu().numpy())
        proj_eigs[torch.abs(proj_eigs) <= tol] = 0
        proj_eigs = torch.flip (proj_eigs, [0])
        L = U @ torch.diag(proj_eigs) @ Vh
        # print ((L - L.T).mean())
        A = torch.diag(torch.diag(L)) - L
        # eigs, Q = torch.linalg.eigh (torch.diag (torch.sum(A, dim=1)) - A)
        # print (proj_eigs[-1], eigs[-1])
        # A = (A + A.T)/2 # making undirected since there can be little approximation errors
    else:
        proj_eigs[torch.abs(proj_eigs) <= tol] = 0
        A = Q @ torch.diag(proj_eigs) @ Q.T
        # A = (A + A.T)/2 # making undirected since there can be little approximation errors
    return A


def project_specset_multiple (As, phi_name, params=[], tol=1e-4):
    if 'lap' in phi_name or 'cheeger' in phi_name:
        Ls = torch.stack([torch.diag (s) for s in torch.sum(As, dim=1)]) - As
        # eigs, Qs = torch.linalg.eigh (Ls)
        Us, eigs, Vhs = torch.linalg.svd (Ls)
        eigs = torch.flip (eigs, [1])
        # eigs are returned in descending order
        eigs[torch.abs(eigs) <= tol] = 0
    elif 'adj' in phi_name or 'num_triangles' in phi_name:
        # Us, eigs, Vhs = torch.linalg.svd (As)
        # eigs = torch.flip (eigs, [1])
        eigs, Qs = torch.linalg.eigh (As)
        # eigs are in ascending order
        eigs[torch.abs(eigs) <= tol] = 0
    if "rank" in phi_name:
        # as l0 norm
        proj_eigs = []
        for eig in eigs:
            proj_eigs.append(l0_ball_vec (eig, bound=params[0]))
        proj_eigs = torch.stack(proj_eigs)
    elif "nzeros" in phi_name and "lb" in phi_name:
        bound = params[0]
        proj_eigs = eigs.clone()
        proj_eigs[:, :bound] = 0
        # # since it's a sorted vector, we can exploit it to do a trick (this trick does not work as there can be negative values)
        # bound = params[0]
        # eigs_weights = torch.zeros_like (eigs[0])
        # eigs_weights[:bound] = 1.
        # # then the trick is eigs_weights.T eigs = 0 for it to have at least bound zeros or connected components
        # proj_eigs = hyperplane_projection_multiple (eigs, eigs_weights, 0)
    elif "nzeros" in phi_name and "ub" in phi_name:
        # random projection
        bound = params[0]
        proj_eigs = torch.clone(eigs)
        nnz_inds = eigs > 0
        nnz_srtd = nnz_inds * torch.arange(start=eigs.shape[1], end=0, step=-1, 
                                           device=As.device, dtype=As.dtype)
        first_nnz_inds = torch.argmax(nnz_srtd, 1)
        for i in range(eigs.shape[0]):
            if first_nnz_inds[i] > bound:
                proj_eigs[i, bound:first_nnz_inds[i]], _ = torch.sort(torch.rand(first_nnz_inds[i] - bound, dtype=eigs.dtype, device=eigs.device) * eigs[i, first_nnz_inds[i]])
    elif "eigmax" in phi_name and "ub" in phi_name:
        bound = params[0]
        proj_eigs = torch.where(torch.abs(eigs) <= bound, eigs, bound * torch.sign(eigs))
    elif "eigsum" in phi_name:
        proj_eigs = []
        for eig in eigs:
            proj_eigs.append(l1_ball_vec (eigs, bound=params[0]))
        proj_eigs = torch.stack(proj_eigs)
    elif "eigi" in phi_name:
        index = params[0]
        lbound, ubound = params[1:]
        proj_eigs = torch.cat((torch.clamp(eigs[:, :index], max=lbound), 
                               torch.clamp(eigs[:, index], min=lbound, max=ubound),
                               torch.clamp(eigs[:, index+1:], min=ubound)
                              ), dim=1)
    elif "cheeger_bound" in phi_name:
        # cheeger given
        cheeger_chi, dmax = params[0], params[1]
        lbound, ubound = cheeger_chi**2/(2*dmax), 2*cheeger_chi #*As.sum(dim=2).max(dim=1)[0]
        # find the first non-trivial index
        nnz_inds = eigs > 0
        nnz_srtd = nnz_inds * torch.arange(start=eigs.shape[1], end=0, step=-1, 
                                           device=As.device, dtype=As.dtype)
        first_nnz_inds = torch.argmax(nnz_srtd, 1, keepdim=True)
        later_nnz_inds = nnz_inds.scatter(1, first_nnz_inds, False)
        proj_eigs = eigs.clone()
        proj_eigs[torch.arange(eigs.shape[0]), 
                  first_nnz_inds.squeeze()] = torch.clamp (eigs[torch.arange(eigs.shape[0]), 
                                                                first_nnz_inds.squeeze()], 
                                                            min=lbound, max=ubound)
        proj_eigs[later_nnz_inds] = torch.clamp (eigs[later_nnz_inds], max=ubound)
        # print (torch.sum ((proj_eigs[torch.arange(proj_eigs.shape[0]), first_nnz_inds.squeeze()] >= lbound-tol*lbound) &
        #        (proj_eigs[torch.arange(proj_eigs.shape[0]), first_nnz_inds.squeeze()] <=  ubound+tol*ubound)))
        # print (lbound, ubound)
        # print (eigs[:2])
        # print (proj_eigs[:2])
    elif "num_triangles" in phi_name:
        bound = params[0]
        proj_eigs3 = hyperplane_projection_multiple (torch.pow(eigs, 3), torch.ones_like(eigs[0]), 6*bound)
        proj_eigs = torch.sign(proj_eigs3) * torch.pow (torch.abs(proj_eigs3), 1/3)
        # proj_eigs = torch.zeros_like(eigs)
        # for i in range(eigs.shape[0]):
        #     proj_eigs[i] = l1_ball_vec (eigs[i]**3, 6*bound)**(1/3)
    elif "diameter" in phi_name:
        D = params[0]
        Is = torch.stack([torch.eye(As.shape[1], dtype=As.dtype, device=As.device) for _ in range(As.shape[0])])
        IA_d = torch.matrix_power(Is + As, D)
        IA_d [IA_d.abs() <= tol] = torch.rand((IA_d.abs() <= tol).sum()).to(As.device)
        eigs, Qs = torch.linalg.eigh (IA_d)
        As = (Qs @ torch.stack([torch.diag (s**(1/D)) for s in eigs]) @ torch.transpose(Qs, 1, 2)) - Is
    # print (((proj_eigs <= tol).sum(dim=1) >= bound).sum())
    if 'lap' in phi_name or 'cheeger' in phi_name:
        proj_eigs[torch.abs(proj_eigs) <= tol] = 0
        proj_eigs = torch.flip (proj_eigs, [1])
        proj_eigs_diag = torch.stack([torch.diag (s) for s in proj_eigs])
        Ls = Us @ proj_eigs_diag @ Vhs
        # print ((L - L.T).mean())
        As = torch.stack([torch.diag(torch.diag(L)) for L in Ls]) - Ls
        _, eigs, _ = torch.linalg.svd (torch.stack([torch.diag (s) for s in torch.sum(As, dim=1)]) - As)
        # print (((eigs <= tol).sum(dim=1) >= bound).sum())
        # print ([(eig, proj_eig) for eig, proj_eig in zip(eigs, proj_eigs) if (eig <= tol).sum() != (proj_eig <= tol).sum()])
        # print (proj_eigs[0:2], eigs[0:2])
        # A = (A + A.T)/2 # making undirected since there can be little approximation errors
        # Ls = torch.stack([torch.diag (s) for s in torch.sum(As, dim=1)]) - As
        # Us, new_eigs, Vhs = torch.linalg.svd (Ls)
        # new_eigs[torch.abs(new_eigs) <= tol] = 0
        # nnz_inds = new_eigs > 0
        # nnz_srtd = nnz_inds * torch.arange(start=new_eigs.shape[1], end=0, step=-1, 
        #                                    device=As.device, dtype=As.dtype)
        # first_nnz_inds = torch.argmax(nnz_srtd, 1, keepdim=True)
        # later_nnz_inds = nnz_inds.scatter(1, first_nnz_inds, False)
        # new_eigs = new_eigs.clone()
        # new_eigs[torch.arange(new_eigs.shape[0]), 
        #           first_nnz_inds.squeeze()] = torch.clamp (new_eigs[torch.arange(new_eigs.shape[0]), 
        #                                                         first_nnz_inds.squeeze()], 
        #                                                     min=lbound, max=ubound)
        # new_eigs[later_nnz_inds] = torch.clamp (new_eigs[later_nnz_inds], max=ubound)
        # print (torch.sum ((new_eigs[torch.arange(new_eigs.shape[0]), first_nnz_inds.squeeze()] >= lbound-tol*lbound) &
        #        (new_eigs[torch.arange(new_eigs.shape[0]), first_nnz_inds.squeeze()] <=  ubound+tol*ubound)))
    elif 'adj' in phi_name or 'num_triangles' in phi_name:
        proj_eigs[torch.abs(proj_eigs) <= tol] = 0
        # proj_eigs = torch.flip (proj_eigs, [1])
        proj_eigs_diag = torch.stack([torch.diag (s) for s in proj_eigs])
        As = Qs @ proj_eigs_diag @ torch.transpose (Qs, 1, 2)
        # As = (As + torch.transpose(As, 1, 2))/2 # making undirected since there can be little approximation errors
    return As