import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import is_bitsandbytes_available

from moe_peft.executors import executor

from .abstracts import LLMMoeBlock
from .config import LLMModelInput, LoraConfig

if is_bitsandbytes_available():
    import bitsandbytes as bnb
    from bitsandbytes.nn import Linear4bit, Linear8bitLt
else:
    from moe_peft.utils import Linear8bitLt, Linear4bit

from typing import Any, Dict, List, Tuple

from einops import einsum, rearrange
from opt_einsum import get_symbol, contract_expression


class MPO(nn.Module):
    def __init__(
        self,
        base_layer: nn.Module,
        shape: Tuple[int, int],
        config: LoraConfig,
        device: str,
        merged=False,
        central_cores=None,
    ):

        super().__init__()

        self.base_layer_ = base_layer
        self.device_ = torch.device(device)

        self.initializer_ = config.lora_init_
        self.r_ = config.mpo_r_
        self.alpha_ = config.lora_alpha_
        self.merged_ = merged

        self.in_features_, self.out_features_ = shape

        assert config.lora_dropout_ > 0.0
        self.dropout_ = nn.Dropout(p=config.lora_dropout_)

        self.mpo_rows = config.mpo_rows_
        self.mpo_cols = config.mpo_cols_

        if shape[0] < shape[1]:
            self.mpo_cols = config.mpo_rows_
            self.mpo_rows = config.mpo_cols_

        self.shared_indices = [len(config.mpo_rows_) // 2]

        if self.merged_:
            self.mpo_rows[self.shared_indices[0]] *= 2

        with torch.no_grad():
            shared_cores, specific_cores = self.initialize_shared_mpo(
                matrices=base_layer.weight,
            )

        if central_cores is not None:
            self.shared_cores = central_cores
            del shared_cores
        else:
            self.shared_cores = shared_cores
        self.specific_cores = specific_cores

        self._contraction_string = None
        self._expression_string = None

        self.mpo_expr = contract_expression(
            self.expression_string,
            *[c.shape for c in self.specific_cores[: self.shared_indices[0]]],
            self.shared_cores,
            *[c.shape for c in self.specific_cores[self.shared_indices[0] :]],
            constants=self.shared_indices,
        )

        [core.requires_grad_() for core in self.specific_cores]

    @property
    def contraction_string(self):
        if self._contraction_string is not None:
            return self._contraction_string
        else:
            d = len(self.mpo_cols)
            contraction_string = ""

            for c in range(d):
                contraction_string += f"d{c} a{c} b{c} d{c+1}"
                if c < d - 1:
                    contraction_string += ", "
            contraction_string += f" -> d0 {' '.join([f'a{c}' for c in range(d)])} {' '.join([f'b{c}' for c in range(d)])} d{d}"
            self._contraction_string = contraction_string
            return self._contraction_string

    @property
    def expression_string(self):
        if self._expression_string is not None:
            return self._expression_string
        else:
            d = len(self.mpo_cols)
            expression_string = ""

            for c in range(d):
                expression_string += f"{get_symbol(c)}{get_symbol(c+100)}{get_symbol(c+200)}{get_symbol(c+1)}"
                if c < d - 1:
                    expression_string += ","
            expression_string += f"->{get_symbol(0)}{''.join([get_symbol(c+100) for c in range(d)])}{''.join([get_symbol(c+200) for c in range(d)])}{get_symbol(d)}"
            self._expression_string = expression_string
            return self._expression_string

    def mpo_decomposition(self, matrix):
        assert len(self.mpo_rows) == len(
            self.mpo_cols
        ), "Row and column shapes must have same length"
        assert math.prod(self.mpo_rows) == matrix.shape[0]
        assert math.prod(self.mpo_cols) == matrix.shape[1]

        d = len(self.mpo_rows)
        tensor = matrix.reshape(*self.mpo_rows, *self.mpo_cols)
        tensor = rearrange(
            tensor,
            f"{' '.join([f'a{i}' for i in range(d)])} {' '.join([f'b{i}' for i in range(d)])} -> {' '.join([f'a{i} b{i}' for i in range(d)])}",
        )
        cores = []
        r = 1

        for k in range(d - 1):
            tensor = tensor.reshape(r * self.mpo_rows[k] * self.mpo_cols[k], -1)

            U, S, Vh = torch.linalg.svd(tensor.float(), full_matrices=False)
            rank = min(self.r_, S.size(0)) if self.r_ > 0 else S.size(0)

            U = U[:, :rank].to(tensor.dtype)
            S = S[:rank].to(tensor.dtype)
            Vh = Vh[:rank, :].to(tensor.dtype)

            core = U.reshape(r, self.mpo_rows[k], self.mpo_cols[k], rank)
            cores.append(core)

            tensor = S.unsqueeze(-1) * Vh
            r = rank

        tensor = tensor.reshape(r, self.mpo_rows[-1], self.mpo_cols[-1], 1)
        cores.append(tensor)

        return cores

    def contract_mpo(self, partial=False):
        mat = self.mpo_expr(*self.specific_cores)

        if not partial:
            mat = mat.reshape(
                math.prod(self.mpo_rows),
                math.prod(self.mpo_cols),
            )

        return mat

    def initialize_shared_mpo(self, matrices):
        shared_cores_list = []
        mpo_cores = self.mpo_decomposition(matrices)
        for i in self.shared_indices:
            shared_cores_list.append(mpo_cores.pop(i))
        specific_cores = mpo_cores

        for c in shared_cores_list:
            shared_cores = c

        return shared_cores, specific_cores

    def reset_parameters(self, lora_tensor=(None, None)) -> None:
        # if the lora_tensor is not (None, None), use it to init the lora weight
        assert isinstance(lora_tensor, Tuple)
        assert len(lora_tensor) == 2
        assert ((lora_tensor[0] is None) and (lora_tensor[1] is None)) or (
            isinstance(lora_tensor[0], torch.Tensor)
            and isinstance(lora_tensor[1], torch.Tensor)
        )

        if lora_tensor == (None, None):
            return
        else:
            with torch.no_grad():
                self.lora_a_.weight.copy_(lora_tensor[0])
                self.lora_b_.weight.copy_(lora_tensor[1])

    def mpo_forward(self, hidden_states: torch.Tensor):

        weight = self.contract_mpo()

        return F.linear(hidden_states, weight)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        result_opera = self.mpo_forward(hidden_states)
        return result_opera


class MLinear(nn.Module):
    def __init__(self, base_layer: nn.Module, device: str, merged=False):
        super().__init__()

        if not isinstance(base_layer, nn.Linear):
            assert isinstance(base_layer, Linear8bitLt) or isinstance(
                base_layer, Linear4bit
            ), f"error type - {type(base_layer)}."
        else:
            base_layer.requires_grad_(False)

        self.device_ = torch.device(device)
        self.base_layer_ = base_layer.to(self.device_)
        self.mpos_: Dict[str, MPO] = {}
        self.moes_: Dict[str, LLMMoeBlock] = {}
        self.merged_ = merged

        if isinstance(self.base_layer_, Linear4bit):
            self.out_features_, self.in_features_ = (
                self.base_layer_.out_features,
                self.base_layer_.in_features,
            )
        else:
            self.out_features_, self.in_features_ = self.base_layer_.weight.shape

    def init_mpo_weight(
        self,
        lora_config: LoraConfig,
        mpo_tensor=(None, None),
        adapter_name=None,
        central_cores=None,
    ):
        if adapter_name is None:
            adapter_name = lora_config.adapter_name

        if adapter_name not in self.mpos_:
            self.mpos_[adapter_name] = MPO(
                self.base_layer_,
                (self.in_features_, self.out_features_),
                lora_config,
                self.device_,
                self.merged_,
                central_cores,
            )

        self.mpos_[adapter_name].reset_parameters(mpo_tensor)

        return (
            central_cores
            if central_cores is not None
            else self.mpos_[adapter_name].shared_cores
        )

    def forward(
        self, hidden_states: torch.Tensor, input_args: LLMModelInput
    ) -> torch.Tensor:
        return F.linear(hidden_states, self.base_layer_.weight)
