from collections import defaultdict
from contextlib import contextmanager
from fnmatch import fnmatch
import logging
import time
from typing import Callable, List, Optional, Union

import flax
from flax import struct
import jax
from jax.experimental import multihost_utils
import jax.numpy as jnp
from ml_collections import ConfigDict
import numpy as np
import optax

from octo.data.utils.text_processing import TextProcessor
from octo.model.octo_model import OctoModel
from octo.utils import jax_utils
from octo.utils.typing import Config, Data, Params, PRNGKey

from hypervla.model import HyperVLA
from octo.model.octo_model import OctoModel
from hypervla.base_model import BaseModel

@struct.dataclass
class TrainState:
    rng: PRNGKey
    model: Union[HyperVLA, OctoModel, BaseModel]
    step: int
    opt_state: optax.OptState
    tx: optax.GradientTransformation = struct.field(pytree_node=False)

    @classmethod
    def create(
        cls,
        rng: PRNGKey,
        model: Union[HyperVLA, OctoModel, BaseModel],
        tx: optax.GradientTransformation,
    ):
        opt_state = tx.init(model.params)
        return cls(
            rng=rng,
            model=model,
            step=0,
            opt_state=opt_state,
            tx=tx,
        )

    def apply_gradients(self, *, grads, rng):
        updates, new_opt_state = self.tx.update(
            grads, self.opt_state, self.model.params
        )
        new_params = optax.apply_updates(self.model.params, updates)

        return self.replace(
            step=self.step + 1,
            model=self.model.replace(params=new_params),
            opt_state=new_opt_state,
            rng=rng,
        )


def format_name_with_config(name, config):
    """Formats a name string with a config dict.

    Formatting keys may be specified as {key} or {full_path_to_key_with_underscores}.

    Example:
        name = "model_{model_type}_{model_size}"
        config = {"model_type": "transformer", "model_size": "small"}
        format_name_with_config(name, config) -> "model_transformer_small"
    """
    config_flat = flax.traverse_util.flatten_dict(config, sep="_")
    config_final = {k.split("_")[-1]: v for k, v in config_flat.items()}
    format_dict = {**config_final, **config_flat}
    return name.format(**format_dict)


class Timer:
    """
    Timer utility. Usage:

        timer = Timer()
        with timer("foo"):
            do_something()

        timer.tick("bar")
        do_something_else()
        timer.tock("bar")

        timer.get_average_times() -> {"foo": 0.1, "bar": 0.2}
    """

    def __init__(self):
        self.reset()

    @contextmanager
    def __call__(self, key):
        self.tick(key)
        try:
            yield None
        finally:
            self.tock(key)

    def reset(self):
        self.counts = defaultdict(int)
        self.times = defaultdict(float)
        self.start_times = {}

    def tick(self, key):
        if key in self.start_times:
            raise ValueError(f"Timer is already ticking for key: {key}")
        self.start_times[key] = time.time()

    def tock(self, key):
        if key not in self.start_times:
            raise ValueError(f"Timer is not ticking for key: {key}")
        self.counts[key] += 1
        self.times[key] += time.time() - self.start_times[key]
        del self.start_times[key]

    def get_average_times(self, reset=True):
        ret = {key: self.times[key] / self.counts[key] for key in self.counts}
        if reset:
            self.reset()
        return ret


def batched_apply(fn, batch_size):
    """Turns a function that applies to a fixed batch size into one that applies to a variable batch size.
    Useful for passing variable batch sizes to jit-compiled functions.
    """

    def pad_to_size(arr, size):
        return np.pad(arr, ((0, size - len(arr)), *[(0, 0)] * (arr.ndim - 1)))

    def get_batch_size(tree):
        return next(iter(jax.tree_util.tree_leaves(tree))).shape[0]

    def wrapped_fn(*args, **kwargs):
        input_batch_size = get_batch_size((args, kwargs))
        multihost_utils.assert_equal(
            input_batch_size // batch_size,
            "batched_apply has been called with arguments that would lead to"
            " a different number of iterations on different hosts."
            f" got batch_size={batch_size}, input_batch_size={input_batch_size}"
            f" on host {jax.process_index()}.",
        )
        outputs = []
        for i in range(0, input_batch_size, batch_size):
            step_batch_size = min(batch_size, input_batch_size - i)
            step_args, step_kwargs = jax.tree_map(
                lambda arr: pad_to_size(arr[i : i + batch_size], batch_size),
                (args, kwargs),
            )
            step_args, step_kwargs = jax_utils.merge_along_axis(
                (step_args, step_kwargs)
            )
            step_output = fn(*step_args, **step_kwargs)
            step_output = jax.device_get(jax_utils.split_along_axis(step_output))
            outputs.append(
                jax.tree_map(
                    lambda arr: arr[:step_batch_size],
                    step_output,
                )
            )
        return jax.tree_map(lambda *args: np.concatenate(args, axis=0), *outputs)

    return wrapped_fn


def filter_eval_datasets(dataset_kwargs_list, sample_weights, eval_datasets=None):
    if sample_weights is None:
        sample_weights = [1.0] * len(dataset_kwargs_list)
    if eval_datasets is None:
        return dataset_kwargs_list, sample_weights
    if len(eval_datasets) == 0:
        return [], []
    else:
        return list(
            map(
                list,
                zip(
                    *[
                        (dkwargs, weight)
                        for dkwargs, weight in zip(dataset_kwargs_list, sample_weights)
                        if (dkwargs["name"] in eval_datasets)
                    ]
                ),
            )
        )


def create_lr_schedule(name: str, **kwargs):
    """Creates a learning rate callable.

    Currently supported schedules:
        cosine: cosine decay with warmup.
            kwargs: init_value, peak_value, warmup_steps, decay_steps
        rsqrt: inverse square root decay with warmup, from the "Scaling Vision Transformers" paper.
            kwargs: init_value, peak_value, warmup_steps, timescale (optional, default 10000)
        constant: constant learning rate with warmup.
            kwargs: init_value, peak_value, warmup_steps

    Args:
        name: name of the schedule
        **kwargs: additional kwargs, which vary by schedule
    """
    if name == "cosine":
        return optax.warmup_cosine_decay_schedule(**kwargs)
    elif name == "rsqrt":
        timescale = kwargs.get("timescale", 10000)
        return optax.join_schedules(
            [
                optax.linear_schedule(
                    init_value=kwargs["init_value"],
                    end_value=kwargs["peak_value"],
                    transition_steps=kwargs["warmup_steps"],
                ),
                lambda step: kwargs["peak_value"]
                / jnp.sqrt((step + timescale) / timescale),
            ],
            [kwargs["warmup_steps"]],
        )
    elif name == "constant":
        return optax.join_schedules(
            [
                optax.linear_schedule(
                    init_value=kwargs["init_value"],
                    end_value=kwargs["peak_value"],
                    transition_steps=kwargs["warmup_steps"],
                ),
                lambda step: kwargs["peak_value"],
            ],
            [kwargs["warmup_steps"]],
        )
    else:
        raise ValueError(f"Unsupported lr schedule: {name}")


def freeze_weights(
    tx: optax.GradientTransformation,
    params_or_params_shape: Params,
    frozen_keys: List[str],
    return_partitions: bool = False,
):
    """
    Freezes all weights in params_or_params_shape whose keys fnmatch the ones in frozen_keys.
    Example usage:
        tx = freeze_weights(tx, model.params, ["octo_transformer.*"])
    """
    logging.info(f"Freezing parameters that include the following keys: {frozen_keys}.")
    partition_optimizers = {
        "trainable": tx,
        "frozen": optax.set_to_zero(),
    }
    # freeze anything that matches fnmatch patterns in `frozen_keys`
    # path is a string of .-separated module names, e.g. ('octo_transformer.BlockTransformer_0...')
    param_partitions = flax.traverse_util.path_aware_map(
        lambda path, v: "frozen"
        if any([fnmatch(".".join(path), key) for key in frozen_keys])
        else "trainable",
        params_or_params_shape,
    )
    tx = optax.multi_transform(partition_optimizers, param_partitions)

    logging.debug("Frozen params:")
    flax.traverse_util.path_aware_map(
        lambda path, opt_status: logging.debug(".".join(path))
        if opt_status == "frozen"
        else None,
        param_partitions,
    )
    total_params = sum(
        jax.tree_util.tree_leaves(
            jax.tree_map(lambda x: x.size, params_or_params_shape)
        )
    )
    trainable_params = sum(
        jax.tree_util.tree_leaves(
            jax.tree_map(
                lambda x, y: x.size if y == "trainable" else 0,
                params_or_params_shape,
                param_partitions,
            )
        )
    )
    logging.info(f"Num trainable params: {trainable_params:,}.")
    logging.info(f"Num frozen params: {total_params - trainable_params:,}.")
    logging.info("To see a detailed list of frozen params, set logging level to DEBUG.")
    return (tx, param_partitions) if return_partitions else tx


def create_optimizer(
    params_or_params_shape: Params, HN_param_type, weight_decay_strategy='v1', **kwargs: dict
) -> optax.GradientTransformation:
    """Creates optimizer for Octo.

    kwargs are the kwargs for optax.adamw; if the "learning_rate" key is a dict, it is interpreted
    as the kwargs for create_lr_schedule (see above), otherwise it is interpreted as a constant
    learning rate.

    If clip_gradient is specified, then gradient clipping is applied. If frozen_keys is specified,
    then those parameters are frozen (i.e. not updated) during training.

    Returns:
        tx: an Optax optimizer
        lr_callable: Function that takes the current step and returns the learning rate
    """
    if isinstance(kwargs["learning_rate"], dict):
        lr_callable = create_lr_schedule(**kwargs["learning_rate"])
    else:
        lr_callable = lambda _: kwargs["learning_rate"]
    kwargs.pop("learning_rate")
    # learning rate schedule for shared params
    if kwargs.get("base_learning_rate", None) is not None:
        if isinstance(kwargs["base_learning_rate"], dict):
            base_lr_callable = create_lr_schedule(**kwargs["base_learning_rate"])
        else:
            base_lr_callable = lambda _: kwargs["base_learning_rate"]
        kwargs.pop("base_learning_rate")
    else:
        base_lr_callable = lr_callable

    # Do not apply weight decay to normalization layers in the HN
    def filter_weight_decay(path, x):
        path_str = jax.tree_util.keystr(path)
        if "norm" in path_str.lower() and "output_head" not in path_str:
            return False
        return True

    # For the context encoder in HN, only apply weight decay to "kernel" (same as Octo's default strategy)
    # For the output heads in HN, only apply weight decay to the heads that generate "kernel" in the base net, 
    # so that the base net can be regularized in the same way as Octo
    def weight_decay_v3(path, x):
        path_str = jax.tree_util.keystr(path)
        if "output_head" in path[0].key:
            if "kernel" in path[0].key:
                return True
            else:
                return False
        else:
            # apply delta decay to all the params in the DINO image encoder
            if "image_encoder" in path_str:
                return True
            # for the remaining parts, only apply weight decay to kernel
            if "kernel" in path_str:
                return True
            else:
                return False

    # For the HN params, apply WD to output heads that generate kernels in the base network
    # No WD to context encoder
    def weight_decay_v5(path, x):
        path_str = jax.tree_util.keystr(path)
        if "output_head" in path[0].key:
            if "kernel" in path[0].key:
                return True
            else:
                return False
        elif "image_encoder" in path_str:
            return True
        return False

    if weight_decay_strategy == 'v5':
        wd_mask = jax.tree_util.tree_map_with_path(
            weight_decay_v5, params_or_params_shape
        )
    elif weight_decay_strategy == 'v3':
        wd_mask = jax.tree_util.tree_map_with_path(
            weight_decay_v3, params_or_params_shape
        )
    elif weight_decay_strategy == 'v2':
        wd_mask = jax.tree_util.tree_map_with_path(
            filter_weight_decay, params_or_params_shape
        )
    else:
        # Following ViT, timm, MAE: this mask skips weight decay on biases and LayerNorm parameters
        wd_mask = jax.tree_util.tree_map_with_path(
            lambda path, x: "kernel" in jax.tree_util.keystr(path), params_or_params_shape
        )

    clip_gradient = kwargs.pop("clip_gradient", None)
    frozen_keys = kwargs.pop("frozen_keys", None)
    grad_accumulation_steps = kwargs.pop("grad_accumulation_steps", 1)
    weight_decay = kwargs.pop("weight_decay", None)
    base_weight_decay = kwargs.pop("base_weight_decay", None)

    def pretty_print_nested_dict(nested_dict, generation_strategy):
        def print_node(node, generation, depth):
            prefix = '-' * depth * 2
            for key in node:
                if type(node[key]) is dict:
                    print (f'{prefix}{key}')
                    print_node(node[key], generation[key], depth + 1)
                else:
                    if node[key]:
                        if generation[key] == "generated":
                            WD_coefficient = weight_decay
                        else:
                            WD_coefficient = base_weight_decay
                        print (f'{prefix}{key}: {generation[key]}, WD = {node[key]}, {WD_coefficient}')
                    else:
                        print (f'{prefix}{key}: {generation[key]}, WD = {node[key]}')
        print_node(nested_dict, generation_strategy, 0)

    pretty_print_nested_dict(wd_mask, HN_param_type)

    if "shared" not in jax.tree_util.tree_leaves(HN_param_type):
        tx = optax.adamw(mu_dtype=jnp.bfloat16, **kwargs, learning_rate=lr_callable, mask=wd_mask, weight_decay=weight_decay)
    else:
        HN_weight_decay_mask = jax.tree_map(lambda x, y: x & (y == "generated"), wd_mask, HN_param_type)
        base_weight_decay_mask = jax.tree_map(lambda x, y: x & (y == "shared"), wd_mask, HN_param_type)
        partition_optimizers = {
            "generated": optax.adamw(mu_dtype=jnp.bfloat16, **kwargs, learning_rate=lr_callable, mask=HN_weight_decay_mask, weight_decay=weight_decay),
            "shared": optax.adamw(mu_dtype=jnp.bfloat16, **kwargs, learning_rate=base_lr_callable, mask=base_weight_decay_mask, weight_decay=base_weight_decay),
        }
        tx = optax.multi_transform(partition_optimizers, HN_param_type)
    if grad_accumulation_steps > 1:
        tx = optax.MultiSteps(tx, grad_accumulation_steps)
    if clip_gradient is not None:
        tx = optax.chain(
            optax.clip_by_global_norm(clip_gradient),
            tx,
        )

    if frozen_keys:
        tx, param_partitions = freeze_weights(
            tx, params_or_params_shape, frozen_keys, return_partitions=True
        )
        zero_frozen_params = lambda params: jax.tree_map(
            lambda x, y: x if y == "trainable" else jnp.zeros(()),
            params,
            param_partitions,
        )
        param_norm_callable = lambda params: optax.global_norm(
            zero_frozen_params(params)
        )
    else:
        param_norm_callable = optax.global_norm

    return tx, lr_callable, base_lr_callable, param_norm_callable


def check_config_diff(new_conf: Config, old_conf: Config, silent: bool = False):
    """Checks for differences between new config and old config dicts."""
    new_conf_flat = flax.traverse_util.flatten_dict(
        new_conf.to_dict() if isinstance(new_conf, ConfigDict) else new_conf
    )
    old_conf_flat = flax.traverse_util.flatten_dict(
        old_conf.to_dict() if isinstance(old_conf, ConfigDict) else old_conf
    )

    # check for missing / new keys
    if set(new_conf_flat.keys()) != set(old_conf_flat.keys()) and not silent:
        logging.info(
            "New config contains extra items: %s",
            set(new_conf_flat.keys()) - set(old_conf_flat.keys()),
        )
        logging.info(
            "New config doesn't contain items: %s",
            set(old_conf_flat.keys()) - set(new_conf_flat.keys()),
        )

    # print differing key values
    mismatched_keys = {
        k: (new_conf_flat[k], old_conf_flat[k])
        for k in new_conf_flat
        if k in old_conf_flat and new_conf_flat[k] != old_conf_flat[k]
    }
    if mismatched_keys and not silent:
        logging.info(
            "New config contains keys with new values: %s",
            flax.core.pretty_repr(mismatched_keys),
        )
    return mismatched_keys or (set(new_conf_flat.keys()) != set(old_conf_flat.keys()))


def merge_params(target_params: Params, pretrained_params: Params) -> Params:
    """Copies pre-trained params into target_params for every param that has corresponding key + shape."""
    flat_target_params = flax.traverse_util.flatten_dict(target_params)
    flat_pretrained_params = flax.traverse_util.flatten_dict(pretrained_params)
    keys_to_update = [
        k
        for k in flat_target_params
        if k in flat_pretrained_params
        and flat_target_params[k].shape == flat_pretrained_params[k].shape
    ]
    missing_keys = [k for k in flat_target_params if k not in flat_pretrained_params]
    shape_mismatch_keys = [
        k
        for k in flat_target_params
        if k in flat_pretrained_params
        and flat_target_params[k].shape != flat_pretrained_params[k].shape
    ]

    for key in keys_to_update:
        logging.debug(f"Param copied from pre-trained: {'.'.join(key)}")
    if missing_keys or shape_mismatch_keys:
        logging.info("########## Parameters skipped during model loading: ##########")
        for key in missing_keys:
            logging.info(
                f"Param missing in pre-trained model, skipping: {'.'.join(key)}"
            )
        for key in shape_mismatch_keys:
            logging.info(
                f"Param with differing shape in pre-trained model, skipping: {'.'.join(key)}"
            )

    flat_target_params = flax.core.copy(
        flat_target_params, {k: flat_pretrained_params[k] for k in keys_to_update}
    )
    target_params = flax.traverse_util.unflatten_dict(flat_target_params)
    return target_params


def process_text(batch: Data, text_processor: Optional[TextProcessor]) -> Data:
    """Encodes the language instruction inside the tasks for a batch.

    If the text processor is None, removes language entirely from the tasks.
    Expects batch to be a nested dictionary, where
        batch["task"]["language_instruction"] is a sequence of byte strings
    """
    if text_processor is None:
        batch["task"].pop("language_instruction")
    else:
        batch["task"]["instruction_string"] = batch["task"]["language_instruction"]
        batch["task"]["language_instruction"] = text_processor.encode(
            [s.decode("utf-8") for s in batch["task"]["language_instruction"]]
        )
        if "rephrased_task" in batch:
            batch["rephrased_task"]["language_instruction"] = text_processor.encode(
                [s.decode("utf-8") for s in batch["rephrased_task"]["language_instruction"]]
            )
    return batch


WeightLoader = Callable[[Params], Params]


def hf_weights_loader(params, hf_model):
    """Loads weights from a HuggingFace model into params."""
    from transformers import AutoConfig, FlaxAutoModel, FlaxT5EncoderModel

    if "t5" in hf_model:
        config = AutoConfig.from_pretrained(hf_model)
        model = FlaxT5EncoderModel.from_pretrained(hf_model, config=config)
    else:
        model = FlaxAutoModel.from_pretrained(hf_model)

    model_def, model_variables = model.module, model.params
    replaced = False

    def find_and_replace(params, key, replacement):
        nonlocal replaced
        for k in params.keys():
            if k == key:
                params[k] = replacement
                print(f"Replaced {key} in params")
                replaced = True
                return
            if isinstance(params[k], type(params)):
                find_and_replace(params[k], key, replacement)

    find_and_replace(params, "hf_model", model_variables)
    assert replaced, "Failed to load weights"
    return params
