import torch
from cvxpy import Problem
from torch import nn
import sys
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer

# Lower limit for positive numbers
LOWLIMIT = 1e-60

# Move tensor to GPU when possible
def w(x):
    if torch.cuda.is_available():
        return x.cuda()
    else:
        return x

# Truncated log function (avoid ln(0))
def log_th(x,eps=1e-8):
    if(x>eps):
        return x.log()
    else:
        return w(torch.tensor(eps).log())



""" L-BFGS algorithm

        Takes as input the gradient q, previous x_k and g_k concatenated in state.
        alp_m and beta_m are unused multipliers, H0 is the initial H matrix, lim is the past horizon limit

        Typical usage:

            q = LBFGS(g.clone(), self.state, self.alp_m, self.beta_m, self.H0, self.lim)


    """
def LBFGS(q, state, alp_m=1, beta_m=1, H0=1, lim=0):
    alpha = list()
    rho = list()
    beta = list()
    yk = list()
    sk = state[:, :, 1]
    for i in range(lim):
        if i == 0:
            yk.append(q - state[i, :, 0])
        else:
            yk.append(state[i - 1, :, 0] - state[i, :, 0])

        aux = yk[i].dot(sk[i])
        if aux.abs() < LOWLIMIT:  # workaround when rho is to small

            print("rho too small: " + str(aux.abs()), file=sys.stderr)
            #rho.append(1.0 / aux)
            rho.append(1.0 / (LOWLIMIT * aux.sign() + LOWLIMIT / 10))
        else:
            rho.append(1.0 / aux)
        alpha.append(rho[i] * sk[i].dot(q) * alp_m)
        if(rho[i]>0): # skip negative rho
            q = q - alpha[i] * yk[i]
    q = q * H0
    if lim > 0:
        val_aux = yk[0].dot(yk[0])
        if val_aux < LOWLIMIT: #workaround when val_aux is to small
            print("fact too small: " + str(val_aux), file=sys.stderr)
            fact = yk[0].dot(sk[0]) / (LOWLIMIT)
        else:
            fact = yk[0].dot(sk[0]) / (val_aux)

        if fact != fact:  # workaround when fact = nan or 0
            fact = w(torch.tensor(1.0))
        if fact == 0:
            fact = w(torch.tensor(LOWLIMIT))
        q = q * fact.abs()
    for i in range(lim - 1, -1, -1):
        beta.append(rho[i] * yk[i].dot(q) * beta_m)
        if (rho[i] > 0): # skip negative rho
            q = q + sk[i] * (alpha[i] - beta[lim - i - 1])
    return q


class LearnedLBFGS_v1(torch.nn.Module):
    """ original L-BFGS with additional unused parameters alpha and beta

        Typical usage:

        my_optim = LearnedLBFGS_v1()
        my_optim.reset_dim(t.n)
        my_optim.init_state(t.df())
        for ...
            d = my_optim.forward(t.df())
            t.w = t.w + d.view(-1)
    """
    def __init__(self, m_in=5, n_in=1, nh = 5, adpt_step=False, line_search=False, tau=0.5, c=0.5, ulim=1.0):
        super(LearnedLBFGS_v1, self).__init__()

        self.alp_m = torch.nn.Parameter(data=torch.tensor(1.0), requires_grad=True)
        self.beta_m = torch.nn.Parameter(data=torch.tensor(1.0), requires_grad=True)

        self.adpt_step = adpt_step
        self.line_search = line_search
        self.last_step_size = None
        self.ulim = ulim
        self.ns = 2
        self.nh = nh
        self.n = n_in
        self.m = m_in
        self.last_step = w(torch.zeros(self.n))
        self.state = w(torch.zeros(self.m, self.n, self.ns))
        self.lim=0
        self.H0 = 1.0
        self.tau = w(torch.tensor(tau))
        self.c = w(torch.tensor(c))


    def reset_dim(self, new_n):
        self.n = new_n
        self.state = w(torch.zeros(self.m, self.n, self.ns))
        self.last_step = w(torch.zeros(self.n))
        self.lim = 0

    def reset_horizon(self, new_m):
        self.m = new_m
        self.state = w(torch.zeros(self.m, self.n, self.ns))
        self.lim = 0

    def init_state(self, g):
        aux = w(torch.cat((g.view(1, -1, 1), w(torch.zeros(1, self.n, 1))), 2))
        self.state = w(aux.repeat(self.m, 1, 1))
        self.lim = 0

    def update_state(self, g):
        aux = w(torch.cat((g.view(1, -1, 1), self.last_step.view(1, -1, 1)), 2))
        self.state = w(torch.cat([aux, self.state[:(self.m - 1), :, :]], 0))
        self.lim = min(self.lim+1, self.m)

    def forward(self, g, dfx=None, fx=None):

        q = LBFGS(g.clone(), self.state, self.alp_m, self.beta_m, self.H0, self.lim)

        if self.adpt_step:
            r = g.dot(q)
            eps = 1e-3
            Hd = 1/eps * (dfx(eps*q) - g)
            delta = q.dot(Hd).abs().sqrt()
            q = q * r/((r+delta)*delta)

        if self.line_search:
            v = w(fx(None))
            tau = self.tau
            c = self.c
            a = w(q.norm()*torch.tensor(self.ulim))
            p = q/a
            m = g.dot(p)
            t = c*m
            while(v-fx(-a*p) < a*t):
                a = a*tau
                if a < 0.01:
                    break
            self.last_step_size = ( a/q.norm()).detach()
            q = a*p
        else:
            self.last_step_size = 1

        #print(self.last_step_size)
        self.last_step = -q
        self.update_state(g)
        return self.last_step

    def flush(self):
        self.state = w(torch.zeros(self.m, self.n, self.ns))
        self.last_step = w(torch.zeros(1))


class LearnedLBFGS_pi(torch.nn.Module):
    """ L-BFGS with cvxlayer to calculate step-size. 

        Typical usage:

        my_optim = LearnedLBFGS_pi()
        my_optim.reset_dim(t.n)
        my_optim.init_state(t.df())
        for ...
            d = my_optim.forward(t.df())
            t.w = t.w + d.view(-1)
    """
    def __init__(self, m_in=5, n_in=1, test=False, ulim=1.0,llim=0.001):
        super(LearnedLBFGS_pi, self).__init__()
        nh = 6              # higher space dimension
        self.ns = 2         # n of vectors saved in state
        self.np = 1         # cvx problem dimension
        self.alp_m = 1      # unused scale factor
        self.beta_m = 1     # unused scale factor
        self.ulim = w(torch.tensor(ulim).log())  # upper-limit to tau_k
        self.llim = w(torch.tensor(llim).log())  # lower-limit to tau_k
        self.nf = 3         # number of input features
        self.nh = nh        
        self.test = test    # test or train?
        self.n = n_in
        self.m = m_in


        self.inputlayer = nn.Linear(self.nf,self.nh*2) # Layers 1 and 2 concatenated

        self.cvxpylayer = self.buildCVXLayer(nh, self.np) # cvx layer

        self.last_step_size = None # last t_k

        self.last_step = w(torch.zeros(self.n)) # last t_k*d_k
        self.state = w(torch.zeros(self.m, self.n, self.ns))
        self.lim=0
        self.H0 = 1.0


    def buildCVXLayer(self, nh, np):
        z = cp.Variable(np)
        A = cp.Parameter(nh)
        u = cp.Parameter(nh)
        constraints = [z >= self.llim.cpu(), z <= self.ulim.cpu()]
        objective = cp.Minimize(cp.norm(A * z - u, p=2))
        problem = cp.Problem(objective, constraints)
        assert problem.is_dpp()
        return CvxpyLayer(problem, parameters=[A, u], variables=[z])


    def reset_dim(self, new_n):
        self.n = new_n
        self.state = w(torch.zeros(self.m, self.n, self.ns))
        self.last_step = w(torch.zeros(self.n))
        self.lim = 0

    def reset_horizon(self, new_m):
        self.m = new_m
        self.state = w(torch.zeros(self.m, self.n, self.ns))
        self.lim = 0

    def init_state(self, g):
        aux = torch.cat((g.view(1, -1, 1), w(torch.zeros(1, self.n, 1))), 2)
        self.state = aux.repeat(self.m, 1, 1)
        self.lim = 0

    def update_state(self, g):
        aux = torch.cat((g.view(1, -1, 1), self.last_step.view(1, -1, 1)), 2)
        self.state = torch.cat([aux, self.state[:(self.m - 1), :, :]], 0)
        self.lim = min(self.lim+1, self.m)

    def forward(self, g):
        q = LBFGS(g.clone(), self.state, self.alp_m, self.beta_m, self.H0, self.lim)
        eps = 1e-3 

        aux = q.dot(g)
        feat_step = (
            torch.cat([q.dot(q).log().view(-1), g.dot(g).log().view(-1), (aux).log().view(-1)])) 
        if self.test:
            x, a = self.inputlayer(feat_step).split(self.nh,0)

            t = x.dot(a)/ x.dot(x)

            step = self.ulim if t>self.ulim else self.llim if t<self.llim else t
        else:
            step, = self.cvxpylayer(*self.inputlayer(feat_step).split(self.nh, 0))
        #print(step)
        last_step_size = step.exp().detach()
        self.last_step_size = last_step_size.detach()
        self.last_step = -last_step_size*q
        self.update_state(g)
        #print(self.last_step_size)
        return -step.exp()*q

    def flush(self):
        self.state = w(torch.zeros(self.m, self.n, self.ns))
        self.last_step = w(torch.zeros(1))

    def detach(self):
        self.state.detach_()
        self.last_step.detach_()



