import math
from logging import getLogger

import numpy as np
import torch
from gekko import GEKKO
from torch import nn


logger = getLogger(__name__)

try:
    import cQIGen as qinfer
except ImportError as e:
    exception_qinfer = e

    class FakeQInfer:
        def __getattr__(self, name):
            raise ImportError(f"cQIGen is not installed or not correctly installed. {exception_qinfer}")


def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
    m = GEKKO()  # create GEKKO model
    # cinfergen if bits==3:
    # tu = tu*3
    B = m.Const(value=bits)
    TP = m.Const(value=T // p)
    k = m.Var(1, integer=True, lb=1)
    z = m.Var(1, integer=True, lb=1)
    w = m.Var(1, integer=True, lb=1)
    y = m.Var(1, integer=True, lb=1)
    mb = m.Var(mu, integer=True, lb=1)
    if gs != -1:
        gg = m.Var(1, integer=True, lb=1)
    tb = m.Var(tu, integer=True, lb=1, ub=int(T / p))
    L = m.Var(integer=True, lb=0, ub=l1)
    m.Equation(L == 32 * mb * N + B * mb * tb + 32 * tb * N)
    m.Equation(mb * k == M)
    if gs != -1:
        m.Equation(gs * gg == mb)
    # m.Equation(tb * z == T)
    m.Equation(tb * z == TP)
    m.Equation(mu * w == mb)
    m.Equation(tu * y == tb)
    # m.Equation(tb * v == tt)
    m.Maximize(L)
    m.options.SOLVER = 1
    m.solver_options = [
        "minlp_maximum_iterations 1000",  # minlp iterations with integer solution
        "minlp_max_iter_with_int_sol 10",  # treat minlp as nlp
        "minlp_as_nlp 0",  # nlp sub-problem max iterations
        "nlp_maximum_iterations 100",  # 1 = depth first, 2 = breadth first
        "minlp_branch_method 2",  # maximum deviation from whole number
        "minlp_integer_tol 0.00",  # covergence tolerance
        "minlp_gap_tol 0.01",
    ]
    try:
        m.solve(disp=False)
    except Exception:
        try:
            m.solver_options = [
                "minlp_maximum_iterations 1000",  # minlp iterations with integer solution
                "minlp_max_iter_with_int_sol 10",  # treat minlp as nlp
                "minlp_as_nlp 0",  # nlp sub-problem max iterations
                "nlp_maximum_iterations 100",  # 1 = depth first, 2 = breadth first
                "minlp_branch_method 1",  # maximum deviation from whole number
                "minlp_integer_tol 0.00",  # covergence tolerance
                "minlp_gap_tol 0.01",
            ]
            m.solve(disp=False)
        except Exception:
            # mytb = T//p
            mytb = tu
            if gs != -1:
                mymb = gs
                while 32 * (mymb + gs) * N + bits * (mymb + gs) * mytb + 32 * mytb * N < l1:
                    mymb += gs
                while M % mymb != 0:
                    mymb -= gs
                return (int(mymb), int(mytb))
            else:
                mymb = mu
                while 32 * (mymb + mu) * N + bits * (mymb + mu) * mytb + 32 * mytb * N < l1:
                    mymb += mu
                while M % mymb != 0:
                    mymb -= mu
                return (int(mymb), int(mytb))

    return (int(mb.value[0]), int(tb.value[0]))


params = {}


def compute_reductions(x, gs=-1, cpp=True):
    if cpp:
        if len(x.shape) != 1:
            rows, cols = x.shape
        else:
            rows = 1
            cols = x.shape[0]
        if gs == -1:
            out = torch.zeros(rows).float().contiguous()
            mygs = cols
        else:
            out = torch.zeros(rows, cols // gs).float().contiguous()
            mygs = gs

        qinfer.compute_reduction_cpp(x, out, rows, cols, mygs)
        return out
    if gs == -1:
        if len(x.shape) != 1:
            return torch.sum(x, 1)
        else:
            return torch.sum(x)
    else:
        if len(x.shape) != 1:
            rows, cols = x.shape
            out = torch.zeros(rows, cols // gs).float().contiguous()
            for i in range(cols // gs):
                out[:, i] = torch.sum(x[:, i * gs : (i + 1) * gs], 1)
            return out
        else:
            cols = x.shape[0]
            out = torch.zeros(cols // gs).float().contiguous()
            for i in range(cols // gs):
                out[i] = torch.sum(x[i * gs : (i + 1) * gs])
            return out


def process_zeros_scales(zeros, scales, bits, M):
    if zeros.dtype != torch.float32:
        new_zeros = torch.zeros_like(scales).float().contiguous()
        if bits == 4:
            qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
        elif bits == 2:
            qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
        elif bits == 3:
            logger.info("Unpacking zeros for 3 bits")
        new_scales = scales.contiguous()
    else:
        if scales.shape[1] != M:
            new_scales = scales.transpose(0, 1).contiguous()
        else:
            new_scales = scales.contiguous()
        if zeros.shape[1] != M:
            new_zeros = zeros.transpose(0, 1).contiguous()
        else:
            new_zeros = zeros.contiguous()

    return new_zeros, new_scales


class QuantLinear(nn.Module):
    QUANT_TYPE = "qigen"

    def __init__(
        self,
        bits,
        group_size,
        infeatures,
        outfeatures,
        bias=None,
        trainable=False,
        hint=1,
        p=8,
        l1=2**18,
    ):
        super().__init__()
        if bits not in [2, 4]:
            raise NotImplementedError("Only 2,4 bits are supported.")
        if trainable:
            raise NotImplementedError("Qigen kernel does not support training.")
        self.bits = bits

        self.infeatures = infeatures
        self.outfeatures = outfeatures

        n = hint
        m = self.infeatures
        t = self.outfeatures

        # registers for now are fixed
        if bits == 3:
            packed = 32
            mu = 32
            tu = 32
        else:
            packed = 32 // bits
            mu = 16
            tu = 32

        global params
        if (m, t) in params:
            mb = params[(m, t)][0]
            tb = params[(m, t)][1]
        else:
            mb, tb = mem_model(n, m, t, mu, tu, bits, l1, p, group_size)
            params[(m, t)] = (mb, tb)

        split = np.ones(p)
        split = split * tb
        while np.sum(split) < t:
            split = split + tb

        idx = p - 1
        while np.sum(split) > t:
            split[idx] = split[idx] - tb
            idx = idx - 1

        assert np.sum(split) == t

        split = split.astype(int)
        self.tt = int(split[0])

        if split[0] == split[-1]:
            self.cutoff = int(p + 1)
        else:
            self.cutoff = int(idx + 1)

        self.mb = mb  # // packed
        self.tb = tb

        self.group_size = group_size

        self.register_buffer("bias", torch.zeros(self.outfeatures))
        self.register_buffer(
            "zeros",
            torch.zeros(
                (math.ceil(infeatures / self.group_size), outfeatures),
                dtype=torch.float32,
            ),
        )
        self.register_buffer(
            "scales",
            torch.zeros(
                (math.ceil(infeatures / self.group_size), outfeatures),
                dtype=torch.float32,
            ),
        )
        if bits == 4:
            self.register_buffer(
                "qweight",
                torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous(),
            )
        elif bits == 3:
            self.register_buffer(
                "qweight",
                torch.zeros(int(self.infeatures // packed * 3 * self.outfeatures)).int().contiguous(),
            )
        elif bits == 2:
            self.register_buffer(
                "qweight",
                torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous(),
            )

    def forward(self, x):
        out_shape = x.shape[:-1] + (self.outfeatures,)
        x = x.reshape((-1, x.shape[-1])).to(torch.float32)
        B = x.shape[0]
        new_x = x.T.contiguous()
        out = torch.zeros((B, self.outfeatures), dtype=torch.float32)
        sums = compute_reductions(x, gs=self.group_size, cpp=True).contiguous()
        if self.group_size == -1:
            if self.bits == 4:
                qinfer.forward4(
                    new_x,
                    self.qweight,
                    out,
                    self.bias,
                    self.scales,
                    self.zeros,
                    sums,
                    B,
                    self.infeatures,
                    self.outfeatures,
                    B,
                    self.mb,
                    self.tb,
                    self.tt,
                    self.cutoff,
                )
            elif self.bits == 2:
                qinfer.forward2(
                    new_x,
                    self.qweight,
                    out,
                    self.bias,
                    self.scales,
                    self.zeros,
                    sums,
                    B,
                    self.infeatures,
                    self.outfeatures,
                    B,
                    self.mb,
                    self.tb,
                    self.tt,
                    self.cutoff,
                )
            elif self.bits == 3:
                qinfer.forward3(
                    new_x,
                    self.qweight,
                    out,
                    self.bias,
                    self.scales,
                    self.zeros,
                    sums,
                    B,
                    self.infeatures,
                    self.outfeatures,
                    B,
                    self.mb,
                    self.tb,
                    self.tt,
                    self.cutoff,
                )
        else:
            if self.bits == 4:
                qinfer.forward_gs4(
                    new_x,
                    self.qweight,
                    out,
                    self.bias,
                    self.scales,
                    self.zeros,
                    sums,
                    B,
                    self.infeatures,
                    self.outfeatures,
                    B,
                    self.mb,
                    self.tb,
                    self.tt,
                    self.group_size,
                    self.cutoff,
                )
            elif self.bits == 2:
                qinfer.forward_gs2(
                    new_x,
                    self.qweight,
                    out,
                    self.bias,
                    self.scales,
                    self.zeros,
                    sums,
                    B,
                    self.infeatures,
                    self.outfeatures,
                    B,
                    self.mb,
                    self.tb,
                    self.tt,
                    self.group_size,
                    self.cutoff,
                )
            elif self.bits == 3:
                qinfer.forward_gs3(
                    new_x,
                    self.qweight,
                    out,
                    self.bias,
                    self.scales,
                    self.zeros,
                    sums,
                    B,
                    self.infeatures,
                    self.outfeatures,
                    B,
                    self.mb,
                    self.tb,
                    self.tt,
                    self.group_size,
                    self.cutoff,
                )
        return out.reshape(out_shape)
