from typing import Any, Dict, List

from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
from clu import metrics
from flax import struct
import jax.tree_util

import dihedral
import DFT
from utils import compute_pytree_size
import training

from transformer_class import TransformerOneEmbed, TransformerTwoEmbed


Params = Dict[str, Any]


@struct.dataclass
class Metrics(metrics.Collection):
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output("loss")
    l2_loss: metrics.Average.from_output("l2_loss")


def logits_last_token(logits_3d: jnp.ndarray) -> jnp.ndarray:
    return logits_3d[:, -1, :]


def cross_entropy_loss(y_pred_3d, y):
    return optax.softmax_cross_entropy_with_integer_labels(
        logits=logits_last_token(y_pred_3d),
        labels=y,
    ).mean()


def total_loss(y_pred_and_l2, y, weight_decay: float):
    y_pred, _, l2_loss = y_pred_and_l2
    return cross_entropy_loss(y_pred, y) + l2_loss * weight_decay


def apply_fn_builder(model):
    def apply(variables, x, training=False):
        params = variables["params"]
        outputs = model.apply({"params": params}, x, training=training)
        l2_loss = sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
        return outputs, {}, l2_loss

    return apply


def compute_metrics(metrics_obj, *, loss, l2_loss, outputs, labels):
    if isinstance(outputs, (tuple, list)):
        outputs = outputs[0]

    logits_2d = logits_last_token(outputs)
    metric_updates = metrics_obj.single_from_model_output(
        logits=logits_2d,
        labels=labels,
        loss=loss,
        l2_loss=l2_loss,
    )
    return metrics_obj.merge(metric_updates)


def build_model(cfg, group_size: int):
    assert cfg.d_head * cfg.num_heads == cfg.d_model
    return TransformerOneEmbed(
        num_layers=1,
        num_mlp_layers=cfg.num_mlp_layers,
        d_vocab=group_size,
        d_model=cfg.d_model,
        d_head=cfg.d_head,
        num_heads=cfg.num_heads,
        n_ctx=cfg.n_ctx,
        act_type=cfg.act_type,
        attn_coeff=cfg.attn_coeff,
        nn_multiplier=cfg.nn_multiplier,
    )


def build_optimizer(
    optimizer_name: str,
    lr: float,
    *,
    steps_per_epoch: int,
    epochs: int,
    momentum: float = 0.0,
):
    total_steps = max(1, steps_per_epoch * epochs)
    warmup_steps = total_steps // 2
    cooldown_steps = total_steps - warmup_steps

    def lr_schedule_fn(step):
        step = jnp.asarray(step, jnp.float32)
        warmup_f = float(warmup_steps)
        cooldown_f = float(cooldown_steps)

        def warmup_fn(step_):
            return lr * (step_ / warmup_f)

        def cooldown_fn(step_):
            return lr * (1.0 - (step_ - warmup_f) / cooldown_f)

        return jax.lax.cond(
            step < warmup_f,
            warmup_fn,
            cooldown_fn,
            step,
        )

    if optimizer_name == "adam":
        return optax.adam(lr_schedule_fn)
    if optimizer_name.startswith("SGD"):
        return optax.sgd(lr, momentum=momentum)
    raise ValueError(f"Unsupported optimizer type: {optimizer_name}")


def create_states(model, tx, weight_decay: float, batch_size: int, random_seed_ints: List[int]):
    dummy_x = jnp.zeros((batch_size, 2), dtype=jnp.int32)
    variables_list = [model.init(jax.random.PRNGKey(seed), dummy_x, training=False) for seed in random_seed_ints]
    params_batch = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *(v["params"] for v in variables_list))

    def init_opt(p):
        return tx.init(p)

    opt_states = []
    for i in range(len(random_seed_ints)):
        params_i = jax.tree_util.tree_map(lambda x: x[i], params_batch)
        opt_states.append(init_opt(params_i))
    opt_state_batch = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *opt_states)

    apply = apply_fn_builder(model)

    def loss_hessian_zero(*_):
        return (0.0, 0.0, 0.0)

    def loss_fn(y_pred_and_l2, y):
        return total_loss(y_pred_and_l2, y, weight_decay)

    states_list = []
    for i, seed in enumerate(random_seed_ints):
        rng_key = jax.random.PRNGKey(seed)
        params_i = jax.tree_util.tree_map(lambda x: x[i], params_batch)
        opt_state_i = jax.tree_util.tree_map(lambda x: x[i], opt_state_batch)
        state = training.TrainState(
            apply_fn=apply,
            params=params_i,
            tx=tx,
            opt_state=opt_state_i,
            loss_fn=loss_fn,
            loss_hessian_fn=loss_hessian_zero,
            compute_metrics_fn=compute_metrics,
            rng_key=rng_key,
            initial_metrics=Metrics,
            batch_stats=None,
            injected_noise=0.0,
        )
        states_list.append(state)

    states = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *states_list)
    init_metrics = jax.tree_util.tree_map(lambda *args: jnp.stack(args), *[s.initial_metrics.empty() for s in states_list])
    return states, init_metrics
