import math

import torch
import torch.nn as nn
import torch.nn.functional as F
# import bitsandbytes as bnb
# import bitsandbytes.functional as bnbF
# from torch.nn.init import _calculate_fan_in_and_fan_out
# from transformers import AutoModelForCausalLM, AutoConfig

# from loguru import logger
from torch.nn import Parameter
from torch import Tensor
import sparse_linear_me

# from torch_sparse import coalesce, spmm



class lora_sparse_linear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, lora_B, lora_A, dv, di, bias):
        ctx.save_for_backward(input, lora_B, lora_A, dv, di, bias)

        return sparse_linear_me.forward(input, lora_B, lora_A, dv, di, bias)

    @staticmethod
    def backward(ctx, output_grad):
        input, lora_B, lora_A, dv, di, bias = ctx.saved_tensors

        grads = sparse_linear_me.backward(
            output_grad, input, lora_B, lora_A, dv, di,
            ctx.needs_input_grad[0],
            ctx.needs_input_grad[1],
            ctx.needs_input_grad[2],
            ctx.needs_input_grad[3],
            bias is not None and ctx.needs_input_grad[5],
            bias,
        )

        return tuple(grads)


class SpLoRaLinear(nn.Module):
    def __init__(
            self,
            in_features: int,
            out_features: int,
            r: int,
            sp_ratio: float = 0.01,
            sp_type: str = 'random',
            *,
            lora_alpha: int = 1,
            lora_dropout: float = 0.0,
            trainable_scaling: bool = False,
            random_subspace: bool = False,
            bias=True,
            device=None,
            dtype=None,
    ):
        """
        Reparameterized sparse and low rank linear layer
                    x W_a @ W_b * lora_alpha / r + x W_sp + bias
        Notice that scale = lora_alpha / r.
        Notice that this class cannot be wrapped to linear layer and thus cannot be used for fine-tune
        For fine-tune, please refer to ... TODO
        """
        super().__init__()
        # nn.Module.__init__(self)
        if r <= 0:
            raise ValueError("r must be positive.")
        if sp_ratio <= 0 or sp_ratio >= 1:
            raise ValueError("sp_ratio must be between 0 and 1.")

        if bias:
            self.bias = Parameter(torch.zeros(out_features, device=device, dtype=dtype, requires_grad=True))
            a = 1/math.sqrt(out_features)
            nn.init.uniform_(self.bias, -a, a)
        else:
            self.register_parameter('bias', None)

        self.in_features = in_features
        self.out_features = out_features
        self.r = r
        self.lora_alpha = lora_alpha
        # self.lora_dropout = nn.Dropout(p=lora_dropout)
        self.random_subspace = random_subspace
        self.trainable_scaling = trainable_scaling
        self.sp_ratio = sp_ratio
        self.sp_type = sp_type
        self.device = device
        self.dtype = dtype

        lora_A_requires_grad = False if random_subspace else True
        self.lora_A = nn.Parameter(torch.empty(r, in_features, dtype=dtype, device=device), requires_grad=lora_A_requires_grad)
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        self.lora_B = nn.Parameter(torch.empty(out_features, r, dtype=dtype, device=device))
        nn.init.zeros_(self.lora_B)
        if trainable_scaling:
            self.scaling = nn.Parameter(torch.tensor([1.], device=device, dtype=dtype), requires_grad=True)
        else:
            self.scaling = self.lora_alpha / self.r

        if sp_type.lower() == 'random':
            indices, values, shape = self._init_sparse_parameters()
            self.shape = shape
            self.register_buffer("sparse_index", indices.to(device))
            # self.sparse_index =
            self.sparse_value = Parameter(values.to(device), requires_grad=True)



    def _post_lora_scale(self):
        if self.trainable_scaling:
            return self.scaling.tanh()

        return self.scaling

    def _init_sparse_parameters(self):
        # Calculate total elements and the number of non-zero elements
        shape = [self.out_features, self.in_features]
        total_elements = self.in_features * self.out_features
        num_nonzeros = int(self.sp_ratio * total_elements)

        # Generate random indices for non-zero elements
        indices = torch.randperm(total_elements)[:num_nonzeros]
        indices, _ = torch.sort(indices)
        indices.to(self.device)

        # Generate random values for non-zero elements
        values = torch.empty(size=(num_nonzeros,), device=self.device, dtype=self.dtype)
        a = 1/math.sqrt(self.in_features)
        nn.init.uniform_(values, -a, a)

        return indices, values, shape


    def forward(self, x: Tensor) :
        """
            Input x : [..., in_dim] and Output [..., out_dim]
        """
        # out = sp_batch_mm(self.sparse_weight, x) + self.bias
        out = 0
        if self.sp_type.lower() == 'random':
            # out += LoraSparseLinear.apply(x, self.lora_B.mm(self.lora_A) * self._post_lora_scale(),
            #                     self.sparse_value, self.sparse_index,
            #                     self.bias)

            # out += sparse_linear.apply(x, self.lora_B.mm(self.lora_A) * self._post_lora_scale(),
            #                      self.sparse_value, self.sparse_index,
            #                      self.bias)
            out += lora_sparse_linear.apply(x, self.lora_B, self.lora_A * self._post_lora_scale(),
                                 self.sparse_value, self.sparse_index,
                                 self.bias)

        return out

    def extra_repr(self) -> str:
        return (f'in_features={self.in_features}, out_features={self.out_features}, rank={self.r}, '
                f'sparsity={self.sp_ratio}, bias={self.bias is not None}')






class SpLoRaLinearFT(nn.Module):
    def __init__(
            self,
            in_features: int,
            out_features: int,
            r: int,
            sp_ratio: float = 0.01,
            sp_type: str = 'random',
            *,
            lora_alpha: int = 1,
            lora_dropout: float = 0.0,
            trainable_scaling: bool = False,
            random_subspace:bool = False,
            weight_data=None,
            bias_data=None,
            bias=True,
            device=None,
            dtype=None,
    ):
        """
        Reparameterized sparse and low rank linear layer
                    x W_a @ W_b * lora_alpha / r + x W_sp + bias
        Notice that scale = lora_alpha / r.
        Notice that this class cannot be wrapped to linear layer and thus cannot be used for fine-tune
        For fine-tune, please refer to ... TODO
        """
        super().__init__()
        # nn.Module.__init__(self)
        if r <= 0:
            raise ValueError("r must be positive.")
        if sp_ratio <= 0 or sp_ratio >= 1:
            raise ValueError("sp_ratio must be between 0 and 1.")

        if bias_data is None:
            bias_data = torch.zeros(out_features, device=device, dtype=dtype, requires_grad=True) if bias else None
        self.bias = nn.Parameter(bias_data) if bias else None

        if weight_data is None:
            # note that our trainable weight are W_a and W_b
            weight_data = torch.zeros(out_features, in_features, device=device, dtype=dtype, requires_grad=False)
        self.weight = nn.Parameter(weight_data, requires_grad=False)

        self.in_features = in_features
        self.out_features = out_features
        self.r = r
        self.lora_alpha = lora_alpha
        self.lora_dropout = nn.Dropout(p=lora_dropout)
        self.trainable_scaling = trainable_scaling
        self.random_subspace = random_subspace
        self.sp_ratio = sp_ratio
        self.sp_type = sp_type
        self.device = device
        self.dtype = dtype

        lora_A_requires_grad = False if random_subspace else True
        self.lora_A = nn.Parameter(torch.empty(r, in_features, dtype=dtype, device=device), requires_grad=lora_A_requires_grad)
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        self.lora_B = nn.Parameter(torch.empty(out_features, r, dtype=dtype, device=device))
        nn.init.zeros_(self.lora_B)
        if trainable_scaling:
            self.scaling = nn.Parameter(torch.tensor([1.], device=device, dtype=dtype), requires_grad=True)
        else:
            self.scaling = self.lora_alpha / self.r

        indices, values, shape = self._init_sparse_parameters()
        self.shape = shape
        self.register_buffer("sparse_index", indices.to(device))
        # self.sparse_index = indices.to(device)
        self.sparse_value = Parameter(values.to(device), requires_grad=True)



    def _post_lora_scale(self):
        if self.trainable_scaling:
            return self.scaling.tanh()

        return self.scaling

    def _init_sparse_parameters(self):
        # Calculate total elements and the number of non-zero elements
        shape = [self.out_features, self.in_features]
        total_elements = self.in_features * self.out_features
        num_nonzeros = int(self.sp_ratio * total_elements)

        # Generate random indices for non-zero elements
        indices = torch.randperm(total_elements)[:num_nonzeros]
        indices, _ = torch.sort(indices)
        indices.to(self.device)

        # Generate random values for non-zero elements
        values = torch.empty(size=(num_nonzeros,), device=self.device, dtype=self.dtype)
        a = 1/math.sqrt(self.in_features)
        nn.init.uniform_(values, -a, a)

        return indices, values, shape


    def forward(self, x: Tensor) :

        out = lora_sparse_linear.apply(self.lora_dropout(x), self.lora_B, self.lora_A * self._post_lora_scale(),
                                 self.sparse_value, self.sparse_index,
                                 self.bias)

        out += F.linear(x, self.weight, bias=None)

        return out

    def extra_repr(self) -> str:
        return (f'in_features={self.in_features}, out_features={self.out_features}, rank={self.r}, '
                f'sparsity={self.sp_ratio}, bias={self.bias is not None}')