"""
Generates language efficiently from a pretrained model
"""

from typing import Dict, Tuple, Callable
from pathlib import Path
import os
import json
import jax
from jax import numpy as jnp
import numpy as np
import optax
import flax
from datasets import disable_caching, DatasetDict
from transformers import AutoTokenizer, AutoConfig
from latte_trans.trainer.dfsp_jax import (
    DFSDPTrainer,
    shard_module_params,
    get_scheduler,
    gather_params,
)
from latte_trans.preproc.tiny_stories import TinyStories
from latte_trans.preproc.slim_pajama import SlimPajama
from latte_trans.experiments.utils import parse_args
from latte_trans.experiments.base import BaseTask
from latte_trans.config import LMTaskConfig
from latte_trans.models.modules.inference.lm import Gemma as Gemma, LanguageSampler
import flax.linen as nn

# from latte_trans.models.tasks.lm_pret import Gemma


MIN_SHARD_WEIGHT = 2**10


def zero_grads():
    # from https://github.com/deepmind/optax/issues/159#issuecomment-896459491
    def init_fn(_):
        return ()

    def update_fn(updates, state, params=None):
        return jax.tree_util.tree_map(jnp.zeros_like, updates), ()

    return optax.GradientTransformation(init_fn, update_fn)


def prepare_freeze_optimizer(
    config: LMTaskConfig, total_steps: int
) -> Tuple[optax.GradientTransformation, optax.Schedule]:
    lr_scheduler = get_scheduler(config=config, total_steps=total_steps)

    def map_params(fn):  # Convenience function to construct mask_fns
        def map_fn(params):
            flat_params = flax.traverse_util.flatten_dict(params)
            flat_mask = {path: fn(path) for path in flat_params.keys()}
            return flax.traverse_util.unflatten_dict(flat_mask)

        return map_fn

    def partition_fn(path):
        if path[-2] in [
            "latte_Wq",
            "latte_Wk",
            "latte_conv",
            "latte_lru_in",
            "latte_lru_norm",
            "latte_rglru",
        ]:
            return "trainable"
        if (len(path) > 2) and path[-3] in [
            "latte_rglru",
        ]:
            return "trainable"
        return "frozen"

    ssm_fn = map_params(partition_fn)
    regular_opt = optax.inject_hyperparams(optax.adamw)(
        learning_rate=lr_scheduler, weight_decay=config.weight_decay
    )

    optimizer = optax.multi_transform(
        {
            "trainable": regular_opt,
            "frozen": zero_grads(),
        },
        ssm_fn,
    )

    if config.grad_accumulation_steps > 1:
        optimizer = optax.MultiSteps(optimizer, config.grad_accumulation_steps)
    # chain with norm
    optimizer = optax.chain(optax.clip_by_global_norm(1.0), optimizer)
    return optimizer, lr_scheduler


def get_dp(config):
    path_name = config.hugg_chk if config.hugg_chk else "allenai/longformer-base-4096"
    print("Path Name is: ", path_name)
    # tokenizer = AutoTokenizer.from_pretrained(
    #     path_name,
    #     cache_dir=Path(config.base_dir) / "input/cache_hugg",
    #     truncation_side="right",
    #     padding_side="right",
    # )
    # cache_dir = os.path.join(config.base_dir, "input", "tiny-sories")
    # dp = TinyStories(
    #     tokenizer=tokenizer,
    #     cache_dir=cache_dir,
    #     max_seq_len=config.max_seq_len,
    #     num_load_procs=None,
    # )
    # raw_data = dp.get_raw_data()
    # tok_data = dp.tokenize(raw_data, force_preproc=config.disable_cache)

    tokenizer = AutoTokenizer.from_pretrained(
        "google/gemma-2-2b",
        cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
        truncation_side="right",
        padding_side="right",
    )
    # We directly load the tokenized data. Slow to tokenize on the fligh
    raw_data = None
    tok_data = DatasetDict.load_from_disk("/user_all_data/data/input/pajama_tokenized")
    tok_data.set_format("np")
    # dummy data-processor only used for collate_fn
    dp = SlimPajama(
        tokenizer,
        cache_dir=None,
        num_load_procs=None,
        max_seq_len=config.max_seq_len,
    )

    return dp, tokenizer, raw_data, tok_data


def get_from_hugg_config(config: LMTaskConfig):
    hugg_config = AutoConfig.from_pretrained(
        config.hugg_chk,
        cache_dir=Path(config.base_dir) / "input/cache_hugg",
    )

    config = config.replace(
        hidden_dim=hugg_config.hidden_size,
        intermediate_dim=hugg_config.intermediate_size,
        hidden_act=hugg_config.hidden_act,
        nlayers=hugg_config.num_hidden_layers,
        num_key_value_heads=hugg_config.num_key_value_heads,
        pos_embed_max_len=hugg_config.max_position_embeddings,
        nheads=hugg_config.num_attention_heads,
        head_dim=hugg_config.head_dim,
        dropout_att=hugg_config.attention_dropout,
        attention_bias=hugg_config.attention_bias,
        text_vocab_size=hugg_config.vocab_size,
    )

    return config


class LMInferenceTask(BaseTask):
    def __init__(self, config):
        super().__init__(config)
        self.config = get_from_hugg_config(self.config)
        self.dp, self.tokenizer, self.raw_data, self.tok_data = get_dp(config)

    def get_model(self, sharded=True):
        if sharded:
            constructor = shard_module_params(
                Gemma, axis_name="B", min_weight_size=MIN_SHARD_WEIGHT
            )
        else:
            constructor = Gemma
        # match self.config.mixed_precision:
        #     case "bf16":
        #         mixed_precission = jnp.bfloat16
        #     case _:
        #         mixed_precission = jnp.float32
        mixed_precission = jnp.float32
        model = constructor(
            config=self.config,
            pad_id=self.tokenizer.pad_token_id,
            dtype=mixed_precission,
            sharded=sharded,
        )
        return model

    def load_params(self, train_rng):
        train_rng, init_rng = jax.random.split(train_rng, 2)
        model = self.get_model(sharded=True)
        trainer = DFSDPTrainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=self.tok_data["train"],
            data_collator=self.dp.get_collate_fn("np"),
            evaluator=None,
            wandb_run=self.wandb_run,
            rng=init_rng,
            model_inputs_orded=("input_ids", "labels"),
            prepare_optimizer_fn=prepare_freeze_optimizer,
        )

        data = {
            "input_ids": jnp.zeros(
                (self.config.batch_size, self.config.max_seq_len), dtype=jnp.int32
            ),
            "labels": jnp.zeros(
                (self.config.batch_size, self.config.max_seq_len), dtype=jnp.int32
            ),
        }
        data = tuple(data.values())  # from BatchEncoding to tuple
        zero_state = trainer._jit_init_fn(
            self.config.batchnorm, model, trainer._optimizer, init_rng, data
        )

        shape_m = jax.tree.map(
            lambda x: x.shape,
            nn.meta.unbox(zero_state.params),
        )
        print("Shape model is: ", json.dumps(shape_m))
        chk_path = self.config.check_path
        state, metadata = trainer.load_trainer_state(
            empty_state=zero_state, check_dir=chk_path, step_number=None
        )

        return state.params

    def inference(self, train_rng):
        train_rng, load_rng = jax.random.split(train_rng)
        params = self.load_params(load_rng)
        axis_name = "B"

        def _gather(p):
            if isinstance(p, nn.Partitioned) and axis_name in p.names:
                param_shard = p.names
                shard_axis = param_shard.index(axis_name)
                # value = jax.lax.all_gather(
                #     p.value, axis_name, axis=shard_axis, tiled=True
                # )
                value = jax.device_get(p.value)
                # If there are any other axes that are sharded, we need to keep the partitioned structure.
                # Otherwise, we can return the value directly.
                param_shard = (
                    param_shard[:shard_axis] + (None,) + param_shard[shard_axis + 1 :]
                )
                if any([name is not None for name in param_shard]):
                    return nn.Partitioned(value, param_shard)
                else:
                    return value
            else:
                return p

        params = jax.tree_util.tree_map(
            _gather, params, is_leaf=lambda x: isinstance(x, nn.Partitioned)
        )
        # gather_params(params, axis_name="B")
        # params = gather_params(params=params, axis_name="B")
        print(jax.tree_util.tree_map(lambda x: x.shape, params))

        model = self.get_model(sharded=False)
        sampler = LanguageSampler(compile=True, model=model, params=params)
        test_seq = "Tell me a joke"
        input_ids = self.tokenizer(test_seq)["input_ids"]
        print(input_ids)
        input_ids = jnp.array([input_ids])
        res = sampler.generate(input_ids, gen_length=150)
        res = np.array(res["generated"])
        print(res)
        print("Our model says: ", self.tokenizer.decode(res[0]))


def main():
    seed = 0
    rng = jax.random.PRNGKey(seed)
    rng, train_rng, sample_rng = jax.random.split(rng, 3)
    args = parse_args()
    config = LMTaskConfig.load(
        yaml_file=args.config_file, base_dir=args.base_dir, name=args.name
    )

    if config.disable_cache:
        print("Disabling Cache")
        disable_caching()

    task = LMInferenceTask(config)
    task.inference(train_rng)


if __name__ == "__main__":
    main()
