from .utils import *
from . import affine_set
from . import halfspace
from . import lp_balls
from . import set_functions
from . import spec_proj_ops
from . import rounding
import torch

def project (xs, adjs, constraint_config, zero_tol=1e-4):
    # Proj(input)
    if constraint_config.constraint == 'L1-adj':
        # bisection
        if constraint_config.params[0] == 'zeros':
            adj0 = torch.zeros_like (adjs)
        else:
            adj0 = torch.load (constraint_config.params[0])
        budget = constraint_config.params[1]
        # assuming undirected...
        row_inds, col_inds = torch.triu_indices(adjs.shape[-1], adjs.shape[-1])
        proj_adjs = []
        adj_vec = (adjs - adj0)[:, row_inds, col_inds]
        for i in range(adjs.shape[0]):
            # print ('Before', torch.norm((adjs[i] - adj0[i]).reshape(-1), p=1))
            adj_proj = lp_balls.l1_ball_vec(adj_vec[i], budget)
            proj_adj = torch.zeros_like (adjs[i])
            proj_adj[row_inds, col_inds] = adj0[i][row_inds, col_inds] + adj_proj
            proj_adj = proj_adj + proj_adj.T
            proj_adjs.append(proj_adj)
            # print ('After', torch.norm((adj_proj).reshape(-1), p=1))
        proj_xs, proj_adjs = xs, adj0 + torch.stack(proj_adjs)
        # print (satisfies(proj_adjs, proj_xs, constraint_config).sum())
    elif constraint_config.constraint == 'L2-adj':
        # bisection
        if constraint_config.params[0] == 'zeros':
            adj0 = torch.zeros_like (adjs)
        else:
            adj0 = torch.load (constraint_config.params[0])
        budget = constraint_config.params[1]
        # assuming undirected...
        row_inds, col_inds = torch.triu_indices(adjs.shape[-1], adjs.shape[-1])
        # import time
        # start_time = time.time()
        adj_vec = (adjs - adj0)[:, row_inds, col_inds]
        adj_proj = lp_balls.l2_ball_vecs(adj_vec, budget)
        proj_adjs = torch.zeros_like (adjs)
        proj_adjs[:, row_inds, col_inds] = adj0[:, row_inds, col_inds] + adj_proj
        proj_xs, proj_adjs = xs, proj_adjs + torch.transpose(proj_adjs, 1, 2)
        # print (time.time() - start_time)
    elif constraint_config.constraint == 'Spectral-radius':
        # eigendecomposition
        # print (sum([torch.all(adjs[i] == adjs[i].T) for i in range(adjs.shape[0])]))
        proj_xs = xs
        proj_adjs = spec_proj_ops.project_specset_multiple (adjs, 'eigmax_adj_ub', params=constraint_config.params)
    elif constraint_config.constraint == 'Rank':
        # eigendecomposition
        proj_adjs = []
        for i in range(adjs.shape[0]):
            proj_adjs.append(spec_proj_ops.project_specset(adjs[i], 'rank_lap_ub', params=constraint_config.params))
        proj_xs, proj_adjs = xs, torch.stack(proj_adjs)
    elif constraint_config.constraint == 'Nconn_atleast':
        # eigendecomposition
        proj_xs = xs
        proj_adjs = spec_proj_ops.project_specset_multiple(adjs, 'nzeros_lap_lb', params=constraint_config.params)
    elif constraint_config.constraint == 'Nconn_atmost':
        # eigendecomposition
        proj_xs = xs
        proj_adjs = spec_proj_ops.project_specset_multiple(adjs, 'nzeros_lap_ub', params=constraint_config.params)
    elif constraint_config.constraint == 'Eigenvalue-Box':
        # eigendecomposition
        proj_xs = xs
        proj_adjs = spec_proj_ops.project_specset_multiple (adjs, 'eigi_lap_bound', 
                                                            params=constraint_config.params)
    elif constraint_config.constraint == 'Cheeger-bound':
        # eigendecomposition
        proj_xs = xs
        proj_adjs = spec_proj_ops.project_specset_multiple (adjs, 'cheeger_bound', 
                                                            params=constraint_config.params)
    elif constraint_config.constraint == 'Num-triangles':
        # eigendecomposition
        proj_xs = xs
        adjs3 = torch.matrix_power(adjs, 3)
        adjs3_diag = torch.diagonal(adjs3, dim1=1, dim2=2)
        for i in range(adjs3_diag.shape[0]):
            adj3_diag = adjs3_diag[i]
            new_diag = lp_balls.l1_ball_vec(adj3_diag, 6*constraint_config.params[0])
            adjs3[i, torch.arange(adjs3_diag.shape[1]), torch.arange(adjs3_diag.shape[1])] = new_diag
        eigs3, Qs = torch.linalg.eigh (adjs3)
        eigs3[torch.abs(eigs3) <= zero_tol] = 0
        proj_eigs = torch.sign(eigs3) * torch.pow (torch.abs(eigs3), 1/3)
        proj_adjs = Qs @ torch.stack([torch.diag(s) for s in proj_eigs]) @ torch.transpose(Qs, 1, 2)
        # print (torch.diagonal(torch.matrix_power(proj_adjs, 3), dim1=1, dim2=2).sum(dim=1)/6)
        # proj_adjs = spec_proj_ops.project_specset_multiple (adjs, 'num_triangles', 
        #                                                     params=constraint_config.params)
        # print (satisfies(proj_adjs, proj_xs, constraint_config))
    elif constraint_config.constraint == 'Diameter':
        # eigendecomposition
        proj_xs = xs
        proj_adjs = spec_proj_ops.project_specset_multiple (adjs, 'diameter', 
                                                            params=constraint_config.params)
        # print (satisfies(proj_adjs, proj_xs, constraint_config))
    elif constraint_config.constraint == 'Eigensum':
        # eigendecomposition
        proj_xs = xs
        proj_adjs = spec_proj_ops.project_specset_multiple (adjs, 'eigsum_lap_bound', 
                                                            params=constraint_config.params)
    elif constraint_config.constraint == 'Valency':
        valencies = torch.tensor(constraint_config.params[0], dtype=xs.dtype, device=xs.device)
        # this can be optimized 
        proj_adjs, proj_xs = affine_set.valency_projection_multiple(adjs, xs, valencies)
    elif constraint_config.constraint == 'Atom-Count':
        atomCounts = torch.tensor(constraint_config.params[0], dtype=xs.dtype, device=xs.device)
        proj_adjs, proj_xs = affine_set.atomCount_projection_multiple(adjs, xs, atomCounts)
    elif constraint_config.constraint == 'Mol-Weight':
        atomWeights = torch.tensor (constraint_config.params[0], dtype=xs.dtype, device=xs.device)
        max_weight = torch.tensor(constraint_config.params[1], dtype=float)
        proj_adjs, proj_xs = halfspace.molwt_projection_multiple(adjs, xs, weights=atomWeights,
                                                                max_weight=max_weight)
    elif constraint_config.constraint == 'Regression':
        c_theta = torch.load(constraint_config.params[0])
        b_theta = torch.load(constraint_config.params[1])
        proj_adjs, proj_xs = halfspace.reg_projection_multiple (adjs, xs, c_theta=c_theta, b_theta=b_theta)
    else:
        raise NotImplementedError (f"{constraint_config.constraint} not supported")

    return proj_xs, proj_adjs

def drift_transformProject (xs, adjs, constraint_config, dataset="community_small"):
    drift = constraint_config.method.gamma if "method" in constraint_config else 1
    if "method" in constraint_config and drift == 0:
        return xs, adjs
    # T(input)
    xs_prev, adjs_prev = xs.clone(), adjs.clone()
    xs_diff_prev, adjs_diff_prev = float('inf'), float('inf')
    xtol, adjtol = 1e-3, 1e-3
    niter, maxiters = 1, 100
    while niter <= maxiters:
        T_xs, T_adjs = implicitConstr_transform(xs_prev, adjs_prev, dataset)
        # adjs = torch.clamp (adjs, min=min(adj_vals), max=max(adj_vals))
        # xs = torch.clamp (xs, min=min(feat_vals), max=max(feat_vals)) if len(feat_vals) > 0 else xs
        projT_xs, projT_adjs = project(T_xs, T_adjs, constraint_config)
        print (niter, torch.sum ((projT_adjs - T_adjs).abs()))
        # print (niter, torch.all(torch.all((projT_xs - T_xs).abs() <= xtol, dim=-1), dim=-1).sum(), 
        #                 torch.all(torch.all((projT_adjs - T_adjs).abs() <= adjtol, dim=-1), dim=-1).sum())
        # print (satisfies (T_adjs, T_xs, constraint_config).sum())
        # print (satisfies (projT_adjs, projT_xs, constraint_config).sum())
        # print (torch.diagonal(torch.matrix_power(T_adjs, 3), dim1=1, dim2=2).sum(dim=1)/6)
        xs_diff = (projT_xs - xs_prev).abs().sum(dim=-1).sum(dim=-1).mean()
        adjs_diff = (projT_adjs - adjs_prev).abs().sum(dim=-1).sum(dim=-1).mean()
        print (niter, xs_diff, adjs_diff)
        if (xs_diff > xs_diff_prev) or (adjs_diff > adjs_diff_prev):
            break
        if (torch.all((projT_xs - T_xs).abs() <= xtol) and torch.all((projT_adjs - T_adjs).abs() <= adjtol)):
            print (satisfies (T_adjs, T_xs, constraint_config).sum())
            print (satisfies (projT_adjs, projT_xs, constraint_config).sum())
            break
        xs_prev, adjs_prev = projT_xs, projT_adjs
        xs_diff_prev, adjs_diff_prev = xs_diff, adjs_diff
        niter += 1
    print (torch.diagonal(torch.matrix_power(T_adjs, 3), dim1=1, dim2=2).sum(dim=1)/6)
    proj_xs, proj_adjs = projT_xs, projT_adjs
    # # taking inverse of the transformation
    # adjs_indices = torch.where(adjs_mask)
    # proj_adjs = adjs.clone()
    # proj_adjs[adjs_indices] = T_adjs[adjs_indices]
    # # 
    # xs_indices = torch.where(xs_mask)
    # proj_xs = xs.clone()
    # proj_xs[xs_indices] = T_xs[xs_indices]
    return xs + drift * (proj_xs - xs), adjs + drift * (proj_adjs - adjs)

def drifted_project (xs, adjs, constraint_config):
    # drifted projection. Default = 1 (i.e., projected diffusion)
    drift = constraint_config.method.gamma if "method" in constraint_config else 1
    if "method" in constraint_config and drift == 0:
        return xs, adjs
    else:
        proj_xs, proj_adjs = project (xs, adjs, constraint_config)
        return xs + drift * (proj_xs - xs), adjs + drift * (proj_adjs - adjs)

