from typing import Any, Dict, Literal, Optional
import flax
import jax
import optax
import numpy as np

from multinav.data.dataset import DatasetConfig, make_dataset
from multinav.model.heads import Head
from multinav.model.model_base import MultiNavModel

from .train_state import TrainState


def make_single_optimizer(
    *,
    optim_type: Literal["sgd", "adam"],
    learning_rate: optax.Schedule | float,
    weight_decay: Optional[float],
    clip_grad_norm: Optional[float],
    learning_rate_scale: Optional[float],
):
    stages = []
    if clip_grad_norm is not None:
        stages.append(optax.clip_by_global_norm(clip_grad_norm))

    if optim_type == "sgd":
        if weight_decay is not None:
            stages.append(optax.add_decayed_weights(weight_decay))
        stages.append(optax.sgd(learning_rate))
    elif optim_type == "adam":
        if weight_decay is not None:
            stages.append(
                optax.adamw(learning_rate=learning_rate, weight_decay=weight_decay)
            )
        else:
            stages.append(optax.adam(learning_rate=learning_rate))

    if isinstance(learning_rate_scale, float):
        stages.append(optax.scale(learning_rate_scale))
    elif callable(learning_rate_scale):
        stages.append(optax.scale_by_schedule(learning_rate_scale))
    elif learning_rate_scale is not None:
        raise ValueError(
            f"Invalid learning_rate_scale: {learning_rate_scale} ({type(learning_rate_scale)})"
        )

    return optax.chain(*stages)

@optax.inject_hyperparams
def make_optimizer(
    *,
    config: MultiNavModel.Config,
    learning_rate: optax.Schedule | float,
    weight_decay: Optional[float],
    clip_grad_norm: Optional[float],
    params: Dict[str, jax.Array],
):
    optimizer_kwargs = {
        "optim_type": config.optimizer,
        "learning_rate": learning_rate,
        "weight_decay": weight_decay,
        "clip_grad_norm": clip_grad_norm,
        "learning_rate_scale": config.base_learning_rate,
    }

    base_optimizer = make_single_optimizer(**optimizer_kwargs)

    optimizers = {"base": base_optimizer}
    labels = jax.tree_map(lambda _: "base", params)

    for head_name, head_config in config.head_configs.items():
        head_key = f"head_{head_name}"
        head_optimizers, head_labels = head_config.make_optimizer(
            params[head_key],
            base_optimizer_name="base",
            optimizer_kwargs=optimizer_kwargs,
        )

        for label, optimizer in head_optimizers.items():
            optimizers[f"{head_key}_{label}"] = optimizer

        def relabel_with_head(label):
            if label == "base":
                return "base"
            else:
                return f"{head_key}_{label}"

        labels[head_key] = jax.tree_map(relabel_with_head, head_labels)

    return optax.multi_transform(transforms=optimizers, param_labels=labels)


def make_model_and_dataset(
    model_config: MultiNavModel.Config,
    data_config: DatasetConfig,
    batch_size: int,
    epochs: int,
    eval_interval: int,
    device_list: list,
    rng: jax.Array,
):
    model: MultiNavModel = MultiNavModel(config=model_config)

    train_dataset, val_datasets = make_dataset(
        data_config,
        batch_size=batch_size,
        num_steps_predict=model.batch_predict_horizon,
        history_size=model.batch_seq_len,
    )

    params = jax.jit(model.init, static_argnames=["train", "method"])(
        {"params": rng, "sample": jax.random.PRNGKey(0)},
        jax.tree_map(lambda x: x._numpy(), next(iter(train_dataset))),
        train=False,
        step=0,
        method="loss",
    )["params"]

    print(
        "Total param count:",
        sum(int(np.prod(x.shape)) for x in jax.tree_util.tree_leaves(params)),
    )
    for k, v in params.items():
        print(f"\t{k}: {sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(v))}")

    train_state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        target_params=params,
        tx=make_optimizer(
            config=model_config,
            learning_rate=optax.warmup_cosine_decay_schedule(
                init_value=0,
                peak_value=1,
                warmup_steps=model_config.warmup_steps,
                decay_steps=epochs * eval_interval,
            ),
            weight_decay=model_config.weight_decay,
            clip_grad_norm=model_config.clip_grad_norm,
            params=params,
        ),
    )

    train_state = flax.jax_utils.replicate(train_state, devices=device_list)

    return train_state, train_dataset, val_datasets
