import math
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import is_bitsandbytes_available

from moe_peft.executors import executor
from moe_peft.common import TASK_IDS

from .abstracts import LLMMoeBlock
from .config import LLMModelInput, LoraConfig

if is_bitsandbytes_available():
    import bitsandbytes as bnb
    from bitsandbytes.nn import Linear4bit, Linear8bitLt
else:
    from moe_peft.utils import Linear8bitLt, Linear4bit

from typing import Any, Dict, List, Tuple


def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
    # BNB requires CUDA weights
    device = weight.device
    is_cpu = device.type == torch.device("cpu").type
    if is_cpu:
        weight = weight.to(torch.device("cuda"))

    cls_name = weight.__class__.__name__
    if cls_name == "Params4bit":
        dequantized = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
        if is_cpu:
            dequantized = dequantized.to(device)
        return dequantized

    if state.SCB is None:
        state.SCB = weight.SCB

    im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
    im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
    im, Sim = bnb.functional.transform(im, "col32")
    if state.CxB is None:
        state.CxB, state.SB = bnb.functional.transform(
            weight.data, to_order=state.formatB
        )
    out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
    dequantized = bnb.functional.mm_dequant(
        out32, Sout32, SCim, state.SCB, bias=None
    ).t()
    if is_cpu:
        dequantized = dequantized.to(device)
    return dequantized


def dequantize_module_weight(module: torch.nn.Module) -> torch.nn.Parameter:
    if hasattr(module, "W_q"):  # For handling HQQ quantized weight
        weight = module.dequantize()
        return weight

    weight = module.weight
    if not isinstance(weight, torch.nn.Parameter):
        raise TypeError(
            f"Input weight should be of type nn.Parameter, got {type(weight)} instead"
        )

    cls_name = weight.__class__.__name__
    if cls_name not in ("Params4bit", "Int8Params"):
        return weight

    quant_state = getattr(module, "state", None)
    device = weight.device
    is_cpu = device.type == torch.device("cpu").type
    weight = dequantize_bnb_weight(weight, state=quant_state)  # no-op if not bnb
    if is_cpu:
        # dequantize_bnb_weight for 8bit moves the device in-place, thus we need to move it back to CPU if necessary
        module.weight = module.weight.to(device)
    return weight


g_cached_range_tensor: Dict[torch.device, torch.Tensor] = {}
# also max batch size
g_max_range = 128


def get_range_tensor(device: torch.device, batch_size: int = 1024):
    global g_cached_range_tensor
    global g_max_range
    if device not in g_cached_range_tensor or batch_size > g_max_range:
        g_max_range = g_max_range if g_max_range > batch_size else batch_size
        g_cached_range_tensor[device] = torch.arange(
            0, g_max_range, step=1, device=device
        )
    return g_cached_range_tensor[device]


class LoraFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        result: torch.Tensor,
        data: torch.Tensor,
        input_args: LLMModelInput,
        dropouts: List[float],
        scalings: List[float],
        *args,
    ):
        # the lora module is f32 precision
        data = data.to(torch.float32)

        save_inputs: Tuple[torch.Tensor | None, ...] = (data,)

        lora_range = get_range_tensor(data.device, data.shape[0])
        for lora_a, lora_b, lora_config, dropout, scaling in zip(
            args[::2],
            args[1::2],
            input_args.batch_configs_,
            dropouts,
            scalings,
        ):
            assert not ((lora_a is None) ^ (lora_b is None))
            if lora_a is None and lora_b is None:
                save_inputs += (None, None, None)
                continue

            assert not ((lora_a.requires_grad) ^ (lora_b.requires_grad))
            if not lora_a.requires_grad and not lora_b.requires_grad:
                save_inputs += (None, None, None)
                continue

            start_idx = lora_config.batch_start_idx_
            end_idx = lora_config.batch_end_idx_

            # must ensure the dropout is not zero
            # is dropout == 0, dropdata is a data's referece, so the data will be changed
            assert dropout > 0.0

            drop_data = F.dropout(data[start_idx:end_idx], p=dropout)
            drop_data.mul_(scaling)
            drop_data = drop_data @ lora_a.transpose(0, 1)
            lora_data = drop_data @ lora_b.transpose(0, 1)

            lora_data = lora_data.to(result.dtype)

            result.index_add_(0, lora_range[start_idx:end_idx], lora_data)

            save_inputs += (lora_a, lora_b, drop_data)

        ctx.input_args = input_args
        ctx.dropouts = dropouts
        ctx.scalings = scalings
        ctx.save_for_backward(*save_inputs)

        return result

    @staticmethod
    def backward(ctx: Any, *grad_outputs: Any) -> Any:
        grad_output: torch.Tensor = grad_outputs[0]
        grad_result = None
        grad_data: torch.Tensor | None = None
        grad_input_args = None
        grad_dropouts = None
        grad_scalings = None
        grad_loras: Tuple[torch.Tensor | None, ...] = ()

        data, *loras = ctx.saved_tensors

        if ctx.needs_input_grad[0]:
            grad_result = grad_output
        if ctx.needs_input_grad[1]:
            grad_data = executor.init_tensor(data)

        # the lora module is fp32 precision
        grad_output = grad_output.to(torch.float32)
        lora_range = get_range_tensor(
            grad_output.device, batch_size=grad_output.shape[0]
        )
        for lora_a, lora_b, drop_data, dropout, scaling, lora_config in zip(
            loras[::3],
            loras[1::3],
            loras[2::3],
            ctx.dropouts,
            ctx.scalings,
            ctx.input_args.batch_configs_,
        ):
            start_idx = lora_config.batch_start_idx_
            end_idx = lora_config.batch_end_idx_
            assert not ((lora_a is None) ^ (lora_b is None))
            if lora_a is None and lora_b is None:
                grad_loras += (None, None)
                if grad_data is not None:
                    executor.index_fill(grad_data, 0, lora_range[start_idx:end_idx], 0)
                continue

            # lora_data shape is batch_size * seq_len * in_dim
            lora_data = data[start_idx:end_idx]
            # grad_y shape is batch_size * seq_len * out_dim
            grad_y = grad_output[start_idx:end_idx]

            # drop_data shape is batch_size * seq_len * r

            # bstage shape is batch_size * seq_len * r
            bstage = grad_y @ lora_b
            bstage *= scaling / (1 - dropout)

            grad_a = torch.sum(bstage.transpose(1, 2) @ lora_data, dim=0)
            grad_b = torch.sum(grad_y.transpose(1, 2) @ drop_data, dim=0)
            grad_loras += (grad_a, grad_b)

            # grad_data shape is batch_size * seq_len * in_dim
            if grad_data is not None:
                grad_x = bstage @ lora_a
                executor.index_copy(grad_data, 0, lora_range[start_idx:end_idx], grad_x)

        return (
            grad_result,
            grad_data,
            grad_input_args,
            grad_dropouts,
            grad_scalings,
            *grad_loras,
        )


class Lora(nn.Module):
    def __init__(
        self,
        base_layer: nn.Module,
        shape: Tuple[int, int],
        config: LoraConfig,
        device: str,
    ):

        super().__init__()

        self.base_layer_ = base_layer
        self.device_ = torch.device(device)

        self.initializer_ = config.lora_init_
        self.r_ = config.lora_r_
        self.alpha_ = config.lora_alpha_

        if config.use_rslora_:
            self.scaling_ = self.alpha_ / math.sqrt(self.r_)
        else:
            self.scaling_ = self.alpha_ / self.r_

        self.in_features_, self.out_features_ = shape

        assert config.lora_dropout_ > 0.0
        self.dropout_ = nn.Dropout(p=config.lora_dropout_)

        self.lora_a_ = nn.Linear(
            self.in_features_,
            self.r_,
            bias=False,
            dtype=torch.float32,
            device=self.device_,
        )
        self.lora_b_ = nn.Linear(
            self.r_,
            self.out_features_,
            bias=False,
            dtype=torch.float32,
            device=self.device_,
        )

        self.use_dora_: bool = config.use_dora_
        self.magnitude_vector_: nn.Parameter = None

    def _get_weight_norm(self, dtype: torch.dtype = torch.float32) -> torch.Tensor:
        # calculate L2 norm of weight matrix, column-wise
        weight = dequantize_module_weight(self.base_layer_).to(dtype)
        lora_weight = self.lora_b_.weight @ self.lora_a_.weight
        weight = weight + self.scaling_ * lora_weight
        weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
        return weight_norm

    def reset_parameters(self, lora_tensor=(None, None)) -> None:
        # if the lora_tensor is not (None, None), use it to init the lora weight
        assert isinstance(lora_tensor, Tuple)
        assert len(lora_tensor) == 2
        assert ((lora_tensor[0] is None) and (lora_tensor[1] is None)) or (
            isinstance(lora_tensor[0], torch.Tensor)
            and isinstance(lora_tensor[1], torch.Tensor)
        )

        if lora_tensor == (None, None):
            if self.initializer_ == "original":
                nn.init.kaiming_uniform_(self.lora_a_.weight, a=math.sqrt(5))
            elif self.initializer_ == "gaussian":
                nn.init.normal_(self.lora_a_.weight, std=1 / self.r_)
            else:
                raise ValueError(f"Unknown initialization {self.initializer_}")
            nn.init.zeros_(self.lora_b_.weight)
        else:
            with torch.no_grad():
                self.lora_a_.weight.copy_(lora_tensor[0])
                self.lora_b_.weight.copy_(lora_tensor[1])

        if self.use_dora_:
            self.magnitude_vector_ = nn.Parameter(
                self._get_weight_norm(), requires_grad=True
            )

    def apply_dora(
        self,
        residual: torch.Tensor,
        result_lora: torch.Tensor,
    ):
        weight_norm = self._get_weight_norm().detach()
        mag_norm_scale = (self.magnitude_vector_ / weight_norm).view(1, -1)
        return mag_norm_scale * residual + mag_norm_scale * result_lora

    def lora_forward(self, hidden_states: torch.Tensor):
        # mixlora ffn
        return (
            self.lora_b_(self.lora_a_(self.dropout_(hidden_states.to(torch.float32))))
            * self.scaling_
        )

    def forward(
        self,
        residual: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # mixlora ffn
        result_lora = self.lora_forward(hidden_states)
        if self.use_dora_:
            return self.apply_dora(residual, result_lora).to(residual.dtype)
        else:
            return residual + result_lora.to(residual.dtype)

class MoORE(nn.Module):
    def __init__(
        self,
        base_layer: nn.Module,
        shape: Tuple[int, int],
        config: LoraConfig,
        device: str,
    ):
        super().__init__()

        self.base_layer_ = base_layer
        self.r_ = config.lora_r_
        self.in_features_, self.out_features_ = shape
        self.num_experts_ = config.num_experts_
        self.routing_strategy_ = config.routing_strategy_
        self.task_embedding_dim_ = config.task_embedding_dim_
        self.output_router_logits_ = config.output_router_logits_
        self.dtype = self.base_layer_.weight.dtype
        # self.dtype = torch.float32
        
        self.dim_S = min(self.out_features_, self.in_features_)
        U, S, Vh = torch.linalg.svd(self.base_layer_.weight.data.float(), full_matrices=False)
        
        task_num = len(TASK_IDS)
        self.task_embedding = nn.Embedding(task_num, self.task_embedding_dim_, device=device, dtype=self.dtype)
        self.task_linear = nn.Linear(self.task_embedding_dim_, self.dim_S, bias=False, device=device, dtype=self.dtype)
        self.task_linear.weight.data = self.task_linear.weight.data.new_zeros((self.dim_S, self.task_embedding_dim_))
        self.gate = nn.Linear(self.in_features_, self.num_experts_, bias=False, device=device, dtype=self.dtype)
        nn.init.kaiming_uniform_(self.gate.weight, a=math.sqrt(5))
        self.gate_up = nn.Linear(self.num_experts_, self.dim_S, bias=False, device=device, dtype=self.dtype)
        self.gate_up.weight.data = self.gate_up.weight.data.new_zeros((self.dim_S, self.num_experts_))
            
        self.svd_U = nn.Parameter(torch.zeros(self.dim_S, self.out_features_, device=device, dtype=self.dtype))
        self.svd_S = nn.Parameter(torch.zeros(self.dim_S, device=device, dtype=self.dtype))
        self.svd_Vh = nn.Parameter(torch.zeros(self.in_features_, self.dim_S, device=device, dtype=self.dtype))
        
        self.svd_U.data = self.svd_U.data + U.to(self.dtype).permute(1, 0)[:self.dim_S, :]
        self.svd_S.data = self.svd_S.data + S.to(self.dtype)[:self.dim_S]
        self.svd_Vh.data = self.svd_Vh.data + Vh.to(self.dtype).permute(1, 0)[:, :self.dim_S]

        half_u_in = torch.zeros(self.in_features_, self.r_//2, device=device, dtype=torch.float32)
        nn.init.kaiming_uniform_(half_u_in, a=math.sqrt(5))
        hra_u_in = torch.repeat_interleave(half_u_in, 2, dim=1)
        self.hra_u_in = nn.ParameterList([nn.Parameter(hra_u_in[:, i].view(-1, 1), requires_grad=True) for i in range(self.total_r_)])

        self.svd_U.requires_grad = False
        self.svd_S.requires_grad = False
        self.svd_Vh.requires_grad = False

        self.router_profile_ = False
        self.profiler = {}
    
    def reset_parameters(self, moore_tensor=(None, None, None)) -> None:
        # if the lora_tensor is not (None, None), use it to init the lora weight
        assert isinstance(moore_tensor, Tuple)
        assert len(moore_tensor) == 3
        assert ( 
            (moore_tensor[0] is None) and
            (moore_tensor[1] is None) and 
            (moore_tensor[2] is None)) or (
            isinstance(moore_tensor[0], List) and 
            isinstance(moore_tensor[1], List) and 
            isinstance(moore_tensor[2], List)
        )

        if moore_tensor != (None, None, None):
            with torch.no_grad():
                self.gate.weight.copy_(moore_tensor[0][0])
                self.task_embedding.weight.copy_(moore_tensor[0][2])
                self.task_linear.weight.copy_(moore_tensor[0][3])
                self.gate_up.weight.copy_(moore_tensor[0][4])
                for i in range(self.r_):
                    self.hra_u_in[i].copy_(moore_tensor[1][i])

    def _profiling(self, routing_weights):
        if not self.router_profile_:
            return
        if 'routing_weights' in self.profiler:
            new_routing_weights = routing_weights.abs().mean(1).cpu().to(torch.float).numpy()
            self.profiler['routing_weights'] = np.concatenate((self.profiler['routing_weights'], new_routing_weights),axis=0)
        else:
            self.profiler['routing_weights'] = routing_weights.abs().mean(1).cpu().to(torch.float).numpy()

    def forward(
        self,
        hidden_states: torch.Tensor,
        task_ids: torch.Tensor = None,
    ) -> torch.Tensor:
        dtype = hidden_states.dtype
        hidden_states = hidden_states.to(self.dtype)
        result = hidden_states

        result = result.to(torch.float32)
        for i in range(self.r_):
            ui = self.hra_u_in[i] / self.hra_u_in[i].norm()
            result = result - 2 * (result @ ui) @ ui.t()
        result = result.to(self.dtype)
        result = result @ self.svd_Vh

        task_embedding = self.task_embedding(task_ids) # B x task_embedding_dim
        task_embedding = self.task_linear(task_embedding) # B x dim_S
        router_logits = self.gate(hidden_states) # B x L x num_experts
        routing_weights = self.gate_up(router_logits) # B x L x dim_S
        # self._profiling(task_embedding.unsqueeze(1) + routing_weights)
        routing_weights = task_embedding.unsqueeze(1) + routing_weights + self.svd_S.data
        routing_weights = F.relu(routing_weights)

        # self._profiling(routing_weights)

        result = result * routing_weights
        result = result @ self.svd_U

        if not self.output_router_logits_:
            router_logits = []
        else:
            router_logits = [router_logits.flatten(start_dim=0, end_dim=1)]

        return result.to(dtype), router_logits

class Linear(nn.Module):
    def __init__(self, base_layer: nn.Module, device: str):
        super().__init__()

        if not isinstance(base_layer, nn.Linear):
            assert isinstance(base_layer, Linear8bitLt) or isinstance(
                base_layer, Linear4bit
            ), f"error type - {type(base_layer)}."
        else:
            base_layer.requires_grad_(False)

        self.device_ = torch.device(device)
        self.base_layer_ = base_layer.to(self.device_)
        self.loras_: Dict[str, Lora] = {}
        self.moes_: Dict[str, LLMMoeBlock] = {}
        self.svds_: Dict[str, MoORE] = {}
        self.adapter_name = None

        if isinstance(self.base_layer_, Linear4bit):
            self.out_features_, self.in_features_ = (
                self.base_layer_.out_features,
                self.base_layer_.in_features,
            )
        else:
            self.out_features_, self.in_features_ = self.base_layer_.weight.shape

    def init_lora_weight(
        self, lora_config: LoraConfig, lora_tensor=(None, None), adapter_name=None
    ):
        if adapter_name is None:
            adapter_name = lora_config.adapter_name

        if adapter_name not in self.loras_:
            self.loras_[adapter_name] = Lora(
                self.base_layer_,
                (self.in_features_, self.out_features_),
                lora_config,
                self.device_,
            )

        self.loras_[adapter_name].reset_parameters(lora_tensor)
    
    def init_svd_weight(
        self, lora_config: LoraConfig, lora_tensor=(None, None, None), adapter_name=None
    ):
        if adapter_name is None:
            adapter_name = lora_config.adapter_name
            self.adapter_name = adapter_name

        if adapter_name not in self.svds_:
            self.svds_[adapter_name] = MoORE(
                self.base_layer_,
                (self.in_features_, self.out_features_),
                lora_config,
                self.device_,
            )

        self.svds_[adapter_name].reset_parameters(lora_tensor)

    def _appy_dora(
        self,
        residual: torch.Tensor,
        lora_delta: torch.Tensor,
        input_args: LLMModelInput,
    ):
        next_states = executor.init_tensor(residual)
        lora_range = get_range_tensor(
            next_states.device, batch_size=next_states.shape[0]
        )
        for lora_config in input_args.batch_configs_:
            adapter_name = lora_config.adapter_name_
            start_idx = lora_config.batch_start_idx_
            end_idx = lora_config.batch_end_idx_

            if adapter_name == "" or adapter_name not in self.loras_:
                continue

            if self.loras_[adapter_name].use_dora_:
                lora_data = self.loras_[adapter_name].apply_dora(
                    residual[start_idx:end_idx],
                    lora_delta[start_idx:end_idx],
                )
            else:
                lora_data = residual[start_idx:end_idx] + lora_delta[start_idx:end_idx]

            executor.index_copy(
                next_states, 0, lora_range[start_idx:end_idx], lora_data
            )

        return next_states

    def _efficient_impl(
        self, hidden_states: torch.Tensor, input_args: LLMModelInput
    ) -> torch.Tensor:
        # hidden_states shape is: batch_size * max_seq_len * dim
        # result = hidden_states @ self.weight_.transpose(0, 1)
        residual = self.base_layer_.forward(hidden_states.to(self.base_layer_.weight.dtype))

        if len(self.loras_) == 0:
            return residual

        # split the data and result
        dropouts: List[float] = []
        scalings: List[float] = []
        loras: Tuple[torch.Tensor] = ()
        for lora_config in input_args.batch_configs_:
            adapter_name = lora_config.adapter_name_

            if adapter_name not in self.loras_:
                loras += (None, None)
                dropouts.append(None)
                scalings.append(None)
                continue

            loras += (
                self.loras_[adapter_name].lora_a_.weight,
                self.loras_[adapter_name].lora_b_.weight,
            )
            dropouts.append(self.loras_[adapter_name].dropout_.p)
            scalings.append(self.loras_[adapter_name].scaling_)

        have_dora = any(lora.use_dora_ for lora in self.loras_.values())

        if have_dora:
            lora_delta = torch.zeros_like(residual, dtype=torch.float32)
            lora_delta = LoraFunction.apply(
                lora_delta,
                hidden_states.to(torch.float32),
                input_args,
                dropouts,
                scalings,
                *loras,
            )
            next_states = self._appy_dora(
                residual.to(torch.float32), lora_delta, input_args
            )
        else:
            next_states = LoraFunction.apply(
                residual.to(torch.float32),
                hidden_states.to(torch.float32),
                input_args,
                dropouts,
                scalings,
                *loras,
            )

        return next_states.to(hidden_states.dtype)

    def _compatible_impl(
        self, hidden_states: torch.Tensor, input_args: LLMModelInput
    ) -> torch.Tensor:
        # hidden_states shape is: batch_size * max_seq_len * dim
        # result = hidden_states @ self.weight_.transpose(0, 1)
        residual = self.base_layer_.forward(hidden_states)

        if len(self.loras_) == 0:
            return residual

        next_states = executor.init_tensor(residual)
        lora_range = get_range_tensor(hidden_states.device, hidden_states.shape[0])

        for lora_config in input_args.batch_configs_:
            adapter_name = lora_config.adapter_name_
            start_idx = lora_config.batch_start_idx_
            end_idx = lora_config.batch_end_idx_

            if adapter_name in self.loras_:
                fwd_fn = self.loras_[adapter_name].forward
                kwargs = {}
            elif adapter_name in self.moes_:
                fwd_fn = self.moes_[adapter_name].forward
                kwargs = {"lora_linear": self}
            else:
                executor.index_copy(
                    next_states,
                    0,
                    lora_range[start_idx:end_idx],
                    residual[start_idx:end_idx],
                )
                continue

            lora_data = fwd_fn(
                residual=residual[start_idx:end_idx],
                hidden_states=hidden_states[start_idx:end_idx],
                **kwargs,
            )
            executor.index_copy(
                next_states, 0, lora_range[start_idx:end_idx], lora_data
            )

        return next_states

    def forward(
        self, 
        hidden_states: torch.Tensor, 
        input_args: LLMModelInput,
        task_ids: torch.Tensor = None,
    ) -> torch.Tensor:
        if len(self.svds_) != 0:
            return self.svds_[self.adapter_name].forward(hidden_states, task_ids)
        else:
            if input_args.efficient_operator_ and len(self.moes_) == 0:
                return self._efficient_impl(hidden_states, input_args), []
            else:
                return self._compatible_impl(hidden_states, input_args), []
