from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    Mapping,
    Optional,
    Set,
    Tuple,
    TypeVar,
    Union,
    overload,
)

import torch
from torch import Tensor, device, dtype, nn
from lpmm.functional import vectorwise_quant, vectorwise_dequant
from copy import deepcopy

T = TypeVar("T", bound="torch.nn.Module")


class Int8Param(torch.nn.Parameter):
    def __new__(
        cls,
        data=None,
        requires_grad=True,
    ):
        if data is None:
            data = torch.empty(0)
        return torch.Tensor._make_subclass(cls, data, requires_grad)
    
    def cuda(self, device):
        # we store the 8-bit rows-major weight
        # we convert this weight to the turning/ampere weight during the first inference pass
        # B = self.data.contiguous().half().cuda(device)
        qdata, metadata = vectorwise_quant(self.data, b=8, quant_type="vector")
        self.data = qdata.cuda()
        setattr(self, "metadata", metadata)
        return self

    @overload
    def to(
        self: T,
        device: Optional[Union[int, device]] = ...,
        dtype: Optional[Union[dtype, str]] = ...,
        non_blocking: bool = ...,
    ) -> T:
        ...

    @overload
    def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
        ...

    @overload
    def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
        ...

    def to(self, *args, **kwargs):
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
            *args, **kwargs
        )

        if (
            device is not None
            and device.type == "cuda"
            and self.data.device.type == "cpu"
        ):
            return self.cuda(device)
        else:
            new_param = Int8Param(
                super().to(
                    device=device, dtype=dtype, non_blocking=non_blocking,
                ),
                requires_grad=False,
            )

            return new_param

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # NOTE: Logging calls Tensor.__repr__, so we can't log __repr__ without infinite recursion
        # if func is not torch.Tensor.__repr__:
        #     logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
        if func is not torch.Tensor.__repr__:
            print(f"func: {func.__name__}")
            print(f"func: {func.__name__}, types: {types}, args: {args!r}, kwargs: {kwargs!r}")
        qtype, device = types[0], args[0].device
        if kwargs is None:
            kwargs = {}
        for i, arg in enumerate(args):
            if isinstance(arg, qtype):
                x = vectorwise_dequant(arg.data, b=8, quant_type="vector", metadata=arg.metadata)
                args[i] = x
                del arg
        ret = super().__torch_function__(func, types, args, kwargs)
        if not func.__name__.endswith("__") and func.__name__.endswith("_") and isinstance(args[0], qtype):
            print(f"func: {func.__name__} enter in-place modify scope")
            qx, metadata = vectorwise_quant(ret, b=8, quant_type="vector")
            return cls.__new__()
        else:
            return ret
         

class LPMM(nn.Module):
    def __init__(self, net):
        super().__init__()
        assert isinstance(net, nn.Sequential)
        self.net = deepcopy(net)
        self.L = len(self.net)
        self.inputs = [None] * self.L

    def forward(self, inputs, **kwargs):
        self.net.requires_grad_(False)
        for l, layer in enumerate(self.net):
            self.inputs[l] = inputs
            inputs = layer(inputs)
            # if not isinstance(inputs, (list, tuple)):
            #     inputs = (inputs,)
        return inputs

    def backward(self, grad_output):
        self.net.requires_grad_(True)
        for rev_l, layer in enumerate(reversed(self.net)):
            l = self.L - 1 - rev_l
            inputs = self.inputs[l]
            inputs.requires_grad_(True)
            outputs = layer(inputs)
            torch.autograd.backward(outputs, grad_tensors=grad_output)
            grad_output = inputs.grad
            
            inputs.grad = None
            inputs.requires_grad_(False)
        