import torch as ch
import warnings
import time
import os

import numpy as np

# reference from https://gist.github.com/mblondel/6f3b7aaad90606b98f71
def projection_simplex_sort(v, z=1):
    n_features = v.shape[0]
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u) - z
    ind = np.arange(n_features) + 1
    cond = u - cssv / ind > 0
    rho = ind[cond][-1]
    theta = cssv[cond][-1] / float(rho)
    w = np.maximum(v - theta, 0)
    return w

def proj_simplex(v,z=1): 
    v = v.view(-1)
    n_features = v.size(0)
    u,_ = ch.sort(v, descending=True)
    cssv = u.cumsum(dim=0)-z
    ind = ch.arange(1,n_features+1)
    cond = u - cssv / ind > 0
    rho = ind[cond][-1]
    theta = cssv[cond][-1]/(rho.float())
    w = (v-theta).clamp(min=0)
    return w

# Components of the active set simplex method

def lam(x,g): 
    return x.view(-1).dot(g.view(-1))

def mu(x,g): 
    return g-lam(x,g)

def active_set(x,g,eps): 
    return x <= eps*mu(x,g)

def armijo(x,d,f,ip,amax=1,delta=0.9, gamma=0.9, maxiters=100): 
    a=amax
    try: 
        cond = f(x + a*d) > f(x) + gamma*a*ip
    except: 
        cond = True

    i=0
    while cond and i < maxiters: 
        fx = f(x)
        a = delta*a
        i += 1
        try: 
            cond = f(x + a*d) > fx + gamma*a*ip
        except: 
            cond = True
    return a

# Active set solver
def as_simplex(x0,f,grad_fn=None,maxiters=100,eps=0.1,lr=0.1,tol=1e-6,beta=0.9, 
               soft_A=None, verbose=None, checkpoint=None, grad_eps=0.05, restart=False): 
    start_time = time.time()
    x = x0.clone().detach()
    x.requires_grad = True

    if checkpoint is not None and os.path.isfile(checkpoint) and restart==False: 
        history = ch.load(checkpoint)
        x.data = history[-1]['x']
        soft_A = history[-1]['A']
    else: 
        soft_A = ch.ones_like(x)*0.5
        history = [{'x': x0.clone().cpu(), 'A': soft_A.clone().cpu()}]

    start_idx = len(history)-1
    for i in range(start_idx,maxiters): 
        counts = (x*1000).round().long()
        # for i in counts.nonzero().squeeze(): 
        #     print(i,counts[i].item())
        if ((counts < 50).all() and (counts == 0).sum() == (x.size(0)-1)) or (counts < 10).all():
            print("Warning: converged to zero, reinitialization")
            p = ch.randperm(x.size(0))
            x = ch.zeros(x.size(0))
            x[p[:25]] = 0.04
            soft_A = ch.ones_like(x)*0.5

        loss = f(x)
        if verbose and i % verbose == 0: 
            sp = (x > 0).float().mean().item()
            vals,inds = ch.topk(x,min(5,x.size(0)))
            top5 = ', '.join([f'{j.item()}: {v.item():.3f}' for j,v in zip(inds,vals)])
            print(f'Iter {i}: loss {loss.item()} sp {sp} top5 ({top5}) time {time.time() - start_time}')
        start_time = time.time()
        if grad_fn is None: 
            g = ch.autograd.grad([loss],[x])[0]
        else: 
            g = grad_fn(f,x)
        
        # Calculate active set and select index
        A0 = active_set(x,g,eps)
        soft_A = soft_A*beta + A0*(1-beta)
        A = soft_A > 0.5
        N = ~A
        # j = g.argmin()
        j0 = g[N].argmin()
        j = ch.arange(N.size(0))[N][j0].item()
        assert N[j]
        
        # corrective step
        xt = ch.zeros_like(x)
        xt[N] = x[N] # we aren't supposed to update x[j] here but it gets overwritten in the next step
        xt[j] = x[j] + x[A].sum()
        
        # feasible direction
        xt_ = xt - lr*g
        xt_[A] = 0
        d = proj_simplex(xt_) - xt
        d[A] = 0
        
        # line search
        ip = g.dot(d)
        # if ip < 0: 
        alpha = armijo(xt,d,f,ip)
        with ch.no_grad(): 
            x.data = xt + alpha*d
            # stop if numerically stable
            if tol is not None and (alpha*d).norm(p=1) < tol: 
                return history,xt

        history.append({'x': x.clone().detach().cpu(), 'A': soft_A.clone().cpu()})
        if checkpoint is not None: 
            ch.save(history, checkpoint)
        # else: 
        #     history.append(x.clone().detach().cpu())
        #     continue
    warnings.warn("AS-Simplex did not reach a stationary point")
    return history,xt


# Active set solver
def pgd(x0,f,grad_fn=None,maxiters=100,eps=0.1,lr=0.1,tol=1e-6,beta=0.9, 
               soft_A=None, verbose=None, checkpoint=None): 
    start_time = time.time()
    x = x0.clone().detach()
    x.requires_grad = True

    if checkpoint is not None and os.path.isfile(checkpoint): 
        history = ch.load(checkpoint)
        x.data = history[-1]['x']
        print(f"Resuming from checkpoint (history: {len(history)})")
    else: 
        history = [{'x': x0.clone().cpu()}]

    start_idx = len(history)-1
    for i in range(start_idx,maxiters): 
        loss = f(x)
        if verbose and i % verbose == 0: 
            sp = (x > 0).float().mean().item()
            vals,inds = ch.topk(x,min(5,x.size(0)))
            top5 = ', '.join([f'{j.item()}: {v.item():.3f}' for j,v in zip(inds,vals)])
            print(f'Iter {i}: loss {loss.item()} sp {sp} top5 ({top5}) time {time.time() - start_time}')
        start_time = time.time()
        if grad_fn is None: 
            g = ch.autograd.grad([loss],[x])[0]
        else: 
            g = grad_fn(f,x)
                
        # feasible direction
        xt_ = x - lr*g
        d = proj_simplex(xt_) - x
        
        # line search
        ip = g.dot(d)
        # if ip < 0: 
        alpha = armijo(x,d,f,ip)
        with ch.no_grad(): 
            x.data = x + alpha*d
            # stop if numerically stable
            if tol is not None and (alpha*d).norm(p=1) < tol: 
                return history,x

        history.append({'x': x.clone().detach().cpu()})
        if checkpoint is not None: 
            ch.save(history, checkpoint)
        # else: 
        #     history.append(x.clone().detach().cpu())
        #     continue
    warnings.warn("AS-Simplex did not reach a stationary point")
    return history,x

if __name__ == "__main__": 
    nf = 10
    
    # Test simplex projection with numpy reference
    for _ in range(100): 
        x = ch.randn(nf)
        assert(np.abs(proj_simplex(x).numpy() - projection_simplex_sort(x.numpy())) < 1e-4).all()

        # setup a toy problem
    x = ch.randn(nf).abs()
    x = x/x.sum()
    w = ch.randn(nf)
    w = -(w > 0).float()
    f = lambda x: x.dot(w) + x.norm(p=2)

    xt = as_simplex(x,f)
    print(xt)
