import json
import os
from typing import Callable

import jax
from flax.core import FrozenDict

act_fns = {"silu": jax.nn.silu}
inv_act_fns = {v: k for k, v in act_fns.items()}

molecular_configs = {
    "small": {
        "num_layers": 6,
        "num_heads": 8,
        "num_kv_heads": 8,
        "hidden_size": 512,
        "intermediate_size": 4 * 512,
        "act_fn": jax.nn.silu,
        "norm_eps": 1e-6,
        "io_tying": True,
    },
    "medium": {
        "num_layers": 12,
        "num_heads": 12,
        "num_kv_heads": 12,
        "hidden_size": 768,
        "intermediate_size": 4 * 768,
        "act_fn": jax.nn.silu,
        "norm_eps": 1e-6,
        "io_tying": True,
    },
    "large": {
        "num_layers": 24,
        "num_heads": 16,
        "num_kv_heads": 16,
        "hidden_size": 1024,
        "intermediate_size": 4 * 1024,
        "act_fn": jax.nn.silu,
        "norm_eps": 1e-6,
        "io_tying": True,
    },
}

sizes = list(molecular_configs.keys())


def get_config_for(size, vocab_size, rope_base=100_000):
    if size not in sizes:
        raise ValueError(f"Invalid size! Supported: {','.join(sizes)}")

    cfg = molecular_configs[size]
    cfg.update({"vocab_size": vocab_size, "rope_base": rope_base})

    return FrozenDict(cfg)


def extend_llama(base_config):
    return FrozenDict(
        {
            **base_config,
            "norm_eps": 1e-6,
            "norm_convert_w": False,
            "norm_w_bias": 0.0,
            "pre_ffn_norm": False,
            "post_ffn_norm": False,
        }
    )


def save_to_dir(config, dirname, filename="model_config.json"):
    config = config.copy()
    assert os.path.isdir(dirname), "Not a directory"
    if isinstance(config, FrozenDict):
        config = config._dict
    if isinstance(config["act_fn"], Callable):
        config["act_fn"] = inv_act_fns[config["act_fn"]]

    with open(f"{dirname}/{filename}", "w") as fp:
        json.dump(config, fp, indent=2)


def load_from_dir(dirname, filename="model_config.json"):
    with open(f"{dirname}/{filename}") as fp:
        config = json.load(fp)
        config["act_fn"] = act_fns.get(config["act_fn"], jax.nn.silu)
        return FrozenDict(config)
