from __future__ import annotations

import argparse
from functools import partial
from typing import Callable
import copy

import flax
import jax
import jax.numpy as jnp
import optax
from chex import ArrayTree, PRNGKey
from flax.training import train_state
from flax.training.common_utils import shard_prng_key
from jax.tree_util import tree_map_with_path

from modeling import ViT, Zoomer
from utils import (
    get_layer_index_fn,
    timm_to_flax,
)

from training_util import (
    CRITERION_COLLECTION,
    OPTIMIZER_COLLECTION,
)
from training import TrainModule as SupervisedTrainModule
from training_reinforce import TrainModule as ReinforceTrainModule


class TrainState(train_state.TrainState):
    mixup_rng: PRNGKey
    dropout_rng: PRNGKey
    patch_selection_rng: PRNGKey
    teacher_params: ArrayTree

    micro_step: int = 0
    micro_in_mini: int = 1
    grad_accum: ArrayTree | None = None

    def split_rngs(self) -> tuple[ArrayTree, ArrayTree]:
        mixup_rng, new_mixup_rng = jax.random.split(self.mixup_rng)
        dropout_rng, new_dropout_rng = jax.random.split(self.dropout_rng)
        patch_selection_rng, new_patch_selection_rng = jax.random.split(
            self.patch_selection_rng
        )

        rngs = {
            "mixup": mixup_rng,
            "dropout": dropout_rng,
            "patch_selection": patch_selection_rng,
        }
        updates = {
            "mixup_rng": new_mixup_rng,
            "dropout_rng": new_dropout_rng,
            "patch_selection_rng": new_patch_selection_rng,
        }
        return rngs, updates

    def replicate(self) -> TrainState:
        return flax.jax_utils.replicate(self).replace(
            mixup_rng=shard_prng_key(self.mixup_rng),
            dropout_rng=shard_prng_key(self.dropout_rng),
            patch_selection_rng=shard_prng_key(self.patch_selection_rng),
            teacher_params=flax.jax_utils.replicate(self.teacher_params),
        )


@partial(jax.pmap, axis_name="batch", donate_argnums=0)
def training_step(state: TrainState, batch: ArrayTree) -> tuple[TrainState, ArrayTree]:
    def loss_fn(params: ArrayTree) -> ArrayTree:
        metrics = state.apply_fn(
            {"params": params}, *batch, rngs=rngs, teacher_params=state.teacher_params
        )
        metrics = jax.tree_map(jnp.mean, metrics)
        return metrics["loss"], metrics

    def update_fn(state: TrainState) -> TrainState:
        # Collect a global gradient from the accumulated gradients and apply actual
        # parameter update with resetting the accumulations to zero.
        grads = jax.tree_map(lambda g: g / state.micro_in_mini, state.grad_accum)
        return state.apply_gradients(
            grads=jax.lax.pmean(grads, axis_name="batch"),
            grad_accum=jax.tree_map(jnp.zeros_like, state.grad_accum),
            micro_step=state.micro_step % state.micro_in_mini,
        )

    rngs, updates = state.split_rngs()
    (_, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    metrics = jax.lax.pmean(metrics, axis_name="batch")

    # Update parameters with the gradients. If the gradient accumulation is enabled,
    # then the parameters will be updated at the end of each mini-batch step. In every
    # micro steps, the gradients will be accumulated.
    if state.grad_accum is None:
        state = state.apply_gradients(grads=jax.lax.pmean(grads, axis_name="batch"))
    else:
        state = state.replace(
            grad_accum=jax.tree_map(lambda ga, g: ga + g, state.grad_accum, grads),
            micro_step=state.micro_step + 1,
        )
        state = jax.lax.cond(
            state.micro_step == state.micro_in_mini, update_fn, lambda x: x, state
        )
    return state.replace(**updates), metrics | state.opt_state.hyperparams


def create_train_state(args: argparse.Namespace) -> TrainState:
    ##### Create student and teacher ViTs #####
    student = ViT(
        layers=args.layers,
        dim=args.dim,
        heads=args.heads,
        num_registers=args.num_registers,
        labels=args.labels,
        patch_size=args.patch_size,
        image_size=args.image_size,
        args=args,
    )

    timm_params = timm_to_flax(args.timm_model_name)
    teacher_params = copy.deepcopy(timm_params)
    head_params = {
        "kernel": jax.random.normal(jax.random.PRNGKey(0), (768, 1000)) / jnp.sqrt(768),
        "bias": jnp.zeros(1000),
    }
    teacher_params["head"] = head_params

    ##### Create zoomer ViTs #####
    zoomer = Zoomer(
        layers=args.zoomer_depth,
        dim=args.dim,
        heads=args.heads,
        num_registers=args.num_registers,
        labels=args.labels,
        patch_size=args.patch_size,
        image_size=args.zoomer_image_size,
    )

    is_reinforcement = "reinforcement" in args.patch_selection_method
    ModuleFunction = ReinforceTrainModule if is_reinforcement else SupervisedTrainModule
    print("Using module function:", ModuleFunction)
    module = ModuleFunction(
        args=args,
        student=student,
        zoomer=zoomer,
        criterion=CRITERION_COLLECTION[args.criterion],
        zoom_map_criterion=CRITERION_COLLECTION[args.zoom_map_criterion],
    )

    example_inputs = {
        "images": jnp.zeros((1, 3, args.image_size, args.image_size), dtype=jnp.uint8),
        "labels": jnp.zeros((1), dtype=jnp.uint8),
        "k_patches": jnp.full(
            (1),
            args.top_k if args.top_k is not None else args.top_k_range.min,
            dtype=jnp.int32,
        ),
        "gumbel_temperature": jnp.full((1), 0.1, dtype=jnp.float32),
        "gumbel_noise_coeff": jnp.full((1), 0.1, dtype=jnp.float32),
        "rl_coeff": jnp.full((1), 0.1, dtype=jnp.float32),
        "kl_coeff": jnp.full((1), 0.1, dtype=jnp.float32),
        "teacher_params": teacher_params,
    }
    init_rngs = {"params": jax.random.PRNGKey(args.init_seed)}
    # print(module.tabulate(init_rngs, **example_inputs))

    init_params = module.init(init_rngs, **example_inputs)["params"]

    student_params = copy.deepcopy(timm_params)
    student_params["head"] = init_params["student"]["head"]

    zoomer_params = copy.deepcopy(timm_params)
    zoomer_params["head"] = init_params["zoomer"]["head"]
    zoomer_params["embed"]["wpe"] = jax.image.resize(
        zoomer_params["embed"]["wpe"],
        init_params["zoomer"]["embed"]["wpe"].shape,
        method="bicubic",
    )

    print(f"zoomer keys before removing: {zoomer_params.keys()}")
    keys_to_remove = []
    for key in list(zoomer_params.keys()):
        if "layer" in key:
            layer = int(key.split("_")[-1])
            if layer >= args.zoomer_depth:
                keys_to_remove.append(key)

    for key in keys_to_remove:
        del zoomer_params[key]

    print(f"zoomer keys after removing: {zoomer_params.keys()}")

    # Combine student and zoomer parameters into one dictionary.
    combined_params = {"student": student_params, "zoomer": zoomer_params}

    if args.grad_accum > 1:
        grad_accum = jax.tree_map(jnp.zeros_like, combined_params)

    # Create learning rate scheduler and optimizer with gradient clipping. The learning
    # rate will be recorded at `hyperparams` by `optax.inject_hyperparameters`.
    @partial(optax.inject_hyperparams, hyperparam_dtype=jnp.float32)
    def create_optimizer_fn(
        learning_rate: optax.Schedule,
    ) -> optax.GradientTransformation:
        tx = OPTIMIZER_COLLECTION[args.optimizer.name](
            learning_rate=learning_rate,
            b1=args.optimizer.betas[0],
            b2=args.optimizer.betas[1],
            eps=float(args.optimizer.eps),
            weight_decay=args.optimizer.weight_decay,
            mask=partial(tree_map_with_path, lambda kp, *_: kp[-1].key == "kernel"),
        )
        if args.optimizer.lr_decay < 1.0:
            layerwise_scales = {
                i: optax.scale(args.optimizer.lr_decay ** (args.layers - i))
                for i in range(args.layers + 1)
            }
            label_fn = partial(get_layer_index_fn, num_layers=args.layers)
            label_fn = partial(tree_map_with_path, label_fn)
            tx = optax.chain(tx, optax.multi_transform(layerwise_scales, label_fn))
        if args.optimizer.clip_grad > 0:
            tx = optax.chain(optax.clip_by_global_norm(args.optimizer.clip_grad), tx)
        return tx

    learning_rate = optax.warmup_cosine_decay_schedule(
        init_value=1e-6,
        peak_value=float(args.optimizer.lr),
        warmup_steps=args.warmup_steps,
        decay_steps=args.training_steps,
        end_value=1e-5,
    )
    return TrainState.create(
        apply_fn=module.apply,
        params=combined_params,
        tx=create_optimizer_fn(learning_rate),
        mixup_rng=jax.random.PRNGKey(args.mixup_seed + jax.process_index()),
        dropout_rng=jax.random.PRNGKey(args.dropout_seed + jax.process_index()),
        patch_selection_rng=jax.random.PRNGKey(
            args.patch_select_seed + jax.process_index()
        ),
        micro_step=0,
        micro_in_mini=args.grad_accum,
        grad_accum=grad_accum if args.grad_accum > 1 else None,
        teacher_params=teacher_params,
    )
