import torch
import numpy as np 
import pickle
from termcolor import colored

import logging

logger = logging.getLogger()


def _safe_norm(v):
    if not torch.isfinite(v).all():
        return np.inf
    return torch.norm(v)


def scalar_search_armijo(phi, phi0, derphi0, c1=1e-4, alpha0=1, amin=0):
    ite = 0
    phi_a0 = phi(alpha0)    # First do an update with step size 1
    if phi_a0 <= phi0 + c1*alpha0*derphi0:
        return alpha0, phi_a0, ite

    # Otherwise, compute the minimizer of a quadratic interpolant
    alpha1 = -(derphi0) * alpha0**2 / 2.0 / (phi_a0 - phi0 - derphi0 * alpha0)
    phi_a1 = phi(alpha1)

    # Otherwise loop with cubic interpolation until we find an alpha which
    # satisfies the first Wolfe condition (since we are backtracking, we will
    # assume that the value of alpha is not too small and satisfies the second
    # condition.
    while alpha1 > amin:       # we are assuming alpha>0 is a descent direction
        factor = alpha0**2 * alpha1**2 * (alpha1-alpha0)
        a = alpha0**2 * (phi_a1 - phi0 - derphi0*alpha1) - \
            alpha1**2 * (phi_a0 - phi0 - derphi0*alpha0)
        a = a / factor
        b = -alpha0**3 * (phi_a1 - phi0 - derphi0*alpha1) + \
            alpha1**3 * (phi_a0 - phi0 - derphi0*alpha0)
        b = b / factor

        alpha2 = (-b + torch.sqrt(torch.abs(b**2 - 3 * a * derphi0))) / (3.0*a)
        phi_a2 = phi(alpha2)
        ite += 1

        if (phi_a2 <= phi0 + c1*alpha2*derphi0):
            return alpha2, phi_a2, ite

        if (alpha1 - alpha2) > alpha1 / 2.0 or (1 - alpha2/alpha1) < 0.96:
            alpha2 = alpha1 / 2.0

        alpha0 = alpha1
        alpha1 = alpha2
        phi_a0 = phi_a1
        phi_a1 = phi_a2

    # Failed to find a suitable step length
    return None, phi_a1, ite


def line_search(update, x0, g0, g, nstep=0, on=True):
    """
    `update` is the propsoed direction of update.
    Code adapted from scipy.
    """
    tmp_s = [0]
    tmp_g0 = [g0]
    tmp_phi = [torch.norm(g0)**2]
    s_norm = torch.norm(x0) / torch.norm(update)

    def phi(s, store=True):
        if s == tmp_s[0]:
            return tmp_phi[0]    # If the step size is so small... just return something
        x_est = x0 + s * update
        g0_new = g(x_est)
        phi_new = _safe_norm(g0_new)**2
        if store:
            tmp_s[0] = s
            tmp_g0[0] = g0_new
            tmp_phi[0] = phi_new
        return phi_new
    
    if on:
        s, phi1, ite = scalar_search_armijo(phi, tmp_phi[0], -tmp_phi[0], amin=1e-2)
    if (not on) or s is None:
        s = 1.0
        ite = 0

    x_est = x0 + s * update
    if s == tmp_s[0]:
        g0_new = tmp_g0[0]
    else:
        g0_new = g(x_est)
    return x_est, g0_new, x_est - x0, g0_new - g0, ite

def rmatvec(part_Us, part_VTs, x):
    # Compute x^T(-I + UV^T)
    # x: (N, d)
    # part_Us: (N, d, threshold)
    # part_VTs: (N, threshold, d)
    if part_Us.nelement() == 0:
        return -x
    xTU = torch.einsum('bi, bij -> bj', x, part_Us) # (N, threshold)
    return -x + torch.einsum('bj, bji -> bi', xTU, part_VTs) # (N, d)


def matvec(part_Us, part_VTs, x):
    # Compute (-I + UV^T)x
    # x: (N, d)
    # part_Us: (N, d, threshold)
    # part_VTs: (N, threshold, d)
    if part_Us.nelement() == 0:
        return -x
    VTx = torch.einsum('bji, bi -> bj', part_VTs, x) # (N, threshold)
    return -x + torch.einsum('bij, bj -> bi', part_Us, VTx) # (N, d)


def broyden(g_, x0, threshold, eps, ls=False, name="unknown"):
    # LBFGS_thres = min(threshold, 20)
    LBFGS_thres = threshold

    x0_shape = x0.shape
    x0 = x0.view(x0_shape[0], -1)
    
    bsz, total_hsize = x0.size()
    eps = eps * np.sqrt(np.prod(x0.shape))

    def g(x):
        return g_(x.view(x0_shape)).view(bsz, -1)
    
    x_est = x0           # (bsz, d)
    gx = g(x_est)        # (bsz, d)
    nstep = 0
    tnstep = 0
    
    # For fast calculation of inv_jacobian (approximately)
    Us = torch.zeros(bsz, total_hsize, LBFGS_thres).to(x0)
    VTs = torch.zeros(bsz, LBFGS_thres, total_hsize).to(x0)
    update = -gx
    new_objective = init_objective = torch.norm(gx).item()
    prot_break = False
    trace = [init_objective]
    
    # To be used in protective breaks
    protect_thres = 1e6
    lowest = new_objective
    lowest_xest, lowest_gx, lowest_step = x_est, gx, nstep
    while new_objective >= eps and nstep < threshold:
        x_est, gx, delta_x, delta_gx, ite = line_search(update, x_est, gx, g, nstep=nstep, on=ls)
        nstep += 1
        tnstep += (ite+1)
        new_objective = torch.norm(gx).item()
        trace.append(new_objective)
        if new_objective < lowest:
            lowest_xest, lowest_gx = x_est.clone().detach(), gx.clone().detach()
            lowest = new_objective
            lowest_step = nstep
        if new_objective < eps:
            break
        if new_objective < 3*eps and nstep == threshold and np.max(trace[-threshold:]) / np.min(trace[-threshold:]) < 1.3:
            logger.info('Iterations exceeded 30 for broyden')
            # if there's hardly been any progress in the last 30 steps
            break
        if new_objective > init_objective * protect_thres:
            logger.info('Broyden failed')
            prot_break = True
            break

        part_Us, part_VTs = Us[:,:,:(nstep-1) % LBFGS_thres], VTs[:,:(nstep-1) % LBFGS_thres]
        vT = rmatvec(part_Us, part_VTs, delta_x) # (N, d)
        u = (delta_x - matvec(part_Us, part_VTs, delta_gx)) / torch.einsum('bi, bi -> b', vT, delta_gx)[:,None]
        vT[vT != vT] = 0
        u[u != u] = 0
        VTs[:,(nstep-1) % LBFGS_thres] = vT
        Us[:,:,(nstep-1) % LBFGS_thres] = u
        update = -matvec(Us[:,:,:nstep], VTs[:,:nstep], gx)

    Us, VTs = None, None
    return {"result": lowest_xest.view(x0_shape),
            "nstep": nstep,
            "tnstep": tnstep,
            "lowest_step": lowest_step,
            "diff": torch.norm(lowest_gx).item(),
            "diff_detail": torch.norm(lowest_gx, dim=1),
            "prot_break": prot_break,
            "trace": trace,
            "eps": eps,
            "threshold": threshold}


def analyze_broyden(res_info, err=None, judge=True, name='forward', training=True, save_err=True):
    """
    For debugging use only :-)
    """
    res_est = res_info['result']
    nstep = res_info['nstep']
    diff = res_info['diff']
    diff_detail = res_info['diff_detail']
    prot_break = res_info['prot_break']
    trace = res_info['trace']
    eps = res_info['eps']
    threshold = res_info['threshold']
    if judge:
        return nstep >= threshold or (nstep == 0 and (diff != diff or diff > eps)) or prot_break or torch.isnan(res_est).any()
    
    assert (err is not None), "Must provide err information when not in judgment mode"
    prefix, color = ('', 'red') if name == 'forward' else ('back_', 'blue')
    eval_prefix = '' if training else 'eval_'
    
    # Case 1: A nan entry is produced in Broyden
    if torch.isnan(res_est).any():
        msg = colored(f"WARNING: nan found in Broyden's {name} result. Diff: {diff}", color)
        print(msg)
        if save_err: pickle.dump(err, open(f'{prefix}{eval_prefix}nan.pkl', 'wb'))
        return (1, msg, res_info)
        
    # Case 2: Unknown problem with Broyden's method (probably due to nan update(s) to the weights)
    if nstep == 0 and (diff != diff or diff > eps):
        msg = colored(f"WARNING: Bad Broyden's method {name}. Why?? Diff: {diff}. STOP.", color)
        print(msg)
        if save_err: pickle.dump(err, open(f'{prefix}{eval_prefix}badbroyden.pkl', 'wb'))
        return (2, msg, res_info)
        
    # Case 3: Protective break during Broyden (so that it does not diverge to infinity)
    if prot_break:
        msg = colored(f"WARNING: Hit Protective Break in {name}. Diff: {diff}. Total Iter: {len(trace)}", color)
        print(msg)
        if save_err: pickle.dump(err, open(f'{prefix}{eval_prefix}prot_break.pkl', 'wb'))
        return (3, msg, res_info)
        
    return (-1, '', res_info)