from typing import Mapping, Dict, List, Optional
import torch
from custom_models.kernels.bwell_gated_mlp import (
    BwellMLPv0,
    BwellMLPv01,
    BwellMLPBase
)

from custom_models.kernels.bwell_nongated_mlp import (
    BwellMLPNGv0,
)

def sparse_to_bwell_state_dict(
    state_dict: Mapping[str, torch.Tensor],
    drop_bias: bool = True,
) -> Dict[str, torch.Tensor]:
    out: Dict[str, torch.Tensor] = {}

    for k, v in state_dict.items():
        if k.endswith("gate_proj.weight"):
            p = k[: -len("gate_proj.weight")]
            out[p + "gate_weight"] = v.T.contiguous()
            continue

        if k.endswith("up_proj.weight"):
            p = k[: -len("up_proj.weight")]
            out[p + "up_weight"] = v.contiguous()
            continue

        if k.endswith("down_proj.weight"):
            p = k[: -len("down_proj.weight")]
            out[p + "down_weight"] = v.T.contiguous()
            continue

        if drop_bias and (
            k.endswith("gate_proj.bias")
            or k.endswith("up_proj.bias")
            or k.endswith("down_proj.bias")
        ):
            continue

        out[k] = v

    return out


def _try_get_layer_idx(k: str) -> Optional[int]:
    token = "layers."
    i = k.find(token)
    if i < 0:
        return None

    j = i + len(token)
    end = j
    while end < len(k) and k[end].isdigit():
        end += 1

    if end == j:
        return None
    if end < len(k) and k[end] != ".":
        return None

    return int(k[j:end])


def sparse_to_bwell_state_dict(
    state_dict: Mapping[str, torch.Tensor],
    drop_bias: bool = True,
    layer_numbers: Optional[List[int]] = None,
    gated: bool = True,
) -> Dict[str, torch.Tensor]:
    out: Dict[str, torch.Tensor] = {}

    layer_numbers_set = None
    if layer_numbers is not None:
        layer_numbers_set = set(int(x) for x in layer_numbers)

    for k, v in state_dict.items():
        layer_idx = _try_get_layer_idx(k)
        convert_this = (
            layer_numbers_set is None
            or (layer_idx is not None and layer_idx in layer_numbers_set)
        )

        if k.endswith("gate_proj.weight"):
            if not gated:
                raise ValueError(
                    "Attempting to convert gate_proj.weight "
                    "but gated is set to False"
                )
            if convert_this:
                p = k[: -len("gate_proj.weight")]
                out[p + "gate_weight"] = v.T.contiguous()
            else:
                out[k] = v
            continue

        if k.endswith("up_proj.weight"):
            if convert_this:
                p = k[: -len("up_proj.weight")]
                if not gated:
                    out[p + "up_weight"] = v.T.contiguous()
                else:
                    out[p + "up_weight"] = v.contiguous()
            else:
                out[k] = v
            continue

        if k.endswith("down_proj.weight"):
            if convert_this:
                p = k[: -len("down_proj.weight")]
                out[p + "down_weight"] = v.T.contiguous()
            else:
                out[k] = v
            continue

        if (
            k.endswith("gate_proj.bias")
            or k.endswith("up_proj.bias")
            or k.endswith("down_proj.bias")
        ):
            if drop_bias and convert_this:
                continue
            out[k] = v
            continue

        out[k] = v

    return out



def get_bwell_mlp_class(version: str):
    if version == "bwell_v0":
        return BwellMLPv0
    elif version == "bwell_v01":
        return BwellMLPv01
    elif version == "bwell_nongated_v0":
        return BwellMLPNGv0
    else:
        raise ValueError(f"Unknown Bwell model version: {version}")

