import torch
import torch.nn as nn

from torch import Tensor
from torch.nn import Linear, Module, Parameter
from typing import Optional, Tuple

from hyptorch import pmath
from .functional import hopfield_core_forward

try:
    from torch.nn.modules.linear import _LinearWithBias
except ImportError:
    _LinearWithBias = None


class HopfieldCore(Module):
    r"""Hopfield core with hyperbolic geometry support.

    Args:
        embed_dim: total dimension of the model.
        num_heads: number of heads.
        dropout: dropout on attention weights.
        bias: add bias.
        add_bias_k: add bias to key at dim=0.
        add_zero_attn: add a zero key at dim=1.
        kdim: key dim; default: embed_dim.
        head_dim: head dim; default: embed_dim // num_heads.
        pattern_dim: value/pattern dim per head; default: head_dim.
        out_dim: output projection dim; default: embed_dim.
        disable_out_projection: if True, skip out projection.
        key_as_static, query_as_static: use provided tensors as already projected.
        theta: softmax temperature used inside functional.
        lr: Riemannian step-size used inside functional.
        c: **Tensor** curvature handle for the Poincaré ball (manifold curvature = -c).
           Pass the SAME Tensor from the outer module so everyone shares one parameter.
    """

    __annotations__ = {
        'bias_k': torch._jit_internal.Optional[torch.Tensor],
    }

    def __init__(self,
                 embed_dim: Optional[int] = None,
                 num_heads: int = 1,
                 dropout: float = 0.0,
                 bias: bool = True,
                 add_bias_k: bool = False,
                 add_zero_attn: bool = False,
                 kdim: Optional[int] = None,

                 head_dim: Optional[int] = None,
                 pattern_dim: Optional[int] = None,
                 out_dim: Optional[int] = None,
                 disable_out_projection: bool = False,
                 key_as_static: bool = False,
                 query_as_static: bool = False,
                 theta: float = 1.0,
                 lr: float = 1e-3,
                 c: Optional[Tensor] = None,   # <--- IMPORTANT: expect Tensor handle here
                 ):
        super().__init__()

        # Keep c as a Tensor handle (shared with outer modules)
        if c is None:
            raise ValueError("HopfieldCore requires a Tensor 'c' handle (shared curvature).")
        if not isinstance(c, torch.Tensor):
            c = torch.tensor(float(c))
        if torch.any(c <= 0):
            raise ValueError("Curvature parameter c must be positive (manifold curvature = -c).")
        self.c = c                 # <-- DO NOT cast to float; keep Tensor for autograd sharing
        self.theta = float(theta)
        self.lr = float(lr)

        assert (type(key_as_static) == bool) and (type(query_as_static) == bool)
        self.key_as_static, self.query_as_static = key_as_static, query_as_static
        num_non_static = 2 - (self.key_as_static + self.query_as_static)
        assert 0 <= num_non_static < 3

        self.disable_out_projection = disable_out_projection

        # In case of static-only execution, check.
        self.static_execution = self._check_execution_mode()
        if self.static_execution:
            embed_dim, kdim = None, None
        if embed_dim is None:
            assert self.static_execution, r'static-only execution requires all projections to be deactivated.'

        # Set dims.
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim

        self._qkv_same_embed_dim = all((self.kdim == embed_dim, pattern_dim is None))

        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = None
        self.pattern_dim = pattern_dim
        self.virtual_hopfield_dim = None
        self.virtual_pattern_dim = None
        if not self.static_execution:
            if head_dim is None:
                self.head_dim = embed_dim // num_heads
                assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads."
            else:
                assert head_dim > 0, "dimension of the association space has to be positive."
                self.head_dim = head_dim
            if self.pattern_dim is None:
                self.pattern_dim = self.head_dim
            self.virtual_hopfield_dim = self.num_heads * self.head_dim
            self.virtual_pattern_dim = self.num_heads * self.pattern_dim

        self.out_dim = embed_dim if out_dim is None else out_dim
        assert disable_out_projection or (self.out_dim > 0), "output projection dimension has to be positive."

        # Projections
        if self._qkv_same_embed_dim is False:
            if query_as_static:
                self.register_parameter('q_proj_weight', None)
            else:
                self.q_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, embed_dim))
            if key_as_static:
                self.register_parameter('k_proj_weight', None)
            else:
                self.k_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, self.kdim))
            self.register_parameter('in_proj_weight', None)
        else:
            if num_non_static > 0:
                self.in_proj_weight = Parameter(torch.empty(
                    (not query_as_static) * self.virtual_hopfield_dim +
                    (not key_as_static) * self.virtual_hopfield_dim, embed_dim))
            else:
                self.register_parameter('in_proj_weight', None)
            self.register_parameter('q_proj_weight', None)
            self.register_parameter('k_proj_weight', None)

        if bias and (num_non_static > 0):
            self.in_proj_bias = Parameter(torch.empty(
                (not query_as_static) * self.virtual_hopfield_dim +
                (not key_as_static) * self.virtual_hopfield_dim))
        else:
            self.register_parameter('in_proj_bias', None)

        if disable_out_projection:
            self.register_parameter('out_proj', None)
        else:
            if bias and _LinearWithBias is not None:
                self.out_proj = _LinearWithBias(self.virtual_pattern_dim, self.out_dim)
            else:
                self.out_proj = Linear(self.virtual_pattern_dim, self.out_dim, bias=bias)

        self.bias_k = None
        if add_bias_k:
            if not key_as_static:
                self.bias_k = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim))
            assert not (self.bias_k is None), r'cannot set key bias if both are static.'

        self.add_zero_attn = add_zero_attn
        self.reset_parameters()

    def _check_execution_mode(self) -> bool:
        return all((self.key_as_static, self.query_as_static, self.disable_out_projection))

    def reset_parameters(self):
        if self._qkv_same_embed_dim and (self.in_proj_weight is not None):
            nn.init.normal_(self.in_proj_weight, mean=0.0, std=0.02)
        else:
            if self.q_proj_weight is not None:
                nn.init.normal_(self.q_proj_weight, mean=0.0, std=0.02)
            if self.k_proj_weight is not None:
                nn.init.normal_(self.k_proj_weight, mean=0.0, std=0.02)
        if self.in_proj_bias is not None:
            nn.init.constant_(self.in_proj_bias, 0.0)
        if self.bias_k is not None:
            nn.init.normal_(self.bias_k, mean=0.0, std=0.02)

        if not self.disable_out_projection:
            nn.init.normal_(self.out_proj.weight, mean=0.0, std=0.02)
            if self.out_proj.bias is not None:
                nn.init.constant_(self.out_proj.bias, 0.0)

    def __setstate__(self, state):
        super(HopfieldCore, self).__setstate__(state)

    def forward(self,
                query: Tensor,
                key: Tensor,
                key_padding_mask: Optional[Tensor] = None,
                need_weights: bool = True,
                attn_mask: Optional[Tensor] = None,

                scaling: Optional[Tensor] = None,
                update_steps_max: Optional[int] = 0,
                update_steps_eps: float = 1e-4,
                return_raw_associations: bool = False,
                ):
        # Type / shape checks
        if self.query_as_static and self.key_as_static:
            assert query.shape[2] == key.shape[2], \
                f'query shape[2] of {query.shape[2]} and key shape[2] of {key.shape[2]} need to be equal'
            head_dim, embed_dim_to_check = query.shape[2], query.shape[2]
        else:
            assert self.query_as_static or (query.shape[2] == self.embed_dim), \
                f'query shape[2] of {query.shape[2]} invalid, needs to be {self.embed_dim}.'
            assert (not self.query_as_static) or (self.query_as_static and query.shape[2] == self.head_dim), \
                f'query shape[2] of {query.shape[2]} invalid, needs to be {self.head_dim}'

            assert self.key_as_static or (key.shape[2] == self.kdim), \
                f'key shape[2] of {key.shape[2]} invalid, needs to be {self.kdim}.'
            assert (not self.key_as_static) or (self.key_as_static and key.shape[2] == self.head_dim), \
                f'key shape[2] of {key.shape[2]} invalid, needs to be {self.head_dim}'
            head_dim, embed_dim_to_check = self.head_dim, self.head_dim if self.query_as_static else self.embed_dim

        out_weights, out_bias = None, None
        if not self.disable_out_projection:
            out_weights, out_bias = self.out_proj.weight, self.out_proj.bias

        # map Euclidean biases to the ball with curvature self.c
        in_proj_bias_exp = pmath.expmap0(self.in_proj_bias, c=self.c) if self.in_proj_bias is not None else None
        bias_k_exp = pmath.expmap0(self.bias_k, c=self.c) if self.bias_k is not None else None
        out_proj_bias_exp = pmath.expmap0(out_bias, c=self.c) if out_bias is not None else None

        if not self._qkv_same_embed_dim:
            return hopfield_core_forward(
                query=query, key=key, embed_dim_to_check=embed_dim_to_check, num_heads=self.num_heads,
                in_proj_weight=self.in_proj_weight, in_proj_bias=in_proj_bias_exp, bias_k=bias_k_exp,
                add_zero_attn=self.add_zero_attn, dropout_p=self.dropout,
                out_proj_weight=out_weights, out_proj_bias=out_proj_bias_exp, training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask,
                use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,

                key_as_static=self.key_as_static, query_as_static=self.query_as_static,
                head_dim=head_dim, pattern_dim=self.pattern_dim,
                update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
                return_raw_associations=return_raw_associations,
                c=self.c, theta=self.theta, lr=self.lr
            )
        else:
            return hopfield_core_forward(
                query=query, key=key, embed_dim_to_check=embed_dim_to_check, num_heads=self.num_heads,
                in_proj_weight=self.in_proj_weight, in_proj_bias=in_proj_bias_exp, bias_k=bias_k_exp,
                add_zero_attn=self.add_zero_attn, dropout_p=self.dropout,
                out_proj_weight=out_weights, out_proj_bias=out_proj_bias_exp, training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask,

                key_as_static=self.key_as_static, query_as_static=self.query_as_static,
                head_dim=head_dim, pattern_dim=self.pattern_dim,
                update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
                return_raw_associations=return_raw_associations,
                c=self.c, theta=self.theta, lr=self.lr
            )
