import copy
from dataclasses import dataclass
from typing import Dict

from moe_peft.common import LoraConfig


@dataclass
class LoraMoeConfig(LoraConfig):
    num_experts_: int = None
    topk: int = None
    router_init_range_: float = None
    routing_strategy_: str = "loramoe"
    router_loss_: bool = True
    router_aux_loss_coef_: float = 0.0
    shared_experts_: int = 0

    def check(self) -> "LoraMoeConfig":
        super().check()
        assert isinstance(self.num_experts_, int) and self.num_experts_ > 0
        assert (
            isinstance(self.router_init_range_, float) and self.router_init_range_ >= 0
        )
        assert self.shared_experts_ <= self.num_experts_ - self.topk, "shared experts should less than (N-K)"

        return self

    @staticmethod
    def from_config(config: Dict[str, any]) -> "LoraMoeConfig":
        return LoraMoeConfig(
            topk=config.get("topk",0),
            shared_experts_=config.get("shared_experts",0),
            router_loss_=config.get("router_loss_", False),
            router_aux_loss_coef_=config.get("router_aux_loss_coef_", 0.001),
            num_experts_=config["num_experts"],
            router_init_range_=config.get("router_init_range", 5.0),
            **LoraConfig.from_config(config).__dict__,
        )

    def export(self) -> Dict[str, any]:
        config = super().export()
        config["peft_type"] = "LORAMOE"
        config["routing_strategy"] = self.routing_strategy_
        config["num_experts"] = self.num_experts_

        return config

    def expert_config(self, expert_idx: int) -> LoraConfig:
        config = copy.deepcopy(super())
        config.adapter_name = f"moe.{self.adapter_name}.experts.{expert_idx}"
        return config

@dataclass
class M2LoRAConfig(LoraConfig):
    num_experts_: int = None
    topk: int = None
    router_init_range_: float = None
    routing_strategy_: str = "m2lora"
    router_loss_: bool = True
    router_aux_loss_coef_: float = 0.0

    def check(self) -> "M2LoRAConfig":
        super().check()
        assert isinstance(self.num_experts_, int) and self.num_experts_ > 0
        assert (
            isinstance(self.router_init_range_, float) and self.router_init_range_ >= 0
        )

        return self

    @staticmethod
    def from_config(config: Dict[str, any]) -> "M2LoRAConfig":
        router_aux_loss_coef_ = config.get(
            "router_aux_loss_coef", 0.001
        )  # for training
        return M2LoRAConfig(
            num_experts_=config["num_experts"],
            topk = config["topk"],
            router_aux_loss_coef_ = router_aux_loss_coef_,
            router_init_range_=config.get("router_init_range", 5.0),
            **LoraConfig.from_config(config).__dict__,
        )

    def export(self) -> Dict[str, any]:
        config = super().export()
        config["peft_type"] = "M2LORA"
        config["routing_strategy"] = self.routing_strategy_
        config["num_experts"] = self.num_experts_

        return config

    def expert_config(self, expert_idx: int) -> LoraConfig:
        config = copy.deepcopy(super())
        config.adapter_name = f"moe.{self.adapter_name}.experts.{expert_idx}"
        return config