import torch
import torch.nn as nn

import gc
from ..data_utils import return_given_alpha
from .state import OptimizationState

from .awq_tools.linear.gemm import WQLinear_GEMM
from .awq_tools.linear.gemv import WQLinear_GEMV
from .awq_tools.linear.gemv_fast import WQLinear_GEMVFast
from .awq_tools.linear.marlin import WQLinear_Marlin


class AWQLayer(nn.Module):
    def __init__(
        self,
        layer: nn.Module,
        sparsity_ratio=0.5,
        group_size=128,
        w_bit=4,
        prune_n=0,
        prune_m=0,
        use_variant=True,
        layer_id=0,
        layer_name="none",
        duo_scaling=True,
        version="gemm",
    ):
        """
        Initializes the AWQLayer with specific properties and configurations.

        layer: The underlying PyTorch layer (e.g., nn.Linear) to be wrapped by AWQLayer.
        sparsity_ratio: Target sparsity ratio for pruning.
        prune_n: Number of weights to keep in each group for structured pruning.
        prune_m: Group size for structured pruning.
        use_variant: Whether to use a pruning variant not mentioned in the paper.
        layer_id: Identifier for the layer, useful for models with multiple layers.
        layer_name: Human-readable name for the layer.
        """

        nn.Module.__init__(self)
        self.existing_layer = layer

        self.in_features = self.existing_layer.in_features
        self.out_features = self.existing_layer.out_features
        self.columns = self.in_features
        # self.dev = getattr(self.existing_layer, "weight.device", "cuda")
        self.dev = self.existing_layer.weight.device

        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.nsamples = 0

        self.sparsity_ratio = sparsity_ratio
        self.prune_n = prune_n
        self.prune_m = prune_m
        self.use_variant = use_variant

        self.layer_id = layer_id
        self.layer_name = layer_name

        # AWQ BELOW

        self.group_size = group_size
        self.w_bit = w_bit

        # TODO: Fix dimensions????
        self.activations = torch.zeros(
            (0, self.in_features), device=self.dev, dtype=torch.bfloat16
        )
        self.n_samples = 0

        self.duo_scaling = duo_scaling
        self.version = version

    def add_batch(self, inp, out):
        """
        Adjusts the scaling row (input activation size per row) based on the input batch for use in calculating pruning metrics.

        inp: Input tensor to the layer.
        out: Output from the layer
        """
        # inp: [n_batches, n_tokens, input_size] -> [n_tokens * n_batches, input_size]

        if isinstance(self.existing_layer, nn.Linear):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            # TODO: Keep or not
            # inp = inp.t()

        # self.n_samples += inp.shape[0]

        self.activations = torch.cat((self.activations, inp), dim=0)

    def train(self, mode: bool = True):
        self.existing_layer.train(mode)

    def forward(self, x: torch.Tensor):
        if hasattr(self.existing_layer, "weight"):
            return self.existing_layer(x.to(self.existing_layer.weight.device))
        else:
            return self.existing_layer(x.to(self.existing_layer.dev))

    def find_scales(self, n_grid=20):
        # TODO: Can remove w's below if no duo_scaling
        x_mean = self.activations.abs().view(-1, self.activations.shape[-1]).mean(0)
        weight = self.weight
        org_shape = weight.shape
        # The weights are reshaped to be organised by quantization group
        weight = weight.view(-1, self.group_size)
        # Calculates the relative magnitude of the weights within each of the quantization groups,
        # and rescales each group individually so that each group has weights on a 0-1 scale.
        w_scale = weight.abs() / weight.abs().amax(
            dim=1, keepdim=True
        )  # TODO: check dim?
        # Resizes the rescaled weight matrix back up to its original dimensions
        w_scale = w_scale.view(org_shape)
        # Gets the average rescaled magnitude for each output channel
        w_mean = w_scale.mean(0)
        clear_memory(weight)

        x_mean = x_mean.view(-1).to(self.dev)
        w_mean = w_mean.view(-1).to(self.dev)

        WX = self(self.activations)
        org_sd = {k: v.cpu() for k, v in self.state_dict().items()}

        history = []
        best_loss = float("inf")
        best_ratio = -1
        best_scales = None

        for ratio in range(n_grid):
            ratio = ratio / n_grid

            # NOTE: s^-1 * x is fused here, according to paper
            if self.duo_scaling:
                scales = (x_mean.pow(ratio) / w_mean.pow(1 - ratio)).clamp(min=1e-4)
            else:
                scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
            # Geometric mean approx. normalization
            scales = scales / (scales.max() * scales.min()).sqrt()
            scales_view = scales.view(1, -1).to(self.dev)

            self.weight.mul_(scales_view)
            self.weight.data = self.pseudo_quantize(self.weight.data)[0] / scales_view
            q_WX = self(self.activations)  # TODO: same as the other WX

            # Loss
            loss = (q_WX - WX).float().pow(2).mean().item()
            history.append(loss)
            if loss < best_loss:
                best_loss = loss
                best_ratio = ratio
                best_scales = scales

            # Restore model to original
            self.load_state_dict(org_sd)

        if best_ratio == -1:
            print(history)
            raise Exception

        best_scales = best_scales.view(-1)

        assert torch.isnan(best_scales).sum() == 0, best_scales

        # TODO: Keep .cpu or no?
        return best_scales.detach()

    def find_clips(self, n_grid=20, max_shrink=0.5, n_sample_token=512):
        w = self.weight.data
        input_feat = self.activations

        assert w.dim() == 2
        org_w_shape = w.shape
        # w           [co, ci]      -> [co, 1, n_group, group size]
        # input_feat  [n_token, ci] -> [1, n_token, n_group, group size]
        group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
        input_feat = input_feat.view(-1, input_feat.shape[-1])
        input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
        input_feat = input_feat[:, 0 :: input_feat.shape[1] // n_sample_token]
        w = w.reshape(org_w_shape[0], 1, -1, group_size)

        oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64  # prevent OOM
        assert org_w_shape[0] % oc_batch_size == 0
        w_all = w
        best_max_val_all = []

        for i_b in range(org_w_shape[0] // oc_batch_size):
            w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]

            org_max_val = w.abs().amax(dim=-1, keepdim=True)  # co, 1, n_group, 1

            best_max_val = org_max_val.clone()
            min_errs = torch.ones_like(org_max_val) * 1e9
            input_feat = input_feat.to(w.device)
            org_out = (input_feat * w).sum(dim=-1)  # co, n_token, n_group

            for i_s in range(int(max_shrink * n_grid)):
                max_val = org_max_val * (1 - i_s / n_grid)
                min_val = -max_val
                cur_w = torch.clamp(w, min_val, max_val)
                q_w = self.pseudo_quantize(cur_w)[0]
                cur_out = (input_feat * q_w).sum(dim=-1)

                # co, 1, n_group, 1
                err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
                del cur_w
                del cur_out
                cur_best_idx = err < min_errs
                min_errs[cur_best_idx] = err[cur_best_idx]
                best_max_val[cur_best_idx] = max_val[cur_best_idx]
            best_max_val_all.append(best_max_val)

        best_max_val = torch.cat(best_max_val_all, dim=0)

        clear_memory(input_feat)
        clear_memory(org_out)

        return best_max_val.squeeze(1)

    def pseudo_quantize(self, w):
        org_w_shape = w.shape
        if self.group_size > 0:
            assert org_w_shape[-1] % self.group_size == 0
            w = w.reshape(-1, self.group_size)
        assert w.dim() == 2
        assert torch.isnan(w).sum() == 0

        # zero point quantization
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
        max_int = 2**self.w_bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
        w = (
            torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
        ) * scales
        zeros = zeros.view(org_w_shape[0], -1)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(w).sum() == 0

        scales = scales.view(org_w_shape[0], -1)
        w = w.reshape(org_w_shape)

        return w, scales, zeros

    def quantize(self):

        # print("Finding best scales")
        scales = self.find_scales()

        # TODO: apply_scales()
        self.weight.mul_(scales.view(1, -1).to(self.weight.device))

        for p in self.parameters():
            assert torch.isnan(p).sum() == 0

        # print("Finding best clipping values")
        clips = self.find_clips()

        org_shape = self.weight.shape
        # self.weight.data = self.weight.data.reshape(*clips.shape[:2], -1)
        # self.weight.data = torch.clamp(self.weight.data, -clips, clips)
        # TODO: below show be the same as above. check the diff self.weights
        self.weight.data = torch.clamp(
            self.weight.data.reshape(*clips.shape[:2], -1), -clips, clips
        )
        self.weight.data = self.weight.data.reshape(org_shape)

        # print("Quantizing")

        linear_layer = self  # TODO: SHOULD BE nn.Linear?? dunno el oh el

        # NOTE: small regression in perplexity if linear layer uses .cpu().float()
        # linear_layer = linear_layer.to(get_best_device()).half()

        linear_layer.weight.data, scales, zeros = self.pseudo_quantize(
            linear_layer.weight.data
        )

        if self.version == "gemm":
            scales = scales.t().contiguous()
            if zeros is not None:
                zeros = zeros.t().contiguous()
            q_linear_module = WQLinear_GEMM

        elif self.version == "gemv":
            q_linear_module = WQLinear_GEMV

        elif self.version == "marlin":
            q_linear_module = WQLinear_Marlin

        elif self.version == "gemv_fast":
            q_linear_module = WQLinear_GEMVFast

        else:
            raise ValueError(f"Unknown version {self.version}")

        q_linear = q_linear_module.from_linear(
            linear=linear_layer,
            w_bit=self.w_bit,
            group_size=self.group_size,
            init_only=False,
            scales=scales,
            zeros=zeros,
        )

        # linear_layer.cpu()
        q_linear.to(next(self.parameters()).device)
        self.existing_layer = q_linear
        # set_op_by_name(self, name, q_linear)
        clear_memory()
        # TODO: figure out what q_linear is stored as and then modify forward pass???

    def dequantize(self):
        pass

    def get_layer(self):
        return self.existing_layer

    # TODO: change these
    @property
    def weight(self):
        return self.existing_layer.weight.data

    @property
    def bias(self):
        if self.existing_layer.bias is not None:
            return self.existing_layer.bias.data
        else:
            return None

    # def apply_scale(self, scales, input_feat_dict=None):
    #     self.cuda()
    #     scales.cuda()

    #     if isinstance(self, nn.Linear):
    #         assert len(layers) == 1
    #         scale_fc_fc(self, layers[0], scales)
    #     elif isinstance(self, (nn.LayerNorm, LlamaRMSNorm)):
    #         scale_ln_fcs(self, layers, scales)
    #     elif isinstance(self, (nn.GELU, BloomGelu, GELUActivation)):
    #         new_module = ScaledActivation(prev_op, scales)
    #         set_op_by_name(module, prev_op_name, new_module)
    #         scale_gelu_fc(prev_op, layers[0], scales)
    #     else:
    #         raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")

    #     # apply the scaling to input feat if given; prepare it for clipping
    #     if input_feat_dict is not None:
    #         for layer_name in layer_names:
    #             inp = input_feat_dict[layer_name]
    #             inp.div_(scales.view(1, -1).to(inp.device))

    #     prev_op.cpu()
    #     for layer in layers:
    #         layer.cpu()
    #     scales.cpu()


def clear_memory(weight=None):
    if weight is not None:
        del weight
    gc.collect()
    torch.cuda.empty_cache()

    # def find_clip(self, s_X, WX, org_sd, scales):
    #     # TODO: check all of this
    #     scales_view = scales.view(1, -1).to(self.dev)
    #     self.weight.mul_(scales_view)

    #     org_max_val = self.weight.abs().amax(dim=-1, keepdim=True) # TODO Check dim

    #     # TODO: check these
    #     best_max_val = org_max_val.clone()
    #     min_errs = torch.ones_like(org_max_val) * 1e9
    #     input_feat = input_feat.to(self.weight.device)
    #     org_out = (input_feat * self.weight).sum(dim=-1)  # co, n_token, n_group

    #     history = []
    #     best_loss = float('inf')
    #     best_ratio = -1

    #     n_grid = 20
    #     for ratio in torch.linspace(0, 1, n_grid + 1):
    #         mag = org_max_val * ratio

    #         clipped_W = torch.clamp(self.weight, -mag, mag)
    #         q_WX = self.pseudo_quantize(clipped_W)

    #         # REDO THiS
    #         cur_out = (input_feat * q_WX).sum(dim=-1)

    # def find_scales(self, n_grid=20, duo_scaling=False):
    #     history = []
    #     best_ratio = -1
    #     best_scales = None
    #     best_error = float("inf")

    #     org_sd = {k: v.cpu() for k, v in self.state_dict().items()}

    #     device = x.device
    #     x_mean = x_mean.view(-1).to(device)
    #     w_mean = w_mean.view(-1).to(device)

    #     for ratio in range(n_grid):
    #         # create new scales
    #         ratio = ratio / n_grid

    #         # NOTE: s^-1 * x is fused here, according to paper
    #         if self.duo_scaling:
    #             scales = (x_mean.pow(ratio) / w_mean.pow(1 - ratio)).clamp(min=1e-4)
    #         else:
    #             scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
    #         scales = scales / (scales.max() * scales.min()).sqrt()
    #         scales_view = scales.view(1, -1).to(device)

    #         # Q(W * s)
    #         for fc in linears2scale:
    #             fc.weight.mul_(scales_view)
    #             fc.weight.data = (
    #                 self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
    #             )

    #         # W * X
    #         int_w_output = module2inspect(x, **kwargs)
    #         if isinstance(int_w_output, tuple):
    #             int_w_output = int_w_output[0]

    #         # compute mean squared error (L2 norm)
    #         loss = (
    #             (fp16_output - int_w_output).float().pow(2).mean().item()
    #         )  # NOTE: float prevents overflow

    #         history.append(loss)
    #         if loss < best_error:
    #             best_error = loss
    #             best_ratio = ratio
    #             best_scales = scales.clone()
    #         module2inspect.load_state_dict(org_sd)

    #     if best_ratio == -1:
    #         logging.debug(history)
    #         raise Exception

    #     assert torch.isnan(best_scales).sum() == 0, best_scales

    #     return best_scales.detach().cpu()
