import warnings
import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional, Tuple, Union

from hyptorch import pmath


def hopfield_core_forward(query,  # type: Tensor
                          key,  # type: Tensor
                          embed_dim_to_check,  # type: int
                          num_heads,  # type: int
                          in_proj_weight,  # type: Optional[Tensor]
                          in_proj_bias,  # type: Optional[Tensor]
                          bias_k,  # type: Optional[Tensor]
                          add_zero_attn,  # type: bool
                          dropout_p,  # type: float
                          out_proj_weight,  # type: Tensor
                          out_proj_bias,  # type: Tensor
                          training=True,  # type: bool
                          key_padding_mask=None,  # type: Optional[Tensor]
                          need_weights=True,  # type: bool
                          attn_mask=None,  # type: Optional[Tensor]
                          use_separate_proj_weight=False,  # type: bool
                          q_proj_weight=None,  # type: Optional[Tensor]
                          k_proj_weight=None,  # type: Optional[Tensor]
                          static_k=None,  # type: Optional[Tensor]

                          key_as_static=False,  # type: bool
                          query_as_static=False,  # type: bool
                          head_dim=None,  # type: Optional[int]
                          pattern_dim=None,  # type: Optional[int]
                          update_steps_max=0,  # type: Optional[Union[int, Tensor]]
                          update_steps_eps=1e-4,  # type: Union[float, Tensor]
                          return_raw_associations=False,  # type: bool
                          c=1.0,  # curvature (>0)
                          theta=1.0,  # temperature
                          lr=0.001  # η: step-size in CCCP update
                          ):
    # type: (...) -> Tuple[Tensor, Optional[Tensor]]
    if not torch.jit.is_scripting():
        tens_ops = (query, key, in_proj_weight, in_proj_bias, bias_k,
                    out_proj_weight, out_proj_bias)
        if any([type(t) is not Tensor for t in tens_ops]) and nn.functional.has_torch_function(tens_ops):
            return nn.functional.handle_torch_function(
                hopfield_core_forward, tens_ops, query, key,
                embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
                bias_k, add_zero_attn, dropout_p, out_proj_weight,
                out_proj_bias, training=training, key_padding_mask=key_padding_mask,
                need_weights=need_weights, attn_mask=attn_mask,
                use_separate_proj_weight=use_separate_proj_weight,
                q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, static_k=static_k,
                key_as_static=key_as_static, query_as_static=query_as_static,
                head_dim=head_dim, pattern_dim=pattern_dim, update_steps_max=update_steps_max,
                update_steps_eps=update_steps_eps, return_raw_associations=return_raw_associations,
                c=c, theta=theta, lr=lr)

    tgt_len, bsz, embed_dim = query.shape[0], key.shape[1], query.shape[2]
    assert embed_dim == embed_dim_to_check

    assert (update_steps_max is None) or (type(update_steps_max) in (int, torch.Tensor))
    if type(update_steps_max) == torch.Tensor:
        assert update_steps_max.ndimension() == 1 and update_steps_max.shape[0] == num_heads
    elif type(update_steps_max) == int:
        update_steps_max = torch.tensor([update_steps_max] * num_heads, dtype=torch.int32, device=query.device)
    elif update_steps_max is None:
        update_steps_max = -torch.ones(size=(num_heads,), dtype=torch.int32, device=query.device)

    assert type(update_steps_eps) in (float, torch.Tensor)
    if type(update_steps_eps) == torch.Tensor:
        assert update_steps_eps.ndimension() == 1 and update_steps_eps.shape[0] == num_heads
        assert (update_steps_eps <= 0.0).sum() == 0
        update_steps_eps = update_steps_eps.to(device=query.device)
    elif type(update_steps_eps) == float:
        assert update_steps_eps > 0
        update_steps_eps = torch.tensor([update_steps_eps] * num_heads, dtype=query.dtype, device=query.device)

    if head_dim is None:
        head_dim = embed_dim // num_heads
        assert head_dim * num_heads == embed_dim
    hopfield_dim = num_heads * head_dim

    if pattern_dim is None:
        pattern_dim = head_dim

    q, k, p_bi, src_len = None, None, None, 0
    update_step, q_old = 0, None
    update_active_heads = torch.tensor([[[True]]] * num_heads * bsz, device=query.device)

    # ============================ iteration ============================
    while update_active_heads.any():
        if update_step == 0:
            # project inputs (Mobius linear) at step 0
            if not use_separate_proj_weight:
                if torch.equal(query, key) and not (key_as_static or query_as_static):
                    combined = mobius_linear(query, in_proj_weight, in_proj_bias, c=c)
                    q, k = combined.chunk(2, dim=-1)
                else:
                    _start, _end = 0, hopfield_dim
                    if query_as_static:
                        q = query.repeat(1, num_heads, 1)
                    else:
                        _b = in_proj_bias
                        _w = in_proj_weight[_start:_end, :]
                        if _b is not None:
                            _b = _b[_start:_end]
                        q = mobius_linear(query, _w, _b, c=c)
                        _start += hopfield_dim
                        _end += hopfield_dim

                    if key_as_static:
                        k = key.repeat(1, num_heads, 1)
                    else:
                        _b = in_proj_bias
                        _w = in_proj_weight[_start:_end, :]
                        if _b is not None:
                            _b = _b[_start:_end]
                        k = mobius_linear(key, _w, _b, c=c)
                        _start += hopfield_dim
                        _end += hopfield_dim
            else:
                _start, _end = 0, hopfield_dim
                if query_as_static:
                    q = query.repeat(1, num_heads, 1)
                else:
                    q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
                    if in_proj_bias is not None:
                        q = mobius_linear(query, q_proj_weight_non_opt, in_proj_bias[_start:_end], c=c)
                        _start += hopfield_dim
                        _end += hopfield_dim
                    else:
                        q = mobius_linear(query, q_proj_weight_non_opt, in_proj_bias, c=c)
                if key_as_static:
                    k = key.repeat(1, num_heads, 1)
                else:
                    k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
                    _bias = None if in_proj_bias is None else in_proj_bias[_start:_end]
                    k = mobius_linear(key, k_proj_weight_non_opt, _bias, c=c)
                    _start += hopfield_dim
                    _end += num_heads * pattern_dim

            if attn_mask is not None:
                assert attn_mask.dtype in (torch.float32, torch.float64, torch.float16, torch.uint8, torch.bool)
                if attn_mask.dtype == torch.uint8:
                    warnings.warn("Byte tensor for attn_mask is deprecated. Use bool tensor instead.")
                    attn_mask = attn_mask.to(torch.bool)
                if attn_mask.dim() == 2:
                    attn_mask = attn_mask.unsqueeze(0)
                    if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
                        raise RuntimeError('Bad 2D attn_mask size.')
                elif attn_mask.dim() == 3:
                    if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
                        raise RuntimeError('Bad 3D attn_mask size.')
                else:
                    raise RuntimeError("Unsupported attn_mask dim")

            if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
                warnings.warn("Byte tensor for key_padding_mask is deprecated. Use bool.")
                key_padding_mask = key_padding_mask.to(torch.bool)

            if bias_k is not None:
                if static_k is None and not key_as_static:
                    k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
                    if attn_mask is not None:
                        attn_mask = nn.functional.pad(attn_mask, [0, 1])
                    if key_padding_mask is not None:
                        key_padding_mask = nn.functional.pad(key_padding_mask, [0, 1])
                else:
                    assert static_k is None and (not key_as_static)
            else:
                assert bias_k is None

            q = q.contiguous().view(tgt_len, -1, head_dim).transpose(0, 1)                   # (B*H, T, D)
            if k is not None:
                k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)       # (B*H, S, D)
            if static_k is not None:
                k = static_k
            src_len = k.size(1)

            if key_padding_mask is not None:
                assert key_padding_mask.size(0) == bsz and key_padding_mask.size(1) == src_len
            if add_zero_attn:
                src_len += 1
                k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
                if attn_mask is not None:
                    attn_mask = nn.functional.pad(attn_mask, [0, 1])
                if key_padding_mask is not None:
                    key_padding_mask = nn.functional.pad(key_padding_mask, [0, 1])

        # ---------- similarity & softmax weights ----------
        val_bi, grad_val_bi = hyperbolic_distance_val_and_grad(q, k, c=c)  # val=-cosh d; grad wrt q
        max_v, _ = torch.max(val_bi, dim=2, keepdim=True)                  # stability
        ex = torch.exp(theta * (val_bi - max_v))
        p_bi = ex / ex.sum(dim=2, keepdim=True)                            # (B*H, T, S)

        # ---------- energy with INTRINSIC regularizer ----------
        # E = - (1/theta) logsumexp(theta*val) + 1/2 * d(q, 0)^2
        zeros = torch.zeros_like(q)
        d0 = pmath.dist(q, zeros, c=c)                                     # (B*H, T)
        E = (-torch.logsumexp(theta * val_bi, dim=2).mean() / theta) + 0.5 * (d0 ** 2).mean()

        # ---------- grad of concave part, Euclidean coords ----------
        # grad_euc = ∇ E_cave = - Σ_i p_i ∂ val / ∂ q
        neg_p = -p_bi.unsqueeze(-1)                                        # (B*H, T, S, 1)
        grad_all = neg_p * grad_val_bi                                     # (B*H, T, S, D)
        grad_euc = grad_all.sum(dim=2)                                     # (B*H, T, D)

        # ---------- convert to Riemannian grad at q ----------
        # conformal factor: gamma = ( (1 - c||q||^2)/2 )^2
        r2 = torch.sum(q * q, dim=-1, keepdim=True)                        # (B*H, T, 1)
        gamma = ((1.0 - c * r2) / 2.0) ** 2
        rgrad = gamma * grad_euc                                           # (B*H, T, D)

        # ---------- CCCP update with intrinsic regularizer ----------
        # v^{(t)} = - grad_cave (Riemannian); PT_{q->0}(v) = (λ_q/2)*v = v / (1 - c||q||^2)
        v_cur = -rgrad                                                     # (B*H, T, D)
        lam = 2.0 / (1.0 - c * r2).clamp_min(1e-6)                         # (B*H, T, 1)
        v0 = (lam / 2.0) * v_cur                                           # (B*H, T, D)
        # step at the base point (origin)
        origin = torch.zeros_like(q)
        q = pmath.expmap(origin, lr * v0, c=c)                             # (B*H, T, D)
        q = pmath.project(q, c=c)

        # ---------- stopping rule ----------
        with torch.no_grad():
            q_active = q.view(size=(bsz, num_heads, tgt_len, head_dim))
            update_active_heads = (update_step < update_steps_max) | (update_steps_max < 0)
            if q_old is not None:
                update_active_heads &= ((q_old - q_active).norm(p=2, dim=(2, 3)).max(axis=0)[0]) > update_steps_eps
            update_active_heads = update_active_heads.unsqueeze(1).unsqueeze(2).repeat(repeats=(bsz, 1, 1))
            q_old = q_active

        update_step += 1
    # ============================ end iteration ============================

    q_out = q.transpose(0, 1).contiguous().view(tgt_len, bsz, -1)
    if out_proj_weight is not None:
        assert q_out.shape[2] == num_heads * pattern_dim
        q_out = mobius_linear(q_out, out_proj_weight, out_proj_bias, c=c)

    p_bi = p_bi.view(bsz, num_heads, tgt_len, src_len) if return_raw_associations else None
    if need_weights:
        attn_output_weights = nn.functional.dropout(p_bi, p=dropout_p, training=training)
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        return q_out, attn_output_weights.sum(dim=1) / num_heads, p_bi
    else:
        return q_out, None, p_bi


def mobius_linear(x: torch.Tensor,
                  weight: torch.Tensor,
                  bias: Optional[torch.Tensor] = None,
                  c: float = 1.0) -> torch.Tensor:
    mv = pmath.mobius_matvec(weight, x, c=c)
    if bias is not None:
        return pmath.project(pmath.mobius_add(mv, bias), c=c)
    else:
        return pmath.project(mv, c=c)


def hyperbolic_distance_grad_q(q: torch.Tensor,
                               k: torch.Tensor,
                               dist_bts: torch.Tensor,
                               c: float
                               ) -> torch.Tensor:
    B, T, D = q.shape
    _, S, _ = k.shape
    q_expand = q.unsqueeze(2)            # (B,T,1,D)
    k_expand = k.unsqueeze(1)            # (B,1,S,D)
    diff = q_expand - k_expand           # (B,T,S,D)
    diff_sq = diff.pow(2).sum(dim=-1)    # (B,T,S)

    q_norm_sq = q.pow(2).sum(dim=-1, keepdim=True)        # (B,T,1)
    k_norm_sq = k.pow(2).sum(dim=-1, keepdim=True)        # (B,S,1)
    k_norm_sq = k_norm_sq.transpose(1, 2)                 # (B,1,S)

    denom = (1.0 - c * q_norm_sq) * (1.0 - c * k_norm_sq) # (B,T,S)
    Arg = 1.0 + 2.0 * c * diff_sq / denom                 # (B,T,S)

    partial_acosh_wrt_Arg = 1.0 / torch.sqrt(Arg * Arg - 1.0)
    partial_dist_c_wrt_Arg = (1.0 / (c ** 0.5)) * partial_acosh_wrt_Arg

    partialArg_wrt_diff = 2.0 * c / denom
    partial_diff_sq_wrt_q = 2.0 * diff
    partialArg_wrt_denom = -2.0 * c * diff_sq / denom.pow(2)
    partial_denom_wrt_q = (1.0 - c * k_norm_sq).unsqueeze(-1) * (-2.0 * c * q.unsqueeze(2))

    partA = partialArg_wrt_diff.unsqueeze(-1) * partial_diff_sq_wrt_q
    partB = partialArg_wrt_denom.unsqueeze(-1) * partial_denom_wrt_q
    partialArg_wrt_q = partA + partB

    partial_d_c_wrt_q = partial_dist_c_wrt_Arg.unsqueeze(-1) * partialArg_wrt_q
    return partial_d_c_wrt_q


def hyperbolic_distance_val_and_grad(q: torch.Tensor,
                                     k: torch.Tensor,
                                     c: float) -> (torch.Tensor, torch.Tensor):
    dist_bts = pmath.dist(q.unsqueeze(2), k.unsqueeze(1), c=c)  # (B,T,S)
    val_bts = -torch.cosh(dist_bts)                             # (B,T,S)
    neg_sinh_d = -torch.sinh(dist_bts)                          # (B,T,S)
    partial_d_wrt_q = hyperbolic_distance_grad_q(q, k, dist_bts, c)
    grad_val_wrt_q = neg_sinh_d.unsqueeze(-1) * partial_d_wrt_q
    return val_bts, grad_val_wrt_q
