import math

import numpy as np

import torch
import torch.nn.functional as f

from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.optim.mixin import OptimMixin

from manifolds import StiefelT


__all__ = ["LandingStiefelAdam"]


def _check_orthogonal(param):
    if not hasattr(param, "manifold"):
        raise TypeError("Parameter should be a geoopt parameter")
    if not isinstance(
        param.manifold, StiefelT
    ):
        raise TypeError("Parameters should be on the Stiefel manifold")
    *_, p, q = param.shape
    if p > q:
        raise ValueError(
            "The second to last dimension of the parameters should be greater or equal than its last dimension. "
            "Only tall matrices are supported so far"
        )


def _safe_step_size(d, g, lambda_regul, eps_d):
    """Compute the safe step size

    Parameters
    ----------
    d : float
        The distance to the manifold
    g : float
        The norm of the landing update
    lambda_regul : float
        The regularisation parameter
    eps_d : float
        The tolerance: the maximal allowed distance to the manifold
    Return
    ------
    sol : float
        The maximal step-size one can take
    """
    beta = lambda_regul * d * (d-1)
    alpha = g**2
    sol = (- beta + torch.sqrt(beta**2 + alpha * torch.relu(eps_d - d)))/(alpha + 1e-8)
    return sol


def _landing_direction(point, grad, lambda_regul):
    r'''
    Computes the relative gradient, the normal direction, and the distance 
    towards the manifold.
    '''

    *_, p = point.shape

    if torch.is_complex(point):
        XtX = torch.matmul(torch.conj(point).transpose(-1, -2), point)
        GtX = torch.matmul(torch.conj(grad).transpose(-1, -2), point)
    else:
        XtX = torch.matmul(point.transpose(-1, -2), point)
        GtX = torch.matmul(grad.transpose(-1, -2), point)
    distance = XtX - torch.eye(p, device=point.device)

    # Note, that if we didn't need to know the rel_grad and distance norm, 
    # this could be further sped up by doing point@(GtX-lam*distance)
    rel_grad = .5*(torch.matmul(grad, XtX) - torch.matmul(point, GtX))  
    norm_dir = lambda_regul * torch.matmul(point, distance)
    # Distance norm for _safe_step_size computation
    distance_norm = torch.norm(distance, dim=(-1, -2), keepdim=True)

    return (rel_grad, norm_dir, distance_norm)


class LandingStiefelAdam(OptimMixin, torch.optim.Optimizer):
    r"""
    Landing algorithm on the Stiefel manifold with the same API as
    :class:`torch.optim.SGD`.

    Parameters
    ----------
    params : iterable
        iterable of parameters to optimize or dicts defining
        parameter groups. Must contain square orthogonal matrices.
    lr : float
        learning rate
    momentum : float (optional)
        momentum factor (default: 0)
    weight_decay : float (optional)
        weight decay (L2 penalty) (default: 0)
    dampening : float (optional)
        dampening for momentum (default: 0)
    nesterov : bool (optional)
        enables Nesterov momentum (default: False)
    lambda_regul : float (optional)
        the hyperparameter lambda that controls the tradeoff between
        optimization in f and landing speed (default: 1.)
    check_type : bool (optional)
        whether to check that the parameters are all orthogonal matrices

    Other Parameters
    ----------------
    stabilize : int
        Stabilize parameters if they are off-manifold due to numerical
        reasons every ``stabilize`` steps (default: ``None`` -- no stabilize)
    """

    def __init__(
        self,
        params,
        lr,
        betas=(0.9, 0.999),
        eps=1e-08,
        weight_decay=0,
        stabilize=None,
        lambda_regul=0,
        normalize_columns=False,
        safe_step=0.5,
        check_type=False,
        use_vr=False,
        dproject=False,  # Whether to project back to the manifold based on safe_step
    ):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if weight_decay < 0.0:
            raise ValueError(
                "Invalid weight_decay value: {}".format(weight_decay)
            )
        if lambda_regul < 0.0:
            raise ValueError(
                "Invalid lambda_regul value: {}".format(lambda_regul)
            )
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            lambda_regul=lambda_regul,
            normalize_columns=normalize_columns,
            safe_step=safe_step,
            check_type=check_type,
            use_vr=use_vr,
            dproject=dproject,
        )

        super().__init__(params, defaults, stabilize=stabilize)

    def step(self, closure=None):
        loss = None
        with torch.no_grad():
            for group in self.param_groups:
                if "step" not in group:
                    group["step"] = 0
                weight_decay = group["weight_decay"]
                betas = group["betas"]
                eps = group["eps"]
                learning_rate = group["lr"]
                lambda_regul = group["lambda_regul"]
                normalize_columns = group["normalize_columns"]
                safe_step = group["safe_step"]
                check_type = group["check_type"]
                group["step"] += 1
                for point_ind, point_ in enumerate(group["params"]):
                    if check_type:
                        _check_orthogonal(point_)
                    if point_.grad is None:
                        continue
                    point = point_.transpose(-1, -2)
                    grad = point_.grad.transpose(-1, -2)
                    
                    if grad.is_sparse:
                        raise RuntimeError(
                            "LandingSGD does not support sparse gradients."
                        )
                    state = self.state[point_]

                    # State initialization
                    if group['step'] == 1:
                        state["exp_avg"] = torch.zeros_like(grad)
                        state["exp_avg_sq"] = torch.zeros(*grad.shape[:-2], device=grad.device, dtype=grad.dtype)

                    grad.add_(point, alpha=weight_decay)
                    rho_inf = 1 / (1 - betas[1]) - 1

                    # If orthogonalization is applied
                    if lambda_regul>0:
                        assert len(point.shape) >= 2, 'Cannot optimize scalars'

                        adam_grad = grad

                        exp_avg = state["exp_avg"]
                        exp_avg_sq = state["exp_avg_sq"]

                        exp_avg.lerp_(adam_grad, 1 - betas[0])
                        exp_avg_sq.mul_(betas[1]).add_(torch.norm(adam_grad, p=2, dim=(-1, -2)).pow(2), alpha=1 - betas[1])  # Vector Adam

                        bias_correction1 = 1 - betas[0] ** group['step']
                        bias_correction2 = 1 - betas[1] ** group['step']
                        bias_correction2_sqrt = bias_correction2 ** 0.5
                        rho_t = rho_inf - 2 * group['step'] * betas[1]**group['step'] / bias_correction2

                        step_size = 1. / bias_correction1

                        if rho_t > 5:
                            denom = (
                                exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size)
                            ).add_(eps / step_size)
                            denom.mul_(math.sqrt((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t)))
                            denom = denom.view(-1, 1, 1)
                        else:
                            denom = 1. / step_size

                        adam_grad = exp_avg / denom
                        rel_grad, normal_dir, distance_norm = _landing_direction(point, adam_grad, lambda_regul)
       
                        lr = learning_rate
                        mask = torch.zeros_like(distance_norm.squeeze()).bool()
                        if safe_step:
                            d = distance_norm

                            if group["dproject"] and (d > safe_step).any():
                                if not isinstance(point_, (ManifoldParameter, ManifoldTensor)):
                                    continue

                                mask = (d > safe_step).squeeze()
                                point[mask] = point_.manifold.projx(point_[mask]).transpose(-1, -2)
                                exp_avg[mask] = point_.manifold.proju(point[mask].transpose(-1, -2), exp_avg[mask].transpose(-1, -2)).transpose(-1, -2)
                               
                            g = torch.linalg.norm(rel_grad + normal_dir, dim=(-1, -2), keepdim=True)
                            max_step = _safe_step_size(d, g, lambda_regul, safe_step)
                            assert lr <= 1/(2.*lambda_regul)
                            lr = torch.clip(max_step, max=lr)
                        
                        # Take the step with orthogonalization
                        if mask.sum() > 0:
                            new_point = point.clone()
                            new_point[~mask] -= lr[~mask] * (rel_grad[~mask] + normal_dir[~mask])
                        else:
                            new_point = point - lr * (rel_grad + normal_dir)

                        if normalize_columns:
                            f.normalize(new_point, p=2, dim=-2, out=new_point)
                        if (group["stabilize"] is not None and 
                            group["step"] % group["stabilize"] == 0):
                            self.stabilize_group(group)
                    else:   
                        # Take the step without orthogonalization
                        new_point = point - learning_rate * grad
                        raise Exception("NO")

                    new_point = new_point.transpose(-1, -2)
                    point_.copy_(new_point)
        return loss
    

    @torch.no_grad()
    def stabilize_group(self, group):
        for p in group["params"]:
            if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
                continue
            manifold = p.manifold
            momentum = group["momentum"]
            p.copy_(manifold.projx(p))
            if momentum > 0:
                param_state = self.state[p]
                if not param_state:  # due to None grads
                    continue
                if "momentum_buffer" in param_state:
                    buf = param_state["momentum_buffer"]
                    buf.copy_(manifold.proju(p, buf))





if __name__ == '__main__':
    m, p = 10, 5
    n = 200
    batch_size = 7
    data = torch.randn(n, m)
    x = torch.randn(m, p)
    x = torch.nn.Parameter(x)
    n_epochs = 3
    optim = LandingStiefelSGD(x, lr=0.1, lambda_regul=1.)
    optim.zero_grad()

    n_batch = math.ceil(n // batch_size)
    for i in range(n_epochs):
        randperm = np.random.permutation(n_batch)
        for idx in randperm:
            batch = slice(idx * batch_size, (idx + 1) * batch_size)
            x = data[slice]
            this_size, _ = x.shape
            loss = -torch.linalg.norm(data.mm(x)) ** 2 / this_size
            loss.backward()
            optim.step()
    
