from typing import Dict, Tuple, Callable
from pathlib import Path
from functools import partial
import itertools
import math
import os
import json
import wandb
import jax
from jax import numpy as jnp
from torch.utils.data import DataLoader
import numpy as np
from datasets import disable_caching, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoConfig,
    LlamaConfig,
    LlamaForCausalLM,
    Gemma2ForCausalLM,
    Gemma2Model,
    AutoModelForCausalLM,
)

from latte_trans.trainer.dfsp_jax import (
    DFSDPTrainer,
    gather_params,
    shard_params,
    shard_module_params,
    BatchNormTrainState,
    TrainState,
    get_scheduler,
)
from recurrentgemma import jax as recurrentgemma

from latte_trans.experiments.utils import parse_args
from latte_trans.experiments.base import BaseTask
from latte_trans.preproc.tiny_stories import TinyStories
from latte_trans.preproc.slim_pajama import SlimPajama
from latte_trans.preproc.lm_dp import Wiki103DP
from latte_trans.preproc.openweb import OpenWeb
from latte_trans.config import LMTaskConfig
from latte_trans.evals.lang_eval import LanguageEvaluator, LanguageEvaluatorSeq
from latte_trans.models.tasks.lm_pret import Gemma
from latte_trans.preproc.book_corpus import BookCorpusLong
from latte_trans.preproc.pile import PileTok2

from flax import linen as nn
import optax
from flax.training import train_state
import flax

MIN_SHARD_WEIGHT = 2**10


def zero_grads():
    # from https://github.com/deepmind/optax/issues/159#issuecomment-896459491
    def init_fn(_):
        return ()

    def update_fn(updates, state, params=None):
        return jax.tree_map(jnp.zeros_like, updates), ()

    return optax.GradientTransformation(init_fn, update_fn)


def map_nested_fn(fn):
    """
    Recursively apply `fn to the key-value pairs of a nested dict / pytree.
    We use this for some of the optax definitions below.
    """

    def map_fn(nested_dict):
        return {
            k: (map_fn(v) if hasattr(v, "keys") else fn(k, v))
            for k, v in nested_dict.items()
        }

    return map_fn


def zero_grads():
    # from https://github.com/deepmind/optax/issues/159#issuecomment-896459491
    def init_fn(_):
        return ()

    def update_fn(updates, state, params=None):
        return jax.tree_util.tree_map(jnp.zeros_like, updates), ()

    return optax.GradientTransformation(init_fn, update_fn)


def prepare_freeze_optimizer(
    config: LMTaskConfig, total_steps: int
) -> Tuple[optax.GradientTransformation, optax.Schedule]:
    lr_scheduler = get_scheduler(config=config, total_steps=total_steps)

    def map_params(fn):  # Convenience function to construct mask_fns
        def map_fn(params):
            flat_params = flax.traverse_util.flatten_dict(params)
            flat_mask = {path: fn(path) for path in flat_params.keys()}
            return flax.traverse_util.unflatten_dict(flat_mask)

        return map_fn

    def partition_fn(path):
        print(path)
        if path[-2] in [
            "latte_Wq",
            "latte_Wk",
            "latte_conv",
            "latte_lru_in",
            "latte_lru_norm",
            "latte_rglru",
        ]:
            print("trainable")
            return "trainable"
        if (len(path) > 2) and path[-3] in [
            "latte_rglru",
        ]:
            print("trainable")
            return "trainable"
        print("frozen")
        return "frozen"
        # return "trainable"

    ssm_fn = map_params(partition_fn)
    regular_opt = optax.inject_hyperparams(optax.adamw)(
        learning_rate=lr_scheduler, weight_decay=config.weight_decay
    )

    optimizer = optax.multi_transform(
        {
            "trainable": regular_opt,
            "frozen": zero_grads(),
        },
        ssm_fn,
    )

    if config.grad_accumulation_steps > 1:
        optimizer = optax.MultiSteps(optimizer, config.grad_accumulation_steps)
    # chain with norm
    optimizer = optax.chain(optax.clip_by_global_norm(1.0), optimizer)
    return optimizer, lr_scheduler


def get_dp(config):
    if config.dataset_name == "tiny-stories":
        path_name = (
            config.hugg_chk if config.hugg_chk else "allenai/longformer-base-4096"
        )
        print("Path Name is: ", path_name)
        tokenizer = AutoTokenizer.from_pretrained(
            path_name,
            cache_dir=Path(config.base_dir) / "input/cache_hugg",
            truncation_side="right",
            padding_side="right",
        )
        cache_dir = os.path.join(config.base_dir, "input", "tiny-sories")
        dp = TinyStories(
            tokenizer=tokenizer,
            cache_dir=cache_dir,
            max_seq_len=config.max_seq_len,
            num_load_procs=None,
        )
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(raw_data, force_preproc=config.disable_cache)
    elif config.dataset_name == "wiki103":
        cache_dir = os.path.join(config.base_dir, "input", "test_wikitext-103")
        tokenizer = AutoTokenizer.from_pretrained(
            "google/gemma-2-2b",
            cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
            truncation_side="right",
            padding_side="right",
        )
        # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
        # tokenizer = SpecialToksGPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-125M")
        # SpecialToksGPT2TokenizerFast.from_pretrained("gpt2")
        dp = Wiki103DP(
            tokenizer=tokenizer,
            cache_dir=cache_dir,
            max_seq_len=config.max_seq_len,
        )
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(
            raw_data, add_special_tokens=False, force_tok=config.disable_cache
        )

    elif config.dataset_name == "owt":
        # only for evaluation purposes
        cache_dir = os.path.join(config.base_dir, "input", "test_openweb")
        tokenizer = AutoTokenizer.from_pretrained(
            "google/gemma-2-2b",
            cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
            truncation_side="right",
            padding_side="right",
        )
        dp = OpenWeb(
            tokenizer=tokenizer,
            cache_dir=cache_dir,
            num_load_procs=8,
            max_seq_len=config.max_seq_len,
        )
        raw_data = dp.get_raw_data()
        # add a small sample for compatibility reasons
        raw_data["train"] = raw_data["train"].select(np.arange(1000))
        tok_data = dp.tokenize(raw_data, force_tok=config.disable_cache)
    elif config.dataset_name == "pajama":
        tokenizer = AutoTokenizer.from_pretrained(
            "google/gemma-2-2b",
            cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
            truncation_side="right",
            padding_side="right",
        )
        # We directly load the tokenized data. Slow to tokenize on the fligh
        raw_data = None
        tok_data = DatasetDict.load_from_disk(
            "/user_all_data/data/input/pajama_tokenized"
        )
        tok_data.set_format("np")
        # dummy data-processor only used for collate_fn
        dp = SlimPajama(
            tokenizer,
            cache_dir=None,
            num_load_procs=None,
            max_seq_len=config.max_seq_len,
        )
    elif config.dataset_name == "bookcorpus":
        # only for evaluation purposes
        tokenizer = AutoTokenizer.from_pretrained(
            "google/gemma-2-2b",
            cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
            truncation_side="right",
            padding_side="right",
        )
        data_dir = config.base_dir
        cache_dir = os.path.join(data_dir, "input/bookcorpus_test")
        dp = BookCorpusLong(
            tokenizer,
            cache_dir=cache_dir,
            num_load_procs=8,
            max_seq_len=config.max_seq_len,
        )
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(raw_data, force_tok=config.disable_cache)
        tok_data.set_format("np")
    elif config.dataset_name == "pile":
        # only for evaluation purposes
        tokenizer = AutoTokenizer.from_pretrained(
            "google/gemma-2-2b",
            cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
            truncation_side="right",
            padding_side="right",
        )
        data_dir = config.base_dir
        cache_dir = os.path.join(data_dir, "input/test_pile2")
        dp = PileTok2(
            tokenizer,
            cache_dir=cache_dir,
            num_load_procs=8,
            max_seq_len=config.max_seq_len,
        )
        raw_data = dp.get_raw_data_test()
        tok_data = dp.tokenize(raw_data, force_tok=config.disable_cache)
        tok_data.set_format("np")

    return dp, tokenizer, raw_data, tok_data


def get_from_hugg_config(config: LMTaskConfig):
    hugg_config = AutoConfig.from_pretrained(
        config.hugg_chk,
        cache_dir=Path(config.base_dir) / "input/cache_hugg",
    )

    config = config.replace(
        hidden_dim=hugg_config.hidden_size,
        intermediate_dim=hugg_config.intermediate_size,
        hidden_act=hugg_config.hidden_act,
        nlayers=hugg_config.num_hidden_layers,
        num_key_value_heads=hugg_config.num_key_value_heads,
        pos_embed_max_len=hugg_config.max_position_embeddings,
        nheads=hugg_config.num_attention_heads,
        head_dim=hugg_config.head_dim,
        dropout_att=hugg_config.attention_dropout,
        attention_bias=hugg_config.attention_bias,
        text_vocab_size=hugg_config.vocab_size,
    )
    return config


class LMPretTask(BaseTask):
    def get_model(self, dp: TinyStories, sharded: bool = True):
        tokenizer = dp.tokenizer
        if self.config.hugg_chk:
            config = get_from_hugg_config(self.config)
        else:
            config = self.config
            config = config.replace(text_vocab_size=tokenizer.vocab_size + 2)

        match self.config.mixed_precision:
            case "bf16":
                mixed_precission = jnp.bfloat16
            case _:
                mixed_precission = jnp.float32

        match self.config.block_type:
            case "llama":
                # constructor = LavaCausalLM
                raise Exception("Not yet implemented")
            case "gemma":
                constructor = Gemma

        if sharded:
            sharded_dense = shard_module_params(
                constructor, axis_name="B", min_weight_size=MIN_SHARD_WEIGHT
            )
        else:
            sharded_dense = constructor

        model = sharded_dense(
            config,
            pad_id=tokenizer.pad_token_id,
            dtype=mixed_precission,
            sharded=sharded,
        )

        print(model)
        return model

    def train(self, train_rng):
        """Trains a model from scratch"""
        dp, tokenizer, raw_data, tokenized_data = get_dp(self.config)
        data_collator = dp.get_collate_fn(return_type="np")
        print("Tok data: ", tokenized_data)
        model = self.get_model(
            dp, sharded=True
        )  # sharded=False) # get_lava_model(self.config, dp.tokenizer)  #
        train_rng, init_rng, eval_rng = jax.random.split(train_rng, 3)

        evaluator = LanguageEvaluator(
            model,
            tokenizer,
            tokenized_data["validation"],
            data_collator=data_collator,
            config=self.config,
            rng=eval_rng,
        )
        trainer = DFSDPTrainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=tokenized_data["train"],
            data_collator=data_collator,
            evaluator=evaluator,
            wandb_run=self.wandb_run,
            rng=init_rng,
            model_inputs_orded=("input_ids", "labels"),
        )
        if not self.config.check_path is None:
            trainer.train(train_rng, checkpoint_path=self.config.check_path)
        else:
            trainer.train(train_rng)

    def finetune(self, train_rng):
        """Trains a model from scratch"""
        dp, tokenizer, raw_data, tokenized_data = get_dp(self.config)
        data_collator = dp.get_collate_fn(return_type="np")

        model = self.get_model(
            dp, sharded=True
        )  # sharded=False) # get_lava_model(self.config, dp.tokenizer)  #
        train_rng, init_rng, eval_rng = jax.random.split(train_rng, 3)

        evaluator = LanguageEvaluator(
            model,
            tokenizer,
            tokenized_data["validation"],
            data_collator=data_collator,
            config=self.config,
            rng=eval_rng,
        )
        trainer = DFSDPTrainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=tokenized_data["train"],
            data_collator=data_collator,
            evaluator=evaluator,
            wandb_run=self.wandb_run,
            rng=init_rng,
            model_inputs_orded=("input_ids", "labels"),
        )
        empty_state = trainer.get_distrib_state(train_rng)
        state, metadata = trainer.load_trainer_state(
            empty_state=empty_state, check_dir=self.config.check_path, step_number=None
        )
        trainer.train(train_rng, state=state)

    def train_formpt(self, train_rng):
        if self.config.hugg_chk:
            match self.config.block_type:
                case "llama":
                    state_pt = LlamaForCausalLM.from_pretrained(
                        self.config.hugg_chk,
                        cache_dir=Path(self.config.base_dir) / "input/cache_hugg",
                        device_map="cpu",
                        max_memory={"cpu": "5GB"},
                        offload_state_dict=True,
                    ).state_dict()
                case "gemma":
                    state_pt = Gemma2Model.from_pretrained(
                        self.config.hugg_chk,
                        cache_dir=Path(self.config.base_dir) / "input/cache_hugg",
                        device_map="cpu",
                        max_memory={"cpu": "10GB"},
                        offload_state_dict=True,
                    ).state_dict()
                case _:
                    state_pt = AutoModelForCausalLM.from_pretrained(
                        self.config.hugg_chk,
                        cache_dir=Path(self.config.base_dir) / "input/cache_hugg",
                        device_map="cpu",
                        max_memory={"cpu": "5GB"},
                        offload_state_dict=True,
                    ).state_dict()

        dp, tokenizer, raw_data, tokenized_data = get_dp(self.config)
        data_collator = dp.get_collate_fn(return_type="np")

        model = self.get_model(dp, sharded=True)
        train_rng, init_rng, eval_rng = jax.random.split(train_rng, 3)

        evaluator = LanguageEvaluator(
            model,
            tokenizer,
            tokenized_data["validation"],
            data_collator=data_collator,
            config=self.config,
            rng=eval_rng,
        )
        trainer = DFSDPTrainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=tokenized_data["train"],
            data_collator=data_collator,
            evaluator=evaluator,
            wandb_run=self.wandb_run,
            rng=init_rng,
            model_inputs_orded=("input_ids", "labels"),
            model_outputs_orded=("loss",),
            prepare_optimizer_fn=prepare_freeze_optimizer,
        )
        state = trainer.get_distrib_state(train_rng)
        flattened_s = flax.traverse_util.flatten_dict(state.params, sep=".")

        flattened_s = model.load_pret_scan(state_pt, flattened_s)
        params = flax.traverse_util.unflatten_dict(flattened_s, sep=".")
        state = state.replace(params=params)

        shape_m = jax.tree.map(
            lambda x: x.shape,
            nn.meta.unbox(state.params),
        )
        print("Shape model is: ", json.dumps(shape_m))

        print(jax.tree_util.tree_map(lambda x: x.shape, state.params))
        trainer.train(train_rng, state)

    def evaluate(self, train_rng):
        """
        Train up to 1024 or (any other fixed) and evaluate next token prediction accuracy on larger seq lens
        """
        print("Evaluation")
        train_rng, init_rng, init_rng2, eval_rng = jax.random.split(train_rng, 4)
        dp, tokenizer, raw_data, tokenized_data = get_dp(self.config)

        print(tokenized_data)
        data_collator = dp.get_collate_fn(
            return_type="np", max_seq_len=self.config.max_seq_len
        )

        model = self.get_model(dp, sharded=True)
        evaluator = LanguageEvaluator(
            model,
            tokenizer,
            tokenized_data["validation"],
            data_collator=data_collator,
            config=self.config,
            rng=eval_rng,
        )

        # chk_path = "/home/user/latte_trans/data/out_latte/copy_50_latte_vapor_1e_4_16/checkpoints"
        # chk_path = "/data_user/data/out_latte/pile_vapor_64/checkpoints"
        chk_path = self.config.check_path
        trainer = DFSDPTrainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=tokenized_data["validation"],
            data_collator=data_collator,
            evaluator=evaluator,
            wandb_run=self.wandb_run,
            rng=init_rng,
            model_inputs_orded=("input_ids", "labels"),
            prepare_optimizer_fn=prepare_freeze_optimizer,
        )
        data = {
            "input_ids": jnp.zeros(
                (self.config.batch_size, self.config.max_seq_len), dtype=jnp.int32
            ),
            "labels": jnp.zeros(
                (self.config.batch_size, self.config.max_seq_len), dtype=jnp.int32
            ),
        }
        data = tuple(data.values())  # from BatchEncoding to tuple
        zero_state = trainer._jit_init_fn(
            self.config.batchnorm, model, trainer._optimizer, init_rng2, data
        )
        state, metadata = trainer.load_trainer_state(
            empty_state=zero_state, check_dir=chk_path, step_number=None
        )
        print("loaded state")
        eval_fn = partial(trainer.trainer_eval, state, eval_rng)
        scores = evaluator.evaluate(
            trainer_eval_fn=eval_fn, prefix="eval_", state=state
        )
        print("Scores loaded: ", scores)

    def eval_speed_bench(self, train_rng):
        """
        Train up to 1024 or (any other fixed) and evaluate next token prediction accuracy on larger seq lens
        """
        print("Evaluation")
        train_rng, init_rng, init_rng2, eval_rng = jax.random.split(train_rng, 4)
        dp, tokenizer, raw_data, tokenized_data = get_dp(self.config)

        print(tokenized_data)
        data_collator = dp.get_collate_fn(
            return_type="np", max_seq_len=self.config.max_seq_len
        )

        model = self.get_model(dp, sharded=True)

        evaluator = LanguageEvaluator(
            model,
            tokenizer,
            tokenized_data["validation"],
            data_collator=data_collator,
            config=self.config,
            rng=eval_rng,
        )

        # chk_path = "/home/user/latte_trans/data/out_latte/copy_50_latte_vapor_1e_4_16/checkpoints"
        # chk_path = "/data_user/data/out_latte/pile_vapor_64/checkpoints"
        chk_path = self.config.check_path
        trainer = DFSDPTrainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=tokenized_data["validation"],
            data_collator=data_collator,
            evaluator=evaluator,
            wandb_run=self.wandb_run,
            rng=init_rng,
            model_inputs_orded=("input_ids", "labels"),
        )
        data = {
            "input_ids": jnp.zeros(
                (self.config.batch_size, self.config.max_seq_len), dtype=jnp.int32
            ),
            "labels": jnp.zeros(
                (self.config.batch_size, self.config.max_seq_len), dtype=jnp.int32
            ),
        }
        data = tuple(data.values())  # from BatchEncoding to tuple
        zero_state = trainer._jit_init_fn(
            self.config.batchnorm, model, trainer._optimizer, init_rng2, data
        )
        print("loaded state")
        param_count = sum(x.size for x in jax.tree_util.tree_leaves(zero_state.params))
        jax.debug.print("Number of parameters: {x} M", x=param_count / 1000000)
        shape_m = jax.tree.map(
            lambda x: x.shape,
            nn.meta.unbox(zero_state.params),
        )
        print("Shape model is: ", json.dumps(shape_m))

        eval_fn = partial(trainer.trainer_eval, zero_state, eval_rng)
        scores = evaluator.evaluate(
            trainer_eval_fn=eval_fn, prefix="eval_", state=zero_state
        )
        print("Scores loaded: ", scores)

    def transfer_learning(self, train_rng):
        if self.config.hugg_chk:
            match self.config.block_type:
                case "llama":
                    state_pt = LlamaForCausalLM.from_pretrained(
                        self.config.hugg_chk,
                        cache_dir=Path(self.config.base_dir) / "input/cache_hugg",
                        device_map="cpu",
                        max_memory={"cpu": "5GB"},
                        offload_state_dict=True,
                    ).state_dict()
                case "gemma":
                    state_pt = Gemma2Model.from_pretrained(
                        self.config.hugg_chk,
                        cache_dir=Path(self.config.base_dir) / "input/cache_hugg",
                        device_map="cpu",
                        max_memory={"cpu": "10GB"},
                        offload_state_dict=True,
                    ).state_dict()
                case _:
                    state_pt = AutoModelForCausalLM.from_pretrained(
                        self.config.hugg_chk,
                        cache_dir=Path(self.config.base_dir) / "input/cache_hugg",
                        device_map="cpu",
                        max_memory={"cpu": "5GB"},
                        offload_state_dict=True,
                    ).state_dict()

        dp, tokenizer, raw_data, tokenized_data = get_dp(self.config)
        # print(tokenized_data)
        data_collator = dp.get_collate_fn(return_type="np")

        # print("=" * 100)
        # print(raw_data["validation"][0])
        # print(raw_data["validation"][1])
        # print(tokenized_data["validation"][0])
        # print(tokenizer.decode(tokenized_data["validation"][0]["input_ids"]))
        val_loader = DataLoader(
            tokenized_data["validation"],
            batch_size=4,
            shuffle=False,
            collate_fn=data_collator,
            drop_last=True,
        )
        # print("-" * 100)
        # test_data = next(iter(val_loader))
        # print(test_data)
        # print(tokenizer.batch_decode(test_data["input_ids"]))

        model = self.get_model(dp, sharded=True)
        train_rng, init_rng, eval_rng, init_rng2 = jax.random.split(train_rng, 4)

        evaluator = LanguageEvaluator(
            model,
            tokenizer,
            tokenized_data["validation"],
            data_collator=data_collator,
            config=self.config,
            rng=eval_rng,
        )
        trainer = DFSDPTrainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=tokenized_data["train"],
            data_collator=data_collator,
            evaluator=evaluator,
            wandb_run=self.wandb_run,
            rng=init_rng,
            model_inputs_orded=("input_ids", "labels"),
            model_outputs_orded=("loss",),
            prepare_optimizer_fn=prepare_freeze_optimizer,
        )
        state = trainer.get_distrib_state(train_rng)
        flattened_s = flax.traverse_util.flatten_dict(state.params, sep=".")

        flattened_s = model.load_pret_scan(state_pt, flattened_s)
        params = flax.traverse_util.unflatten_dict(flattened_s, sep=".")
        state = state.replace(params=params)

        if self.config.check_path:
            data = {
                "input_ids": jnp.zeros(
                    (self.config.batch_size, self.config.max_seq_len), dtype=jnp.int32
                ),
                "labels": jnp.zeros(
                    (self.config.batch_size, self.config.max_seq_len), dtype=jnp.int32
                ),
            }
            data = tuple(data.values())  # from BatchEncoding to tuple
            state, metadata = trainer.load_trainer_state(
                empty_state=state,
                check_dir=self.config.check_path,
                step_number=None,
            )
        print("loaded state")

        shape_m = jax.tree.map(
            lambda x: x.shape,
            nn.meta.unbox(state.params),
        )
        print("Shape model is: ", json.dumps(shape_m))
        print(jax.tree_util.tree_map(lambda x: x.shape, state.params))
        # trainer.train(train_rng, state)
        eval_fn = partial(trainer.trainer_eval, state, eval_rng)
        scores = evaluator.evaluate(
            trainer_eval_fn=eval_fn, prefix="eval_", state=state
        )
        print("Scores loaded: ", scores)


def main():
    seed = 0
    rng = jax.random.PRNGKey(seed)
    rng, train_rng, sample_rng = jax.random.split(rng, 3)
    args = parse_args()
    config = LMTaskConfig.load(
        yaml_file=args.config_file, base_dir=args.base_dir, name=args.name
    )

    if config.disable_cache:
        print("Disabling Cache")
        disable_caching()

    task = LMPretTask(config)
    if args.evaluate:
        # task.evaluate(sample_rng)
        # task.eval_speed_bench(sample_rng)
        task.transfer_learning(sample_rng)
    else:
        # with jax.profiler.trace("/tmp/tensorboard"):
        # task.train(train_rng)
        task.train_formpt(train_rng)
        # task.finetune(train_rng)


if __name__ == "__main__":
    main()
