"""
Data Parallel Model sharding implementation of lm
"""

from typing import Dict, Tuple
from functools import partial
import itertools
import math
import os
import json
import logging
import wandb
import jax
from jax import numpy as jnp
import numpy as np
import optax
from datasets import disable_caching
import dataclasses

from transformers import AutoTokenizer
from datasets import DatasetDict
from latte_trans.preproc.lm_dp import Wiki103DP
from latte_trans.preproc.pile import PileStream
from latte_trans.preproc.openweb import OpenWeb
from latte_trans.preproc.toks import SpecialToksGPT2TokenizerFast
from latte_trans.preproc.tiny_stories import TinyStories
from latte_trans.preproc.book_corpus import BookCorpusLong

from latte_trans.trainer.dfsp_jax import (
    DFSDPTrainer,
    gather_params,
    shard_module_params,
    get_scheduler,
)
from latte_trans.trainer.jax_single_host import Trainer
from latte_trans.experiments.base import BaseTask
from latte_trans.models.tasks.lm import (
    LMHeadVanilla,
    LMHeadSOTA,
    LMHeadQual,
    RecurrentGemmaWrapper,
    Gemma,
)
from latte_trans.evals.lang_eval import (
    LanguageEvaluator,
    LanguageEvaluatorSeq,
)
from latte_trans.preproc.slim_pajama import SlimPajama
from latte_trans.experiments.utils import parse_args
from latte_trans.config import LMTaskConfig
from recurrentgemma import jax as recurrentgemma

logging.basicConfig(
    format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
    datefmt="%Y-%m-%d:%H:%M:%S",
    level=logging.INFO,
)
LOG = logging.getLogger(__name__)

# jax.config.update("jax_platform_name", "cpu")
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_disable_jit", True)

MIN_WEIGHT_SHARD_SIZE = 2**10


def map_nested_fn(fn):
    """
    Recursively apply `fn to the key-value pairs of a nested dict / pytree.
    We use this for some of the optax definitions below.
    """

    def map_fn(nested_dict):
        return {
            k: (map_fn(v) if hasattr(v, "keys") else fn(k, v))
            for k, v in nested_dict.items()
        }

    return map_fn


def prepare_optimizer(
    config: LMTaskConfig, total_steps: int
) -> Tuple[optax.GradientTransformation, optax.Schedule]:
    lr_scheduler = get_scheduler(config=config, total_steps=total_steps)

    ssm_fn = map_nested_fn(
        lambda k, _: (
            "ssm"
            if k
            in [
                "B",
                "Lambda_re",
                "Lambda_im",
                "log_step",
                "norm",
                "theta_log",
                "nu_log",
                "gamma_log",
                "B_re",
                "B_im",
            ]
            else "regular"
        )
    )

    regular_opt = optax.inject_hyperparams(optax.adamw)(
        learning_rate=lr_scheduler, weight_decay=config.weight_decay
    )

    optimizer = optax.multi_transform(
        {
            "ssm": optax.inject_hyperparams(optax.adam)(learning_rate=1e-4),
            "regular": regular_opt,
        },
        ssm_fn,
    )

    if config.grad_accumulation_steps > 1:
        optimizer = optax.MultiSteps(optimizer, config.grad_accumulation_steps)
    # chain with norm
    optimizer = optax.chain(optax.clip_by_global_norm(1.0), optimizer)
    return optimizer, lr_scheduler


def get_dp(config):
    if config.dataset_name == "wiki103":
        cache_dir = os.path.join(config.base_dir, "input", "wikitext-103")
        tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
        # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
        # tokenizer = SpecialToksGPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-125M")
        # SpecialToksGPT2TokenizerFast.from_pretrained("gpt2")
        dp = Wiki103DP(tokenizer=tokenizer, cache_dir=cache_dir)
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(raw_data, add_special_tokens=True)
    elif config.dataset_name == "pile":
        cache_dir = os.path.join(config.base_dir, "input", "pile2")
        tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
        dp = PileStream(tokenizer=tokenizer, cache_dir=cache_dir, num_load_procs=None)
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(raw_data, max_seq_len=config.max_seq_len)
    elif config.dataset_name == "owt":
        cache_dir = os.path.join(config.base_dir, "input", "tok_openweb")
        tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
        dp = OpenWeb(
            tokenizer=tokenizer,
            cache_dir=cache_dir,
            num_load_procs=None,
            max_seq_len=config.max_seq_len,
        )
        raw_data = None
        tok_data = DatasetDict.load_from_disk(cache_dir).with_format("np")
        print(tok_data)
    elif config.dataset_name == "tiny-stories":
        cache_dir = os.path.join(config.base_dir, "input", "tiny-sories")
        tokenizer = SpecialToksGPT2TokenizerFast.from_pretrained("gpt2")
        dp = TinyStories(
            tokenizer=tokenizer,
            cache_dir=cache_dir,
            max_seq_len=config.max_seq_len,
            num_load_procs=min(1, os.cpu_count() - 1),
        )
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(raw_data, force_preproc=config.disable_cache)

    elif config.dataset_name == "pajama":
        tokenizer = AutoTokenizer.from_pretrained(
            "google/gemma-2-2b",
            cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
            truncation_side="right",
            padding_side="right",
        )
        # We directly load the tokenized data. Slow to tokenize on the fligh
        raw_data = None
        tok_data = DatasetDict.load_from_disk(
            "/user_all_data/data/input/pajama_tokenized"
        )
        tok_data.set_format("np")
        # dummy data-processor only used for collate_fn
        dp = SlimPajama(
            tokenizer,
            cache_dir=None,
            num_load_procs=None,
            max_seq_len=config.max_seq_len,
        )
    elif config.dataset_name == "bookcorpus":
        tokenizer = AutoTokenizer.from_pretrained(
            "google/gemma-2-2b",
            cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
            truncation_side="right",
            padding_side="right",
        )
        data_dir = config.base_dir
        cache_dir = os.path.join(data_dir, "input/bookcorpus_abl")
        print("Book corpus")
        dp = BookCorpusLong(
            tokenizer,
            cache_dir=cache_dir,
            num_load_procs=8,
            max_seq_len=config.max_seq_len,
        )
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(raw_data, force_tok=config.disable_cache)
        tok_data.set_format("np")
    return dp, tokenizer, raw_data, tok_data


def get_eval_dp(config):
    """
    Creates Smaller dataset for evaluation
    """
    from transformers import AutoTokenizer
    from datasets import DatasetDict

    # We directly load the tokenized data. Slow to tokenize on the fligh
    match config.dataset_name:
        case "pajama":
            tokenizer = AutoTokenizer.from_pretrained(
                "google/gemma-2-2b",
                cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
                truncation_side="right",
                padding_side="right",
            )
            # tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
            data_path = "/user_all_data/data/input/pajama_raw"
            cache_dir = os.path.join(config.base_dir, "input/test_pajama")
            raw_data = DatasetDict.load_from_disk(data_path)
            raw_data["train"] = raw_data["train"].select(np.arange(100))
            raw_data["validation"] = raw_data["validation"].select(np.arange(2000))
            raw_data["test"] = raw_data["test"].select(np.arange(1000))
            dp = SlimPajama(
                tokenizer,
                cache_dir=cache_dir,
                num_load_procs=None,
                max_seq_len=config.max_seq_len,
            )
            tokenized_data = dp.tokenize(raw_dataset=raw_data, force_tok=True)
        case "bookcorpus":
            print("Book Corpus Loading")
            tokenizer = AutoTokenizer.from_pretrained(
                "google/gemma-2-2b",
                cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
                truncation_side="right",
                padding_side="right",
            )
            data_dir = config.base_dir
            cache_dir = os.path.join(data_dir, "input/bookcorpus_lmsh")
            dp = BookCorpusLong(
                tokenizer,
                cache_dir=cache_dir,
                num_load_procs=8,
                max_seq_len=config.max_seq_len,
            )
            raw_data = dp.get_raw_data()
            tokenized_data = dp.tokenize(raw_data, force_tok=config.disable_cache)
            tokenized_data.set_format("np")
        case "owt":
            cache_dir = os.path.join(config.base_dir, "input", "test_openweb")
            tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
            dp = OpenWeb(
                tokenizer=tokenizer,
                cache_dir=cache_dir,
                num_load_procs=8,
                max_seq_len=config.max_seq_len,
            )
            raw_data = dp.get_raw_data()
            # add a small sample for compatibility reasons
            raw_data["train"] = raw_data["train"].select(np.arange(1000))
            tokenized_data = dp.tokenize(raw_data, force_tok=config.disable_cache)
    return dp, tokenizer, raw_data, tokenized_data


class LMTask(BaseTask):
    def get_model(self, tokenizer, sharded=True):
        match self.config.mixed_precision:
            case "bf16":
                mixed_precission = jnp.bfloat16
            case _:
                mixed_precission = jnp.float32

        match self.config.block_type:
            case "griffin":
                griffin_pattern = itertools.cycle(
                    [
                        recurrentgemma.TemporalBlockType.RECURRENT,
                        # recurrentgemma.TemporalBlockType.RECURRENT,
                        recurrentgemma.TemporalBlockType.ATTENTION,
                    ]
                )

                print(tuple(itertools.islice(griffin_pattern, self.config.nlayers)))
                config = recurrentgemma.GriffinConfig(
                    vocab_size=tokenizer.vocab_size,
                    width=self.config.hidden_dim,
                    mlp_expanded_width=self.config.intermediate_dim,
                    lru_width=self.config.state_dim,
                    num_heads=self.config.nheads,
                    block_types=tuple(
                        itertools.islice(griffin_pattern, self.config.nlayers)
                    ),
                    embeddings_scale_by_sqrt_dim=True,
                    attention_window_size=self.config.att_block_len,
                    logits_soft_cap=None,  # 30.0,
                )
                sharded_model = shard_module_params(
                    RecurrentGemmaWrapper,
                    axis_name="B",
                    min_weight_size=MIN_WEIGHT_SHARD_SIZE,
                )
                model = sharded_model(config, dtype=mixed_precission)
                print(model)
                return model
            case "transformer":
                constructor = LMHeadVanilla
            case "transformer-sota":
                constructor = LMHeadSOTA
            case "transformer-qual":
                constructor = LMHeadQual
            case "gemma":
                constructor = Gemma
            case _ as unreachable:
                raise IOError("Type of model not supported")
        if sharded:
            sharded_model = shard_module_params(
                constructor, axis_name="B", min_weight_size=MIN_WEIGHT_SHARD_SIZE
            )
        else:
            sharded_model = constructor

        model = sharded_model(
            self.config,
            vocab_size=tokenizer.vocab_size,
            pad_id=tokenizer.pad_token_id,
            dtype=mixed_precission,
            sharded=sharded,
        )
        print(model)
        return model

    def train(self, train_rng):
        dp, tokenizer, raw_data, tokenized_data = get_dp(self.config)
        data_collator = dp.get_collate_fn(
            return_type="np", max_seq_len=self.config.max_seq_len
        )

        model = self.get_model(tokenizer, sharded=True)
        train_rng, init_rng, eval_rng = jax.random.split(train_rng, 3)
        evaluator = LanguageEvaluator(
            model,
            tokenizer,
            tokenized_data["validation"],
            data_collator=data_collator,
            config=self.config,
            rng=eval_rng,
        )
        trainer = DFSDPTrainer(  # Trainer( #
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=tokenized_data["train"],
            data_collator=data_collator,
            evaluator=evaluator,
            wandb_run=self.wandb_run,
            rng=init_rng,
            model_inputs_orded=("input_ids", "labels"),
            prepare_optimizer_fn=prepare_optimizer,  # add ssm grad
        )
        if not self.config.check_path is None:
            trainer.train(train_rng, checkpoint_path=self.config.check_path)
        else:
            trainer.train(train_rng)

    def evaluate_seq_gen(self, train_rng):
        """
        Train up to 1024 or (any other fixed) and evaluate next token prediction accuracy on larger seq lens
        """
        print("Evaluation")
        train_rng, init_rng, init_rng2, eval_rng = jax.random.split(train_rng, 4)
        dp, tokenizer, raw_data, tokenized_data = get_eval_dp(self.config)
        data_collator = dp.get_collate_fn(
            return_type="np", max_seq_len=self.config.max_seq_len
        )

        model = self.get_model(tokenizer, sharded=True)
        evaluator = LanguageEvaluator(
            model,
            tokenizer,
            tokenized_data["validation"],
            data_collator=data_collator,
            config=self.config,
            rng=eval_rng,
        )

        # chk_path = "/home/user/latte_trans/data/out_latte/copy_50_latte_vapor_1e_4_16/checkpoints"
        # chk_path = "/data_user/data/out_latte/pile_vapor_64/checkpoints"
        chk_path = self.config.check_path
        trainer = DFSDPTrainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=tokenized_data["validation"],
            data_collator=data_collator,
            evaluator=evaluator,
            wandb_run=self.wandb_run,
            rng=init_rng,
            model_inputs_orded=("input_ids", "labels"),
            prepare_optimizer_fn=prepare_optimizer,  # add ssm grad
        )
        data = {
            "input_ids": jnp.zeros(
                (self.config.batch_size, self.config.max_seq_len), dtype=jnp.int32
            ),
            "labels": jnp.zeros(
                (self.config.batch_size, self.config.max_seq_len), dtype=jnp.int32
            ),
        }
        data = tuple(data.values())  # from BatchEncoding to tuple
        zero_state = trainer._jit_init_fn(
            self.config.batchnorm, model, trainer._optimizer, init_rng2, data
        )
        state, metadata = trainer.load_trainer_state(
            empty_state=zero_state, check_dir=chk_path, step_number=None
        )
        print("loaded state")
        # eval_fn = partial(trainer.trainer_eval, state, eval_rng)
        # scores = evaluator.evaluate(
        #     trainer_eval_fn=eval_fn, prefix="eval_", state=state
        # )
        # print("Eval metrics: ", "-" * 100)
        # print(scores)
        # print("Seq Len Eval: ", "-" * 100)
        evaluator = LanguageEvaluatorSeq(
            model,
            tokenizer,
            tokenized_data["validation"],
            data_collator=data_collator,
            config=self.config,
            eval_batch_size=self.config.batch_size,
            rng=eval_rng,
        )
        eval_fn = partial(trainer.trainer_eval, state, eval_rng)
        scores = evaluator.evaluate(
            trainer_eval_fn=eval_fn, prefix="eval_", state=state
        )

        import matplotlib.pyplot as plt
        import json

        SAVE_PATH = "/home/user/latte_trans/data/seq_len_plts"
        with open(
            f"{SAVE_PATH}/scores_{self.config.dataset_name}.json",
            "w",
        ) as f:
            to_save = {}
            to_save["config"] = str(self.config)
            to_save["PPL"] = scores["PPL"].tolist()
            to_save["PPL_mean"] = scores["PPL_mean"].tolist()
            json.dump(to_save, f)
        start_window = 128
        cap = 200
        ppl = np.array(scores["PPL"])
        print(ppl.shape)
        ppl = np.where(ppl > cap, cap, ppl)

        plt.figure()
        plt.plot(
            start_window + np.arange(len(ppl[0, start_window:])), ppl[0, start_window:]
        )
        plt.savefig(f"{SAVE_PATH}/ppl_{self.config.dataset_name}.png")

        ppl = np.array(scores["PPL_mean"])
        print(ppl.shape)
        ppl = np.where(ppl > cap, cap, ppl)

        plt.figure()
        plt.plot(
            start_window + np.arange(len(ppl[0, start_window:])), ppl[0, start_window:]
        )
        plt.savefig(f"{SAVE_PATH}/ppl_{self.config.dataset_name}2.png")

        plt.figure()
        step_sz = 200
        ppl = np.array(scores["PPL"])
        ppl = np.where(ppl > cap, cap, ppl)
        plt_every = ppl[0, start_window::step_sz]
        x = np.arange(len(ppl[0]))[start_window::step_sz]
        plt.plot(x, plt_every)
        plt.savefig(f"{SAVE_PATH}/ppl_{self.config.dataset_name}3.png")

    def check_latens_collapse(self, train_rng):
        """
        Check if latents collapse after training
        """
        train_rng, init_rng, init_rng2, eval_rng = jax.random.split(train_rng, 4)
        dp, tokenizer, raw_data, tokenized_data = get_eval_dp(self.config)
        data_collator = dp.get_collate_fn(
            return_type="np", max_seq_len=self.config.max_seq_len
        )

        model = self.get_model(tokenizer, sharded=True)
        evaluator = LanguageEvaluator(
            model,
            tokenizer,
            tokenized_data["validation"],
            data_collator=data_collator,
            config=self.config,
            rng=eval_rng,
        )

        # chk_path = "/home/user/latte_trans/data/out_latte/copy_50_latte_vapor_1e_4_16/checkpoints"
        # chk_path = "/data_user/data/out_latte/pile_vapor_64/checkpoints"
        chk_path = self.config.check_path
        trainer = DFSDPTrainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=tokenized_data["validation"],
            data_collator=data_collator,
            evaluator=evaluator,
            wandb_run=self.wandb_run,
            rng=init_rng,
            model_inputs_orded=("input_ids", "labels"),
            prepare_optimizer_fn=prepare_optimizer,  # add ssm grad
        )
        data = {
            "input_ids": jnp.zeros(
                (self.config.batch_size, self.config.max_seq_len), dtype=jnp.int32
            ),
            "labels": jnp.zeros(
                (self.config.batch_size, self.config.max_seq_len), dtype=jnp.int32
            ),
        }
        data = tuple(data.values())  # from BatchEncoding to tuple
        zero_state = trainer._jit_init_fn(
            self.config.batchnorm, model, trainer._optimizer, init_rng2, data
        )
        state, metadata = trainer.load_trainer_state(
            empty_state=zero_state, check_dir=chk_path, step_number=None
        )
        eval_fn = partial(trainer.trainer_eval, state, eval_rng)
        scores = evaluator.evaluate(
            trainer_eval_fn=eval_fn, prefix="eval_", state=state
        )
        print("Eval metrics: ", "-" * 100)
        print(scores)
        print("Seq Len Eval: ", "-" * 100)

        res = jax.tree_util.tree_map(jax.device_get, state.params)
        model = self.get_model(tokenizer, sharded=False)
        data = next(iter(evaluator._val_loader))
        output, interim = model.apply(
            {"params": res},  # {"params": state.params},  #
            data["input_ids"],
            data["labels"],
            train=False,
            # capture_intermediates=True,
            mutable="intermediates",
        )

        def f(Qs):
            jax.tree_util.tree_map_with_path(
                lambda k, x: jax.numpy.save(f"./Qs_{str(k)}", x),
                Qs,
            )
            return Qs

        # jax.jax.experimental.io_callback(f, interim, interim)
        print(interim["intermediates"].keys())
        Qs = interim["intermediates"]["DecoderSota_0"]["residual_block"][
            "CausalRopeLatteMachiattoSliding_0"
        ]["Qs"]
        # K = interim["intermediates"]["Decoder_0"]["transformer_block"][
        #     "RotCausalScanLatte_0"
        # ]["K"]
        # nr_applications module.apply, nr_layers, batch_size, shape
        print(len(Qs))
        Qs = Qs[0]
        # K = K[0]
        print(Qs.shape)
        # NTBHL
        print(
            Qs[0, 10, 0, 0, 0],
            Qs[0, 10, 0, 0, 1],
            Qs[0, 10, 0, 0, 2],
            Qs[0, 10, 0, 0, -1],
        )
        # print(K.shape)
        import matplotlib.pyplot as plt
        import seaborn as sns
        from mpl_sizes import get_format

        plt.rcParams["text.usetex"] = True
        formatter = get_format("ICLR")  # options: ICLR, ICML, NeurIPS, InfThesis
        plt.rcParams["text.usetex"] = True
        plt.rcParams["font.family"] = "serif"
        plt.rcParams["font.serif"] = ["Times"]
        plt.rcParams["axes.labelsize"] = "large"

        # colors = ['#03045e', '#033e8a', '#0077b6', '#0296c8', '#06b4d8', '#49cae4']
        colors = ["#EB8531", "#30CEEA"]

        # plt.figure(figsize=(10, 10))
        # plt.plot(np.arange(Qs.shape[1]), Qs[0, :, 0, 0, 0], label="layer0")
        # plt.plot(np.arange(Qs.shape[1]), Qs[1, :, 0, 0, 0], label="layer1")
        # plt.plot(np.arange(Qs.shape[1]), Qs[11, :, 0, 0, 0], label="layer11")
        # plt.ylabel("p(l=0|t)")
        # plt.xlabel("t")
        # plt.legend()
        # plt.tight_layout()
        # plt.savefig(f"test_collapse.png")

        # plt.figure()
        # plt.plot(np.arange(Qs.shape[1]), Qs[1, :, 0, 0, 0], label="l=0")
        # plt.plot(np.arange(Qs.shape[1]), Qs[1, :, 0, 0, 1], label="l=1")
        # plt.plot(np.arange(Qs.shape[1]), Qs[1, :, 0, 0, 16], label="l=16")
        # plt.ylabel("Layer 1: p(l|t)")
        # plt.xlabel("t")
        # plt.legend()
        # plt.tight_layout()
        # plt.savefig(f"test_collapse2.png")

        # plt.figure()
        # plt.plot(np.arange(Qs.shape[-1]), Qs[1, 128, 0, 0, :], label="t=128")
        # plt.plot(np.arange(Qs.shape[-1]), Qs[1, 1, 0, 0, :], label="t=1")
        # plt.plot(np.arange(Qs.shape[-1]), Qs[1, 890, 0, 0, :], label="t=890")
        # plt.ylabel("Layer 1: p(l|t)")
        # plt.xlabel("l")
        # plt.legend()
        # plt.tight_layout()
        # plt.savefig(f"test_collapse3.png")

        # plot_attention_maps(Qs)
        plt.figure(figsize=(10, 10))
        T = 24
        x = Qs[0, :T, 0, 0, :]
        L = x.shape[-1]
        x = np.array(x, dtype=np.float32)
        plt.xlabel("l")
        plt.ylabel("T")
        im = plt.imshow(
            x, extent=[0, L, 0, T], aspect="auto", vmin=0, vmax=1, origin="lower"
        )  # , cmap="grey")
        plt.yticks(list(range(T)), rotation=60)
        plt.xticks(list(range(L)))

        cbar = plt.colorbar(im)
        plt.tight_layout()
        plt.savefig("test_att.pdf")

        figsize = (
            10,
            11,
        )  # (formatter.text_width_plot()[0], formatter.text_width_plot()[1])
        plt.figure()
        # NTBHL -> BNHTL
        attn_maps = Qs.transpose(2, 0, 3, 1, 4)[0, :, :, :T, :]
        attn_maps = np.array(attn_maps, dtype=np.float32)
        num_heads = attn_maps.shape[1]
        num_layers = attn_maps.shape[0]
        fig, ax = plt.subplots(
            num_layers,
            num_heads,
            figsize=(num_heads * figsize[0], num_layers * figsize[1]),
        )

        if num_layers == 1:
            ax = [ax]
        if num_heads == 1:
            ax = [[a] for a in ax]
        for row in range(num_layers):
            for column in range(num_heads):
                ax[row][column].set_xlabel("l", fontsize=30)
                ax[row][column].set_ylabel("t", fontsize=30)
                ax[row][column].imshow(
                    attn_maps[row][column],
                    extent=[0, L, 0, T],
                    aspect="auto",
                    vmin=0,
                    vmax=1,
                    origin="lower",
                )
                ax[row][column].set_yticks(
                    [x - 0.5 for x in list(range(1, T + 1))],
                    list(range(T)),
                    fontsize=30,
                )
                ax[row][column].set_xticks(
                    [x - 0.5 for x in list(range(1, L + 1))],
                    list(range(L)),
                    fontsize=30,
                )
                ax[row][column].set_title(f"Layer {row+1}, Head {column+1}")
        fig.subplots_adjust(hspace=0.5)
        plt.tight_layout()
        fig.savefig("test_att2.pdf")

        wdt = 11
        hgt = 10
        plt.figure(figsize=(hgt, wdt))
        # extent = [0.5, L + 0.5, 0.5, T + 0.5]
        im = plt.imshow(
            attn_maps[3][0],
            extent=[0, L, 0, T],
            aspect="auto",
            vmin=0,
            vmax=1,
            interpolation="none",
            origin="lower",
        )
        plt.yticks(
            [x - 0.5 for x in list(range(1, T + 1))], list(range(T)), fontsize=30
        )
        plt.xticks(
            [x - 0.5 for x in list(range(1, L + 1))], list(range(L)), fontsize=30
        )

        plt.xlabel("l", fontsize=30)
        plt.ylabel("t", fontsize=30)
        plt.tight_layout()
        # cbar = plt.colorbar(im)
        plt.savefig("test_att_l4h1.pdf")

        plt.figure(figsize=(hgt, wdt))
        im = plt.imshow(
            attn_maps[2][5],
            extent=[0, L, 0, T],
            aspect="auto",
            vmin=0,
            vmax=1,
            origin="lower",
        )
        plt.yticks(
            [x - 0.5 for x in list(range(1, T + 1))], list(range(T)), fontsize=30
        )
        plt.xticks(
            [x - 0.5 for x in list(range(1, L + 1))], list(range(L)), fontsize=30
        )

        plt.xlabel("l", fontsize=30)
        plt.ylabel("t", fontsize=30)
        plt.tight_layout()
        plt.savefig("test_att_l3h6.pdf")

        plt.figure(figsize=(hgt, wdt))
        im = plt.imshow(
            attn_maps[0][7],
            extent=[0, L, 0, T],
            aspect="auto",
            vmin=0,
            vmax=1,
            origin="lower",
        )
        plt.yticks(
            [x - 0.5 for x in list(range(1, T + 1))], list(range(T)), fontsize=30
        )
        plt.xticks(
            [x - 0.5 for x in list(range(1, L + 1))], list(range(L)), fontsize=30
        )

        plt.xlabel("l", fontsize=30)
        plt.ylabel("t", fontsize=30)
        plt.tight_layout()
        plt.savefig("test_att_l1h8.pdf")

        plt.figure(figsize=(hgt, wdt))
        im = plt.imshow(
            attn_maps[9][1],
            extent=[0, L, 0, T],
            aspect="auto",
            vmin=0,
            vmax=1,
            origin="lower",
        )
        plt.yticks(
            [x - 0.5 for x in list(range(1, T + 1))], list(range(T)), fontsize=30
        )
        plt.xticks(
            [x - 0.5 for x in list(range(1, L + 1))], list(range(L)), fontsize=30
        )

        plt.xlabel("l", fontsize=30)
        plt.ylabel("t", fontsize=30)
        plt.tight_layout()
        plt.tight_layout()
        plt.savefig("test_att_l10h2.pdf")

        # for l in range(Qs.shape[0]):
        #     plt.figure()
        #     for t in [0, 126, 256, 512, 1024]:
        #         plt.plot(np.arange(Qs.shape[-1]), Qs[l, t, 16, 0], label=f"t_{t}")
        #         plt.xlabel("Latent l")
        #         plt.ylabel("p(l|t)")
        #         plt.legend()
        #         plt.savefig(f"/home/ubuntu/latte_trans/data/coll/qs_{l}.png")

        # for layer in range(K.shape[0]):
        #     for t in [256, 512, 1024]:
        #         plt.figure(figsize=(10, 7))
        #         for lat in [0, 16, 32, 64]:
        #             k = K[layer, :t, 0, 0, lat]
        #             k = jax.nn.softmax(k)
        #             plt.plot(np.arange(t), k, label=f"t_{t}_l_{lat}")

        #             plt.xlabel("Token s")
        #             plt.ylabel("p(s|l,t)")
        #             plt.legend()
        #             plt.savefig(
        #                 f"/home/ubuntu/latte_trans/data/coll/ks_{layer}_t{t}.png"
        #             )


def main():
    seed = 0
    rng = jax.random.PRNGKey(seed)
    rng, train_rng, sample_rng = jax.random.split(rng, 3)
    args = parse_args()
    config = LMTaskConfig.load(
        yaml_file=args.config_file, base_dir=args.base_dir, name=args.name
    )

    if config.disable_cache:
        LOG.info("Disabling Cache")
        disable_caching()

    task = LMTask(config)
    if args.evaluate:
        # task.evaluate_seq_gen(sample_rng)
        # task.evaluate(sample_rng)
        task.check_latens_collapse(sample_rng)
    else:
        task.train(train_rng)


if __name__ == "__main__":
    main()
