import os
import json
import torch

from dataclasses import dataclass, field
from contextlib import nullcontext
from typing import Union, Optional, Any, Literal

from litgpt.config import Config
from litgpt.config_dynamic import Config as DynamicConfig

from transformers import AutoModelForCausalLM, AutoConfig


@dataclass
class HuggingfaceConfig:
    """need to properly merge HF one day"""

    name: str
    checkpoint: Optional[str]
    block_size: Optional[int] = None
    strategy: Optional[str] = None

    @property
    def Block(self):
        if "llama" in self.name.lower():
            from transformers.models.llama.modeling_llama import LlamaDecoderLayer

            return LlamaDecoderLayer
        else:
            raise ValueError("Provide the block name for this architecture.")

    def construct_model(self, objective, gradient_checkpointing: bool) -> torch.nn.Module:
        from axonn.models.transformers import parallelize

        source = self.checkpoint or self.name
        with parallelize(source) if self.strategy == "axonn_tp" else nullcontext():
            model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(source))

        if gradient_checkpointing:
            model.enable_gradient_checkpointing()
        return model


@dataclass
class DataEntry:
    type: str
    prefix: str
    weight: int = 1.0
    data_signature: Optional[dict[str, list[str] | str]] = None
    name: Optional[str] = None
    data_dir: Optional[str] = None
    text_key: Optional[str] = None
    repetitions: Optional[int] = None
    max_epoch: Optional[int] = None
    scheduler: Optional[tuple[str, int]] = None
    return_data_id: Optional[bool] = None


@dataclass
class FabricConfig:
    optimize_communication: Optional[bool] = False
    all_reduce_dtype: Optional[str] = None
    row_tensor_parallel_size: Optional[int] = 1
    col_tensor_parallel_size: Optional[int] = 1
    depth_tensor_parallel_size: Optional[int] = 1
    optim_sharding: Optional[bool] = False
    allow_optim_fusion: Optional[bool] = False
    use_apex_adamw: Optional[bool] = False


@dataclass
class CLISettings:
    # Main settings
    run_name: str = "default-run"  # The name for logging.
    out_dir: str = None  # type: ignore # The directory to save checkpoints. Required to be given or set as OUT_DIR
    resume: bool = True  # Whether to resume from a checkpoint in the out_dir.
    max_tokens: Optional[Union[int, float]] = None  # The maximum number of tokens to train on (determines max_steps).
    max_steps: Optional[int] = None  # Set max_tokens to zero if setting max_steps
    seed: int = 1337  # The random seed to use for reproducibility.

    # Model configuration
    model_name: str = "tiny-llama-1.1b"  # The model name to use when creating the model from config.py / config_dynamic
    model_impl: str = "retrieval"  # The model name to use when creating the model from config.py
    block_size: int = 2048  # The block size to use (lit-gpt-ese for sequence length).
    ignore_block_size_mismatch: bool = False  # Whether to ignore block size mismatch.
    model_checkpoint: Optional[str] = None  # The model checkpoint to load. Else, from config.

    attn_impl: Literal["sdpa", "rocm"] = "sdpa"  # The attention implementation to use.
    structured_init: bool = False  # Whether to use layer structured initialization for the model.
    structured_init_for_wte: bool = False  # Whether to use structured initialization for the input embedding layer.
    structured_init_olmo_variant: bool = False  # Whether to use olmo style structured initialization.

    # Training hyperparameters
    world_batch_size: int = 2048  # The total batch size across all devices and nodes.
    batch_size_ramp: int = 0  # Over how many mbs steps to linearly increase the batch size to world_batch_size
    optimizer: str = "AdamW"
    optim_config: dict[str, Any] = field(
        default_factory=lambda: dict(
            lr=0.0004,  # The learning rate.
            weight_decay=0.1,  # The weight decay.
            betas=(0.9, 0.95),  # The beta parameters for the Adam optimizer.
            eps=1e-8,  # The eps parameter for the Adam optimizer
        )
    )
    grad_clip: float = 1.0  # The gradient clipping value.
    warmup_steps: int = 0  # The number of warmup steps.
    cooldown_steps: int = 0  # The number of cooldown steps.
    lr_schedule: str = "cosine"  # The learning rate schedule to use.
    min_lr: float = 0.00004  # The minimum learning rate to decay to.
    no_weight_decay_for_bias_and_norm_params: bool = False  # do not use weight decay for bias and norm params
    lr_scaler: Optional[str] = None  # The learning rate scaling strategy to use. "inverse_n_embd"

    # Objective and Regularization
    z_regularization: float = 0.0
    target_range_train: list[int] = (
        None  # the target range of ids to use when computing a cls loss using special tokens for the training data.
    )
    target_range_val: list[int] = None  # ...for the val data
    freeze_params: Optional[list[str]] = None  # List of parameter names to freeze (no gradients).

    # Implementation and backend
    fabric_strategy: str = "ddp"  # The fabric strategy to use: ddp, fsdp, axonn_tp.
    fabric_precision: Literal["bf16-true", "bf16-mixed", "16-mixed", "16", "32"] = "bf16-mixed"
    fabric_use_lightning_environment: bool = False  # If False, use the auto setting, True, use LightningEnvironment.
    fabric: FabricConfig = field(
        default_factory=lambda: FabricConfig(
            **dict(
                optimize_communication=False,  # [Copilot] Whether to optimize communication.
                all_reduce_dtype=None,  # [Copilot] The dtype to use for all-reduce communication.
                row_tensor_parallel_size=1,  # The size of the row tensor parallel dimension
                col_tensor_parallel_size=1,  # The size of the col tensor parallel dimension
                depth_tensor_parallel_size=1,  # The size of the depth tensor parallel dimension
                optim_sharding=False,  # zero-1, activated directly in pytorch. May not play nicely with non-ddp
                allow_optim_fusion=False,  # fishes for fusion opportunities in the optimizer
            )
        )
    )
    micro_batch_size: int = 4  # The micro batch size to use.
    compile_model: bool = False  # Whether to compile the model.
    compile_model_max_autotune_no_cudagraphs: bool = False  # Whether to compile the model with XXXX-13 autotune and no cudagraphs.
    dynamo_ddp_config: Optional[Literal["ddp_optimizer", "python_reducer", "no_optimization"]] = None
    matmul_precision: str = "high"  # enable tf32 acc on cuda with this
    dataloader_num_workers: int = 0  # The number of workers to use for the dataloaders.
    n_chunks: int = 4  # The number of chunks to preload at a time from packed dataset.
    gradient_checkpointing: bool = False  # Whether to use activation checkpointing
    allow_nonfinite_loss: bool = False  # whether to end training immediately if non-finite loss is encountered
    use_liger_ce: bool = False  # Whether to use Liger Kernel's custom cross-entropy loss.

    # Logging
    logger_name: str = "wandb"  # The logger to use for logging, only supports "wandb" for now.
    logger_project: str = "tinyllama"  # The logger/wandb project to log to.
    wandb_tags: list[str] = field(default_factory=lambda: [])  # The tags to add the the wandb run.
    data_telemetry: Optional[Union[int|bool]] = None  # Data telemetry switch, set based on needs, marking step to stop at.
    lockstep_sampling: Optional[Literal["micro_batch", "world_batch"]] = None  # Whether to sample in lockstep intra device or across all devices, else freely evolving.
    data_verbose: bool = True  # Dataset verbose switch, primarily for debugging pqds-pure.
    data_dry_run: bool = False  # Dataset dry run switch, doesnt take training steps, just rolls through the data.
    model_telemetry: bool = (
        False  # Whether to monitor important model values to look for spikes. May increase overhead. Induces compile warnings, ok/FIXME?
    )
    shape_watching_iters: int = 3  # Number of iterations to watch shapes for. Set to 0 to disable.
    log_rank_zero_only: bool = False  # Whether to log iters only from rank 0.
    log_step_interval: int = 1  # The base interval for logging (scales with gradient_accumulation_steps).
    eval_iters: int = 100  # The number of iterations to process during a validation loop.
    save_step_interval: int = 2000  # The number of iterations between saving.
    eval_step_interval: int = 2000  # The number of iterations between evaluating.
    save_first_step: bool = False  # Whether to save the checkpoint at the first step
    save_last_step: bool = False  # Whether to save the checkpoint at the last step
    save_n_min_before_job_done: Optional[int] = None  # Save the checkpoint n minutes before current job done
    sanity_validate: bool = False  # Whether to run a short sanity check validation loop at the start.
    measure_utilization: bool = False  # Print FLOPs and MFU. Flaky on XXXX-26, so defaulting to False, FIXME?
    estimate_param_count: bool = False  # Estimate the number of parameters in the model using a function.
    track_memory: bool = False  # Track memory usage during training.
    track_memory_finegrained: bool = False  # Track memory usage during training with fine-grained tracking.
    simple_gptneox_tflops: bool = True  # Use a simple GPT-NeoX flops calculation. Standin on XXXX-26.
    peak_tflops_per_device: float = (
        192.0  # The peak TFLOPS per device for the GPUS on this system. default is XXXX-26's  MI250X
    )
    derive_cost_basis: bool = False  # Derive the cost basis for run on this topology.
    target_token_count: Optional[Union[float, int]] = (
        None  # The target token count for the cost basis. (as opposed to max_step based)
    )
    cards_per_node: int = 8  # The number of cards per node.
    validate_only: bool = False  # Whether to only run validation.
    initial_validate: bool = False  # Whether to run a validation loop when trainig starts.
    validate_at_end: bool = True  # Whether to run a validation loop at the end.
    stability_step: Optional[int] = None  # The step at which we log "stable run"

    # Retrieval args
    finetune_checkpoint: Optional[str] = None  # The model checkpoint to load for finetuning.
    pretrained_prefix_model: bool = False  # Whether to use a pretrained checkpoint.
    pretrained_suffix_model: bool = False  # Whether to use a pretrained checkpoint.
    mean_pooling: bool = False  # Whether to use mean pooling for the final hidden states.
    fixed_length: bool = False  # Whether to use fixed length sequences.
    attn_type: str = "causal_attn"  # The attention type to use.
    loss_type: str = (
        "cross_batch_negative"  # The loss type to use. Choose from ["sequence_negative", "cross_batch_negative"].
    )
    max_seq_len: Optional[int] = (
        None  # The maximum sequence length to use for hfds only. will use this param to truncate seqs
    )
    negatives_cross_device: bool = False  # Whether to use negatives from other devices.
    negatives_cross_device_group_size: Optional[int] = (
        None  # The number of devices to group together for cross device negative gathering. If None, use all devices.
    )
    mask_k_ldiags: int = (
        None  # Number of neighboring lower diagonals to mask out in scores matrix (prefix x suffix -> n x n)
    )
    mask_k_udiags: int = (
        None  # Number of neighboring upper diagonals to mask out in scores matrix (prefix x suffix -> n x n)
    )
    pick_k: int = None  # drop k% indices of prefix suffix pairs
    gen_loss: bool = False  # Whether to use generative loss (next token prediction)
    alpha: float = 0.5  # The weight value to use for the multi-loss setting (emb loss and gen loss)
    n_gram: int = None  # The n-gram rows to remove from the score matrix (prefix x suffix)
    keep_k_cross_device_negatives: int = (
        None  # The number of cross device negatives to keep. (a little bit of math is needed to figure out this number)
    )
    compute_k_loss: int = (
        None  # We use neighboring k upper diagonals as labels to compute loss k times (this is a hacky way to try achieving k_pos_labels)
    )
    k_random_pos_labels: int = (
        None  # Randomly samples a suffix within k distance as positive label (this is another hacky way to try achieving k_pos_labels while using fast CE kernel)
    )
    mask_full_ldiag: bool = (
        False  # Whether to mask the full lower diagonal in the scores matrix. (We think lower diag negatives are bad negatives)
    )
    length_shortcut_ablation: Optional[Literal["permute_batch_tokens", "rand_toks_const_lens", "rand_toks_doc_lens", 
                                               "rand_toks_rand_lens", "truncate_lens_100_uniform", "truncate_lens_100_normal"]] = None  # Ablate the length shortcut.
    # single model flags
    suffix_is_prefix: bool = False  # Whether to use suffix as prefix
    batch_prefix_and_suffix: bool = False  # Whether to batch prefix and suffix together in a single call.
    flip_rope_embedding_suffix: bool = False  # Whether to flip the rope embedding for suffix
    nope_pos_embedding: bool = False  # Whether to use nope positional embedding
    add_suf_pre_tokens: bool = False  # Whether to use suffix and prefix tokens

    # retrieval finetuning args
    train_group_size: int = None  # The number of hard negatives to sample for each query from the dataset (it uses train_group_size - 1 negatives)

    k_pos_labels: int = (
        None  # During in-batch negatives, instead of one positive you can use k nearby positive that are k tokens apart
    )
    decay_factor: float = (
        1.0  # When using k_pos_labels, we can decay the prob mass across pos labels. 1.0 will do uniform mass, 0.5 will do moderate decay, 0.1 will do steep decay
    )
    siglip_loss: bool = False  # Whether to accumulate the average of gradients for use in alpha stage.
    keep_eos: bool = False  # Whether to keep the eos token in the sequence.

    doc_wise_pqdsp: bool = (
        False  # Whether to load document wise from pqds for the retrieval task. Note applies to pqds-pure only.
    )
    doc_wise_pqdsp_skip_tail: bool = (
        False  # Whether to skip the tail of the row when loading document wise from pqds for the retrieval task.
    )
    doc_wise_pqdsp_sep_tok: str = (
        "bos_id"  # The separator token tokenizer attr to use when loading document wise from pqds for the retrieval task.
    )

    # Data Handling
    # PKDS arguments:
    shuffle_filenames: bool = True  # (PKDS only.) Shuffle filenames glob'd up for each prefix
    shuffle_blocks: bool = True  # (PKDS only.) Whether to shuffle the blocks in files.
    all_block_size_tensors: bool = False  # Assume all datasets return tensors of exactly block_size
    # HFDS arguments:
    pad_to_block_size: bool = False  # Whether to pad to the block size (HFDS only).
    add_bos: bool = True  # Whether to add the BOS token to the input (HFDS only).
    add_eos: bool = True  # Whether to add the EOS token to the input (HFDS only).
    data_signature: dict[str, list[str] | str] = field(
        default_factory=lambda: {"keys": ["text"], "format_fn": "pass_text"}
    )  # The data signature to use for processing rows of the dataset. Can be set individually per dataset. (HFDS only).
    # For both backends:
    collate_checks_enabled: bool = True  # Enable checks for the collate function.
    all_block_size_tensors: bool = False  # Assume all datasets return tensors with the same size, may reduce latency.
    use_chat_template: bool = False  # Whether to use the chat template in the collator.
    return_data_id: bool = False  # Whether to return the data_id in the dataset.
    data_config: Union[str, dict[str, list[DataEntry]]] = field(
        default_factory=lambda: {
            "train_data": [DataEntry("pkds", "", 1)],
            "val_data": [DataEntry("pkds", "", 1)],
        }
    )
    # The directories containing the training/validation data.
    train_data_dir: str = "$DATA_DIR/spj_star_combined_full_tinyllama_tokd"
    val_data_dir: str = "$DATA_DIR/spj_star_combined_full_tinyllama_tokd"
    # The path to the tokenizer to use [required to identify pad_token_id even for pkds]
    tokenizer_path: str = (
        "/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_models/external/TinyLlama-1.1B-intermediate-step-1431k-3T"
    )
    # For exact match memorization validation logic
    memorization_validation: bool = False
    prefix_lengths: Union[dict[str, int], list[int]] = field(
        default_factory=lambda: {"min": 50, "XXXX-13": 150, "step": 50}
    )
    suffix_lengths: Union[dict[str, int], list[int]] = field(
        default_factory=lambda: {"min": 25, "XXXX-13": 75, "step": 25},
    )

    model_config: Union[Config, DynamicConfig, HuggingfaceConfig] = field(init=False)
    model_overwrite: dict[str, Any] = field(default_factory=dict)

    def __post_init__(self):
        # Validate arguments
        if self.out_dir is None:
            self.out_dir = os.getenv("OUTPUT_DIR", "NOT_FOUND")
        assert self.out_dir != "NOT_FOUND"
        assert self.tokenizer_path, "Tokenizer has to be specified."

        # Handle data config
        self._parse_data_config()
        self._process_data_entries()
        self._expand_paths()

        # Handle memorization validation
        self._complete_memorization_validation()

        # Handle fabric config
        self._complete_fabric_config()
        # Tensor parallelism is implemented by the AxoNN fabric only.
        if (
            self.fabric.depth_tensor_parallel_size > 1
            or self.fabric.row_tensor_parallel_size > 1
            or self.fabric.col_tensor_parallel_size > 1
        ):
            assert self.fabric_strategy == "axonn_tp", "x_tensor_parallel_size > 1 implies use of axonn_tp."

        self._parse_environment_variables()

        # Add any derived cfg here
        self.node_batch_size = self.world_batch_size // self.num_nodes
        self.loader_block_size = self.block_size + 1
        self.global_total_time = 0
        self.max_tokens_per_device = 0
        self.tokens_per_step = 0

        self.batch_size = self.node_batch_size // self.devices
        if self.batch_size_ramp == 0:
            self.gradient_accumulation_steps = self.batch_size // self.micro_batch_size
        else:
            self.gradient_accumulation_steps = 1
        self.replicas = self.devices * self.num_nodes

        self.warmup_iters = self.warmup_steps * self.gradient_accumulation_steps
        self.cooldown_iters = self.cooldown_steps * self.gradient_accumulation_steps

        self.log_iter_interval = self.log_step_interval * self.gradient_accumulation_steps
        self.dataset_names = [i.prefix for i in self.data_config["train_data"]]

        self._validate_args()

        # Finally, store model config object itself
        if self.model_impl == "litgpt" or self.model_impl == "retrieval":
            self.model_config = Config.from_name(self.model_name, **self.model_overwrite)

        # Set strategy
        self.model_config.strategy = self.fabric_strategy
        # Set attn_impl
        self.model_config.attn_impl = self.attn_impl

        # check whether we're requesting a compatible modeling config and attn_impl
        if self.model_config.surrogate_config:
            assert self.attn_impl == "sdpa", "Surrogate models only support SDPA attention."

        # Set structured_init
        self.model_config.structured_init = self.structured_init
        self.model_config.structured_init_for_wte = self.structured_init_for_wte
        self.model_config.structured_init_olmo_variant = self.structured_init_olmo_variant

        self.n_hard_negatives = self.train_group_size - 1 if self.train_group_size else 0

        # retrieval args safety and conflict checks
        assert not (
            self.mask_k_ldiags and self.mask_full_ldiag
        ), "mask_k_ldiags and mask_full_ldiag cannot be set together"
        if self.k_random_pos_labels != None:
            assert self.k_pos_labels == None, "k_random_pos_labels is being used so k_pos_labels should be None"
            assert self.mask_k_udiags == None, "k_random_pos_labels is being used so mask_k_udiags should be None"
        if self.n_hard_negatives > 0:
            assert self.k_pos_labels == None, "k_pos_labels is not supported with hard_negatives"
            assert self.k_random_pos_labels == None, "k_random_pos_labels is not supported with hard_negatives"
            assert self.mask_k_ldiags == None, "mask_k_ldiags is not supported with hard_negatives"
        

    def _validate_args(self):
        assert (
            self.max_tokens is None
        ), "max_tokens is not supported for this branch now bc inferring max_steps is tricky"

        assert ((self.max_steps is not None) and (self.max_steps > 0)) ^ (
            ((self.max_tokens is not None) and (self.max_tokens > 0))
        ), f"only max_steps ({self.max_steps}) xor max_tokens ({self.max_tokens}) can be specified"
        assert len(set(self.dataset_names)) == len(
            self.data_config["train_data"]
        ), "please provide different names for each subset"

        # Any additional sanity checks here.
        # assert self.gradient_accumulation_steps > 0, "derived gradient_accumulation_steps must be > 0"
        assert (
            self.gradient_accumulation_steps == 1
        ), "gradient_accumulation_steps must be 1, eg. no accumulation (check world batch size)"

        if self.batch_size_ramp == 0:
            assert (
                self.world_batch_size
                == self.micro_batch_size * self.gradient_accumulation_steps * self.devices * self.num_nodes
            ), "world batch size should be: micro_batch_size * gradient_accumulation_steps * devices * num_nodes"
        else:
            assert self.world_batch_size % (self.micro_batch_size * self.devices * self.num_nodes) == 0

        assert not (
            self.memorization_validation and self.target_range_val
        ), "both memorization_validation and target_range_val cannot be set at same time"

        if self.fabric_strategy == "ddp" and self.compile_model and self.gradient_checkpointing:
            assert (
                self.dynamo_ddp_config == "python_reducer"
            ), "dynamo_ddp_config must be python_reducer for this setup."
            # NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph. Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674.
        
        if self.compile_model_max_autotune_no_cudagraphs:
            self.compile_model = True

    def _parse_environment_variables(self):
        """Parse env variables and directly store as non-field attributes"""
        self.SLURM_JOB_ID = int(os.getenv("SLURM_JOB_ID", 0))
        self.SLURM_ARRAY_JOB_ID = int(os.getenv("SLURM_ARRAY_JOB_ID", 0))
        self.SLURM_ARRAY_TASK_ID = int(os.getenv("SLURM_ARRAY_TASK_ID", 0))
        self.SLURM_ARRAY_TASK_COUNT = int(os.getenv("SLURM_ARRAY_TASK_COUNT", 1))
        self.MASTER_ADDR = os.getenv("MASTER_ADDR", "0")
        self.MASTER_PORT = int(os.getenv("MASTER_PORT", 0))
        self.WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
        self.RANK = int(os.getenv("SLURM_PROCID", "0"))
        self.devices = int(os.getenv("SLURM_NTASKS_PER_NODE", torch.cuda.device_count()))
        self.num_nodes = int(os.getenv("SLURM_JOB_NUM_NODES", 1))

    def _parse_data_config(self) -> dict[str, list[DataEntry]]:
        """If data_config is a string, load it from a file."""
        if isinstance(self.data_config, str):
            try:
                with open(self.data_config, mode="r") as json_file:
                    self.data_config = json.load(json_file)
            except Exception as e:
                raise ValueError(
                    f"data_config passed was a string, but failed to load as a json object from {self.data_config}: {e}"
                )

    def _process_data_entries(self):
        """If they are dicts, convert them to DataEntry objects."""
        processed_data_config = {"train_data": [], "val_data": []}
        unpack_entry = lambda entry: DataEntry(**entry) if isinstance(entry, dict) else entry
        for entry in self.data_config["train_data"]:
            processed_data_config["train_data"].append(unpack_entry(entry))
        for entry in self.data_config["val_data"]:
            processed_data_config["val_data"].append(unpack_entry(entry))
        self.data_config = processed_data_config

    def _expand_paths(self):
        """Materialize fully qualified paths."""
        self.train_data_dir = os.path.expandvars(self.train_data_dir) if self.train_data_dir is not None else ""
        self.val_data_dir = os.path.expandvars(self.val_data_dir) if self.val_data_dir is not None else ""
        for entry in self.data_config["train_data"] + self.data_config["val_data"]:
            if entry.data_dir is not None:
                entry.data_dir = os.path.expandvars(entry.data_dir)

    def _complete_fabric_config(self):
        """Complete fabric config with missing values if only partially specified."""
        self.fabric = FabricConfig(**self.fabric) if isinstance(self.fabric, dict) else self.fabric

    def _complete_memorization_validation(self):
        if isinstance(self.prefix_lengths, dict):
            min_prefix_len, max_prefix_len, step = (
                self.prefix_lengths["min"],
                self.prefix_lengths["XXXX-13"],
                self.prefix_lengths.get("step", 1),
            )
            prefix_lengths = list(range(min_prefix_len, max_prefix_len + 1, step))
        elif isinstance(self.prefix_lengths, list):
            prefix_lengths = sorted(self.prefix_lengths)
        else:
            raise ValueError(f"prefix_lengths must be a dict or list, got {self.prefix_lengths}")

        if isinstance(self.suffix_lengths, dict):
            min_suffix_len, max_suffix_len, step = (
                self.suffix_lengths["min"],
                self.suffix_lengths["XXXX-13"],
                self.suffix_lengths.get("step", 1),
            )
            suffix_lengths = list(range(min_suffix_len, max_suffix_len + 1, step))
        elif isinstance(self.suffix_lengths, list):
            suffix_lengths = sorted(self.suffix_lengths)
        else:
            raise ValueError(f"suffix_lengths must be a dict or list, got {self.suffix_lengths}")

        self.prefix_lengths = prefix_lengths
        self.suffix_lengths = suffix_lengths