from __future__ import annotations
import torch
from inspect import signature
from operator import attrgetter
from typing import Optional, Tuple
from types import NoneType
from typing import Any, Callable, Dict, Annotated, List, TypeVar, Union, Tuple
from msgspec import Struct
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(parent_dir)
sys.path.append(os.path.join(parent_dir, 'triton_kernels'))
from sparse_gemv import sparse_gemv
from swiglu_norm import swiglu_norm

__all__ = ['Dynamic', 'BaseKernel', 'SparseGEMV', 'SwigluNorm']

T = TypeVar('T')

NAMESPACE: str = "norsa"

TENSOR: str = "Tensor"
NONE: str = "None"

DOT_OP: str = "."
TYPE_SEP: str = " | "
COMMA_SPACE: str = ", "

TYPE_MAPPINGS: Dict[str, str] = (
    {
        'dtype': 'ScalarType',
        'device': 'Device',
    }
)


class Dynamic:
    Int = Annotated[int, 'SymInt']
    Num = Annotated[int | float, 'Scalar']


def check_origin(annotation: Any, target: Any) -> bool:
    return hasattr(annotation, '__origin__') and annotation.__origin__ is target


def unpack_optional(annotation: Union[T, NoneType] | T) -> Tuple[T, bool]:
    if check_origin(annotation, Union):
        target, maybe_none, *maybe_more = annotation.__args__
        if maybe_more or maybe_none is not NoneType:
            raise TypeError(f'unsupported multi-union type: {annotation}')
        else:
            return target, True
    return annotation, False


def unpack_list(annotation: List[T] | T) -> Tuple[T, bool]:
    if check_origin(annotation, list):
        return annotation.__args__[0], True
    return annotation, False


def resolve_type(annotation: Any) -> str:

    target, is_optional = unpack_optional(annotation)
    target, is_iterable = unpack_list(target)

    if hasattr(target, '__metadata__'):
        resolved = target.__metadata__[0]
    elif hasattr(target, '__name__'):
        if (str_value := target.__name__).endswith(TENSOR):
            resolved = TENSOR
        else:
            resolved = TYPE_MAPPINGS.get(str_value, str_value)
    else:
        raise TypeError(f'unable to infer type from given {target}')
    if is_iterable:
        resolved = f'{resolved}[]'
    if is_optional:
        resolved = f'{resolved}?'
    return resolved


def resolve_return(annotation: Any) -> str:
    target = annotation
    if check_origin(annotation, tuple):
        target = annotation.__args__
    if isinstance(target, tuple):
        return f'({COMMA_SPACE.join(map(resolve_type, target))})'
    else:
        return resolve_type(target)


class BaseKernel(Struct):

    name: str
    target: str
    schema: str

    @classmethod
    def initialize(cls, name: str, target: str, **kwargs) -> BaseKernel:
        return cls(name, target, cls.schematize())

    @classmethod
    def schematize(cls) -> str:
        params = dict((forward_signature := signature(cls.forward, eval_str=True)).parameters)
        _ = params.pop('self')
        arguments = (f"{resolve_type(p.annotation)} {name}" for name, p in params.items())
        out = resolve_return(forward_signature.return_annotation)
        return f"({', '.join(arguments)}) -> {out}"

    @property
    def is_registered(self) -> bool:
        return hasattr(getattr(torch.ops, NAMESPACE), self.name)

    def operator(self, compiled: bool = False) -> Callable:
        if compiled:
            self.register()
            return attrgetter(f'{NAMESPACE}.{self.name}')(torch.ops)
        else:
            return self.forward

    def meta(self, *args, **kwargs) -> Any:
        raise NotImplementedError(
            f'abstract implementation `meta` in {self.__class__.__name__} required for '
            f'registration'
        )

    def forward(self, *args, **kwargs) -> Any:
        raise NotImplementedError(
            f'concrete implementation `forward` in {self.__class__.__name__} required for '
            f'registration'
        )

    def register(self) -> None:
        if not self.is_registered:
            qualname = f'{NAMESPACE}::{self.name}'
            torch.library.define(qualname, self.schema)
            torch.library.impl_abstract(qualname, func=self.meta)
            torch.library.impl(qualname, self.target, func=self.forward)


class SparseGEMV(BaseKernel):

    def meta(
        self,
        x: torch.Tensor,
        weight: torch.Tensor,
        threshold_ratio: float,
        norm_squared: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return x.new_empty((x.size(0), x.size(1), weight.size(0)))

    def forward(
        self,
        x: torch.Tensor,
        weight: torch.Tensor,
        threshold_ratio: float,
        norm_squared: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return sparse_gemv(x, weight, threshold_ratio, norm_squared)



class SwigluNorm(BaseKernel):

    def meta(self,
             a: torch.Tensor,
             b: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return (a.new_empty(a.shape[0], a.shape[1], a.shape[2]), torch.empty(1, dtype=torch.float32, device=a.device))
    
    def forward(self,
             a: torch.Tensor,
             b: torch.Tensor
    ) -> (torch.Tensor, torch.Tensor):
        return swiglu_norm(a, b)