# Some code based on https://github.com/epfml/landmark-attention
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import math
from dataclasses import dataclass, field
from functools import partial
from typing import Dict, Optional, Sequence

import torch
import transformers
from torch.utils.data import Dataset
from transformers import Trainer, DataCollatorForLanguageModeling
from llama_attn_hici import (
    replace_llama_attn,
    register_hici_to_model,
)
from gptneox_attn_replace import replace_gpt_neox_attn
from peft import LoraConfig, get_peft_model
from torch.distributed import barrier

from datasets import load_dataset

IGNORE_INDEX = -100

# ============================================================================
# Custom Trainer with Layered Learning Rates
# ============================================================================

class LayeredLRTrainer(Trainer):
    """
    Custom Trainer that supports different learning rates for HiCI parameters.

    Usage:
        trainer = LayeredLRTrainer(
            model=model,
            args=training_args,  # Must have hici_lr set
            ...
        )
    """

    def create_optimizer(self):
        """
        Create optimizer with separate learning rates for different parameter groups.

        If args.hici_lr is set, HiCI parameters use that lr,
        while other parameters use args.learning_rate.
        """
        if self.optimizer is None:
            # Check if we need layered learning rates
            if self.args.hici_lr is not None:
                # Only print on rank 0
                is_main_process = self.args.local_rank <= 0
                if is_main_process:
                    print("\n" + "=" * 70)
                    print("Creating Optimizer with Layered Learning Rates")
                    print("=" * 70)

                # Separate parameters into groups
                local_constructor_params = []
                global_integrator_params = []
                other_params = []

                for n, p in self.model.named_parameters():
                    if not p.requires_grad:
                        continue

                    if "local_constructor" in n:
                        local_constructor_params.append(p)
                    elif "global_integrator" in n:
                        global_integrator_params.append(p)
                    else:
                        other_params.append(p)

                # Combine HiCI params for optimizer (may be empty if HiCI is not enabled)
                hici_params = local_constructor_params + global_integrator_params

                # Create parameter groups with different learning rates
                optimizer_grouped_parameters = [
                    {
                        "params": hici_params,
                        "lr": self.args.hici_lr,
                        "weight_decay": self.args.weight_decay,
                    },
                    {
                        "params": other_params,
                        "lr": self.args.learning_rate,
                        "weight_decay": self.args.weight_decay,
                    },
                ]

                # Create optimizer
                optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
                    self.args
                )

                # Remove 'lr' from kwargs since we're specifying it per group
                optimizer_kwargs.pop("lr", None)

                self.optimizer = optimizer_cls(
                    optimizer_grouped_parameters, **optimizer_kwargs
                )

                # Print summary with detailed breakdown (only on rank 0)
                if is_main_process:
                    local_constructor_count = sum(p.numel() for p in local_constructor_params)
                    hierarchical_count = sum(
                        p.numel() for p in global_integrator_params
                    )
                    hici_count = local_constructor_count + hierarchical_count
                    other_count = sum(p.numel() for p in other_params)
                    total_count = hici_count + other_count

                    print(f"\n   HiCI Module Parameters Breakdown:")
                    print(f"  " + "-" * 68)

                    if local_constructor_count > 0:
                        print(f"     LocalConstructor:")
                        print(
                            f"       Count: {local_constructor_count:,} ({local_constructor_count / total_count * 100:.2f}%)"
                        )
                    else:
                        print(f"     LocalConstructor: Not enabled (0 parameters)")

                    if hierarchical_count > 0:
                        print(f"      HierarchicalAggregator:")
                        print(
                            f"       Count: {hierarchical_count:,} ({hierarchical_count / total_count * 100:.2f}%)"
                        )
                    else:
                        print(
                            f"      HierarchicalAggregator: Not enabled (0 parameters)"
                        )

                    print(f"  " + "-" * 68)
                    print(
                        f"     Total HiCI Modules: {hici_count:,} ({hici_count / total_count * 100:.2f}%)"
                    )
                    print(f"     Learning Rate: {self.args.hici_lr:.2e}")

                    print(f"\n   Other Trainable Parameters:")
                    print(
                        f"    Count: {other_count:,} ({other_count / total_count * 100:.2f}%)"
                    )
                    print(f"    Learning Rate: {self.args.learning_rate:.2e}")

                    print(
                        f"\n   Learning Rate Ratio: {self.args.hici_lr / self.args.learning_rate:.1f}x"
                    )
                    print("=" * 70 + "\n")

            else:
                # Use standard optimizer creation (only print on rank 0)
                if self.args.local_rank <= 0:
                    print("\n Using uniform learning rate for all parameters")
                    print(f"   Learning Rate: {self.args.learning_rate:.2e}\n")
                return super().create_optimizer()

        return self.optimizer

    def training_step(self, model, inputs):
        """
        Perform a training step with separate gradient clipping for HiCI modules.

        If args.hici_grad_clip is set, HiCI module parameters get stricter clipping
        than other parameters (which use args.max_grad_norm).
        """
        # Call parent's training_step to do forward + backward
        loss = super().training_step(model, inputs)

        # Apply separate gradient clipping if configured
        if (
            self.args.hici_grad_clip is not None
            and self.args.hici_lr is not None
        ):
            # Separate parameters into groups
            hici_params = []
            other_params = []

            for name, param in model.named_parameters():
                if param.grad is not None:
                    # Check if this is a HiCI module parameter
                    if "local_constructor" in name or "global_integrator" in name:
                        hici_params.append(param)
                    else:
                        other_params.append(param)

            # Apply stricter gradient clipping to HiCI modules ONLY
            # Other parameters (embed, norm) are stable and don't need clipping
            if hici_params:
                torch.nn.utils.clip_grad_norm_(
                    hici_params, max_norm=self.args.hici_grad_clip
                )

            # Note: Other parameters don't use gradient clipping
            # This follows LongLoRA's original design which doesn't clip embed/norm gradients

            # Print gradient clipping info only once (on rank 0, first step)
            if not hasattr(self, "_grad_clip_printed"):
                is_main_process = self.args.local_rank <= 0
                if is_main_process:
                    print("\n" + "=" * 70)
                    print("Gradient Clipping Configuration")
                    print("=" * 70)
                    print(f"   HiCI Modules:")
                    print(f"     Max Gradient Norm: {self.args.hici_grad_clip}")
                    print(f"     Num Parameters: {len(hici_params)}")
                    print(f"\n   Other Parameters (embed, norm):")
                    print(f"     Max Gradient Norm: None (no clipping)")
                    print(f"     Num Parameters: {len(other_params)}")
                    print("=" * 70 + "\n")
                self._grad_clip_printed = True

        return loss

DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
    model_type: Optional[str] = field(default="llama")

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=8192 * 4,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_flash_attn: bool = field(
        default=True,
        metadata={"help": "Whether use flash attention for training."},
    )
    use_full_attn: bool = field(
        default=False,
        metadata={"help": "Whether to use plain, full-attention for training."},
    )
    low_rank_training: bool = field(
        default=True,
        metadata={"help": "Whether use low rank adaptation for training."},
    )
    use_yarn_rope: bool = field(
        default=False,
        metadata={"help": "Whether to use YaRN (Yet another RoPE extensioN) instead of PI. No Transformers upgrade needed."},
    )
    trainable_params: str = field(
        default="embed,norm",
        metadata={
            "help": "Additional trainable parameters except LoRA weights, if low rank training."
        },
    )
    num_local_slots: int = field(
        default=8,
        metadata={
            "help": "Number of local local query slots for capturing chunk-level context (default: 8)."
        },
    )
    global_slots: int = field(
        default=16,
        metadata={
            "help": "Number of HiCI slots for capturing document-level context (default: 16)."
        },
    )
    num_chunks: int = field(
        default=4,
        metadata={
            "help": "Number of chunks to split each sequence into for HiCI."
        },
    )
    use_local_summary: bool = field(
        default=True,
        metadata={"help": "Whether to use local representation extraction."},
    )
    use_global_repr: bool = field(
        default=True,
        metadata={"help": "Whether to use HiCI aggregator."},
    )
    use_flash_attn_in_hici: bool = field(
        default=False,
        metadata={"help": "Whether to use flash attn in LocalConstructorFlash."},
    )
    use_flash_plus: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to use LocalConstructorFlashPlus."},
    )
    use_flash_plus_norope: Optional[bool] = field(
        default=False,
        metadata={
        "help": "： HiCI RoPE， RoPE plus "
        },
    )
    forward_flashattn_optimized: Optional[bool] = field(
        default=False,
        metadata={"help": "forward_flashattn_optimized"},
    )
    use_hierarchical_forward: Optional[bool] = field(
        default=False,
        metadata={"help": "，+"},
    )
    use_llama_init: Optional[bool] = field(
        default=False,
        metadata={
        "help": "C：HiCI Q/K/V（Warm Initialization）"
        },
    )
    num_heads: int = field(
        default=32,
        metadata={"help": "Number of attention heads in the HiCI module."},
    )
    use_bottleneck: bool = field(
        default=True,
        metadata={
            "help": "Whether to use bottleneck in HiCI aggregator."
        },
    )
    bottleneck_dim: int = field(
        default=4096,
        metadata={"help": "Bottleneck dimension for representation compression."},
    )
    recurrence_size: Optional[int] = field(
        default=128,
        metadata={
            "help": "Number of tokens to carry from previous chunk (Transformer-XL style, default: 256)."
        },
    )
    hici_lr: Optional[float] = field(
        default=None,
        metadata={
            "help": "Separate learning rate for HiCI parameters. "
            "If None, uses the same learning rate as other parameters. "
            "Recommended: 2e-4 to 5e-4 (10-25x base lr)."
        },
    )
    hici_grad_clip: Optional[float] = field(
        default=None,
        metadata={
            "help": "Separate gradient clipping for HiCI module parameters. "
            "If None, uses the same max_grad_norm as other parameters. "
            "Recommended: 0.1 to 0.3 (stricter than default 1.0). "
            "This helps prevent gradient explosion in HiCI modules."
        },
    )

def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

def tokenize_fn(tokenizer, example):
    """

    """
    context_length = tokenizer.model_max_length
    MAX_CHARS_PER_SEGMENT = 200_000

    texts = example["text"]
    all_chunks = []

    current_batch = []
    current_length = 0

    for text in texts:
        text_len = len(text)

        if text_len > MAX_CHARS_PER_SEGMENT:
            if current_batch:
                combined = tokenizer.eos_token.join(current_batch)
                chunks = _tokenize_and_chunk(tokenizer, combined, context_length)
                all_chunks.extend(chunks)
                current_batch = []
                current_length = 0

            num_segments = (
                text_len + MAX_CHARS_PER_SEGMENT - 1
            ) // MAX_CHARS_PER_SEGMENT
            for i in range(num_segments):
                start = i * MAX_CHARS_PER_SEGMENT
                end = min((i + 1) * MAX_CHARS_PER_SEGMENT, text_len)
                segment = text[start:end]
                chunks = _tokenize_and_chunk(tokenizer, segment, context_length)
                all_chunks.extend(chunks)

        elif current_length + text_len > MAX_CHARS_PER_SEGMENT and current_batch:
            combined = tokenizer.eos_token.join(current_batch)
            chunks = _tokenize_and_chunk(tokenizer, combined, context_length)
            all_chunks.extend(chunks)

            current_batch = [text]
            current_length = text_len
        else:
            current_batch.append(text)
            current_length += text_len

    if current_batch:
        combined = tokenizer.eos_token.join(current_batch)
        chunks = _tokenize_and_chunk(tokenizer, combined, context_length)
        all_chunks.extend(chunks)

    return {"input_ids": all_chunks}

def _tokenize_and_chunk(tokenizer, text, context_length):
    """

    Returns:
        List[List[int]]: chunks of input_ids
    """
    outputs = tokenizer(
        text,
        truncation=False,
        return_tensors=None,  # Python list
        padding=False,
    )

    input_ids = outputs["input_ids"]

    total_length = len(input_ids)
    if total_length % context_length != 0:
        padding_length = context_length - (total_length % context_length)
        pad_token_id = (
            tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
        )
        input_ids = input_ids + [pad_token_id] * padding_length

    num_chunks = len(input_ids) // context_length
    chunks = []
    for i in range(num_chunks):
        start = i * context_length
        end = start + context_length
        chunk = input_ids[start:end]
        chunks.append(chunk)

    return chunks

def train():
    # Set random seed for reproducibility (before any random operations)
    # Same as eval_distributed.py for consistency
    # torch.manual_seed(seed)
    # random.seed(seed)
    # np.random.seed(seed)

    parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
    model_args, training_args = parser.parse_args_into_dataclasses()

    # NOTE: May expand supported model types in the future
    if model_args.model_type == "gpt-neox":
        replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
    else:
        assert model_args.model_type == "llama", (
            "Only support llama and gpt-neox for now"
        )
        replace_llama_attn(
            use_flash_attn=training_args.use_flash_attn,
            use_full=training_args.use_full_attn,
            use_optimized=training_args.forward_flashattn_optimized,
            use_optimized_plus=training_args.use_flash_plus,
            use_optimized_plus_norope=training_args.use_flash_plus_norope,
            use_hierarchical_forward=training_args.use_hierarchical_forward,
        )

    # Set RoPE scaling factor
    config = transformers.AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
    )

    orig_rope_scaling = getattr(config, "rope_scaling", None)
    if orig_rope_scaling is None:
        orig_rope_scaling = {"factor": 1}

    orig_rope_scaling_factor = (
        orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
    )
    orig_ctx_len = getattr(config, "max_position_embeddings", None)
    if orig_ctx_len:
        orig_ctx_len *= orig_rope_scaling_factor
        if training_args.model_max_length > orig_ctx_len:
            scaling_factor = float(
                math.ceil(training_args.model_max_length / orig_ctx_len)
            )
            config.rope_scaling = {"type": "linear", "factor": scaling_factor}

    # Load model and tokenizer
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=training_args.cache_dir,
        torch_dtype=torch.bfloat16,
    )

    # Replace RoPE with YaRN if requested
    if training_args.use_yarn_rope and training_args.model_max_length > orig_ctx_len:
        from yarn_rope_official import replace_rope_with_yarn

        print("\n" + "=" * 80)
        print(" Replacing Position Interpolation (PI) with YaRN")
        print("=" * 80)
        print(f"  Original context length: {orig_ctx_len}")
        print(f"  Target context length:   {training_args.model_max_length}")
        print(f"  Scaling factor:          {scaling_factor:.1f}")
        print("  YaRN advantages:")
        print("    - Better perplexity (~3% improvement over PI)")
        print("    - Higher passkey accuracy (~6% improvement)")
        print("    - Faster convergence (75% fewer tokens)")
        print("=" * 80 + "\n")

        model = replace_rope_with_yarn(
            model,
            scaling_factor=scaling_factor,
            original_max_length=orig_ctx_len,
        )

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=True,
    )

    special_tokens_dict = dict()
    if tokenizer.pad_token is None:
        special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
    if tokenizer.eos_token is None:
        special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
    if tokenizer.bos_token is None:
        special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
    if tokenizer.unk_token is None:
        special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

    smart_tokenizer_and_embedding_resize(
        special_tokens_dict=special_tokens_dict,
        tokenizer=tokenizer,
        model=model,
    )

    # ========================================================================
    # Register HiCI Module (CRITICAL: Before optimizer initialization!)
    # ========================================================================
    register_hici_to_model(
        model,
        num_local_slots=training_args.num_local_slots,
        global_slots=training_args.global_slots,
        num_chunks=training_args.num_chunks,
        num_heads=training_args.num_heads,
        use_bottleneck=training_args.use_bottleneck,
        bottleneck_dim=training_args.bottleneck_dim,
        use_local_summary=training_args.use_local_summary,
        use_hierarchical=training_args.use_global_repr,
        use_flash_plus=training_args.use_flash_plus,
        use_flash=training_args.use_flash_attn_in_hici,
        use_llama_init=training_args.use_llama_init,
    )
    print("=" * 70 + "\n")

    rank = int(os.environ.get("RANK", -1))
    if rank > 0:
        barrier()

    from datasets import load_from_disk

    dataset = load_from_disk(
        "/path/to/cache/datasets"
    )

    print("=" * 70)
    print("（token）")
    print("=" * 70)

    print("\n 1: token...")
    very_short_docs = dataset.filter(
        lambda x: len(x["text"]) < 20_000, num_proc=128
    )
    short_docs = dataset.filter(
        lambda x: 20_000 <= len(x["text"]) < 100_000, num_proc=128
    )
    medium_docs = dataset.filter(
        lambda x: 100_000 <= len(x["text"]) < 300_000, num_proc=128
    )
    long_docs = dataset.filter(
        lambda x: len(x["text"]) >= 300_000, num_proc=128
    )  # >90K tokens, batch=1
    print(f" (<20K, <6K tokens): {len(very_short_docs['train']):,} ")
    print(f" (20K-100K, 6K-30K tokens): {len(short_docs['train']):,} ")
    print(
    f" (100K-300K, 30K-90K tokens): {len(medium_docs['train']):,} "
    )
    print(f" (>=300K, >90K tokens): {len(long_docs['train']):,} ")

    print("\n 2: batch_size ...")
    from datasets import concatenate_datasets

    very_short_processed = very_short_docs.map(
        partial(tokenize_fn, tokenizer),
        batched=True,
        batch_size=200,
        num_proc=128,
        remove_columns=["text", "meta"],
    )

    short_processed = short_docs.map(
        partial(tokenize_fn, tokenizer),
        batched=True,
        batch_size=40,
        num_proc=128,
        remove_columns=["text", "meta"],
    )

    medium_processed = medium_docs.map(
        partial(tokenize_fn, tokenizer),
        batched=True,
        batch_size=5,
        num_proc=128,
        remove_columns=["text", "meta"],
    )

    long_processed = long_docs.map(
        partial(tokenize_fn, tokenizer),
        batched=True,
        batch_size=1,
        num_proc=128,
        remove_columns=["text", "meta"],
    )

    print("\n 3: ...")
    dataset = concatenate_datasets(
        [
            very_short_processed["train"],
            short_processed["train"],
            medium_processed["train"],
            long_processed["train"],
        ]
    )
    dataset = dataset.shuffle(seed=42)
    dataset = {"train": dataset}
    print(f" : {len(dataset['train']):,}")
    print("=" * 70)

    if rank == 0:
        barrier()

    print(dataset)

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # ========================================================================
    # Setup LoRA (if enabled)
    # This ensures HiCI parameters are already in model.parameters()
    # ========================================================================
    if training_args.low_rank_training:
        if model_args.model_type == "gpt-neox":
            # added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
            targets = ["query_key_value", "dense"]
        else:
            targets = ["q_proj", "k_proj", "v_proj", "o_proj"]

        config = LoraConfig(
            r=8,
            lora_alpha=16,
            target_modules=targets,
            lora_dropout=0,
            bias="none",
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, config)
        # enable trainable params
        [
            p.requires_grad_()
            for n, p in model.named_parameters()
            if any([k in n for k in training_args.trainable_params.split(",")])
        ]

    # ========================================================================
    # Verify and summarize trainable parameters
    # ========================================================================
    if rank == 0:
        print("\n" + "=" * 70)
        print("Trainable Parameters Summary")
        print("=" * 70)

    trainable_params_dict = {}
    for n, p in model.named_parameters():
        if p.requires_grad:
            category = None
            if "lora" in n.lower():
                category = "LoRA Adapters"
            elif (
                "local_constructor" in n
                or "global_integrator" in n
                or "global_integrator" in n
            ):
                category = "HiCI Modules"
            elif "embed" in n.lower():
                category = "Embeddings"
            elif "norm" in n.lower():
                category = "LayerNorm"
            else:
                category = "Other"

            if category not in trainable_params_dict:
                trainable_params_dict[category] = 0
            trainable_params_dict[category] += p.numel()

    total_trainable = sum(trainable_params_dict.values())
    total_params = sum(p.numel() for p in model.parameters())

    if rank == 0:
        for category, count in sorted(trainable_params_dict.items()):
            print(
                f"  {category:20s}: {count:15,} params ({count / total_trainable * 100:5.2f}%)"
            )

        if "HiCI Modules" in trainable_params_dict:
            local_constructor_count = 0
            hierarchical_count = 0
            for n, p in model.named_parameters():
                if p.requires_grad:
                    if "local_constructor" in n:
                        local_constructor_count += p.numel()
                    elif "global_integrator" in n or "global_integrator" in n:
                        hierarchical_count += p.numel()

            if local_constructor_count > 0 or hierarchical_count > 0:
                print(f"    {' LocalConstructor':20s}: {local_constructor_count:15,} params")
                print(
                    f"    {' HierarchicalAgg':20s}: {hierarchical_count:15,} params"
                )

        print(f"  {'' * 20}   {'' * 15}   {'' * 7}")
        print(
            f"  {'Total Trainable':20s}: {total_trainable:15,} params ({total_trainable / total_params * 100:5.2f}% of total)"
        )
        print(f"  {'Total Params':20s}: {total_params:15,} params")

    # Warning if HiCI modules are not properly configured
    has_hici_in_trainable = "local_constructor" in training_args.trainable_params
    has_hierarchical_in_trainable = "hierarchical" in training_args.trainable_params
    has_hici_params = "HiCI Modules" in trainable_params_dict

    if rank == 0:
        if has_hici_in_trainable and not has_hici_params:
            print(
                "\n  WARNING: 'local_constructor' specified in --trainable_params but no HiCI parameters found!"
            )
        elif not has_hici_in_trainable and has_hici_params:
            print(
                "\n  WARNING: HiCI module parameters found but not in --trainable_params!"
            )
            print(
                "    Add '--trainable_params \"embed,norm,local_constructor,hierarchical\"' to enable training."
            )

        # Check if hierarchical is missing when using HiCI
        if training_args.use_global_repr:
            if has_hici_in_trainable and not has_hierarchical_in_trainable:
                print(
                    "\n  WARNING: Using HiCI but 'hierarchical' not in --trainable_params!"
                )
                print("    HierarchicalAggregator parameters may not be trained!")
                print(
                    "    Recommended: '--trainable_params \"embed,norm,local_constructor,hierarchical\"'"
                )

        print("=" * 70 + "\n")

    model.config.use_cache = False  # required for gradient checkpointing
    model.enable_input_require_grads()  # required for gradient checkpointing
    model.gradient_checkpointing_enable()  # enable gradient checkpointing

    if rank == 0:
        import glob

        checkpoint_dirs = glob.glob(
            os.path.join(training_args.output_dir, "checkpoint-*")
        )
        if checkpoint_dirs:
            latest_checkpoint = max(
                checkpoint_dirs, key=lambda x: int(x.split("-")[-1])
            )

            print("\n" + "=" * 80)
            print(" Checkpoint")
            print("=" * 80)
            print(f" Output : {training_args.output_dir}")
            print(f" {len(checkpoint_dirs)} checkpoints")
            print(f" HuggingFace Trainer checkpoint :")
            print(f"   → {latest_checkpoint}")
            print()
            print(f"   - HiCI local query slots: {training_args.num_local_slots}")
            print("=" * 80 + "\n")
        else:
            print("\n" + "=" * 80)
            print(" （ checkpoints）")
            print("=" * 80)
            print(f" Output : {training_args.output_dir}")
            print(f"   - HiCI local query slots: {training_args.num_local_slots}")
            print("=" * 80 + "\n")

    # ========================================================================
    # Initialize Trainer (optimizer created here)
    # At this point, model.parameters() includes:
    # 1. Base model parameters (frozen if LoRA)
    # 2. LoRA adapters (trainable)
    # 3. Global HiCI parameters (trainable, ~1.6B for 16 slots)
    # 4. Embeddings & LayerNorm (trainable if in trainable_params)
    # Uses LayeredLRTrainer to support different learning rates for HiCI
    # ========================================================================
    trainer = LayeredLRTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=None,
        data_collator=data_collator,
    )
    trainer.train()
    trainer.save_state()
    trainer.save_model(output_dir=training_args.output_dir)

if __name__ == "__main__":
    train()
