"""
Box Embeddings Sentence Transformer Training Script

This script trains sentence transformers with box embeddings for improved
semantic similarity and entailment tasks.
"""

import argparse
import json
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
from datasets import Dataset, load_dataset, load_from_disk
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    models,
)
from sentence_transformers.evaluation import SequentialEvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

import wandb
from box_similarity import similarity_function, similarity_function_entailment
from EmbeddingSimilarityEvaluatorDiff import (
    EmbeddingSimilarityEvaluatorDiff,
    TripletEvaluatorDiff,
)

sys.path.append("evaluation")


@dataclass
class TrainingConfig:
    """Configuration class for training parameters."""

    model_type: str
    dataset_size: int
    run_name: str
    batch_size: int
    mini_batch_size: int
    use_augmentation: float = 0.0
    use_simcse: bool = False
    resume: bool = False
    run_id: Optional[str] = None
    use_synthesized: bool = False
    use_synthesized_negative: bool = False
    use_links: bool = False
    entailment_type: str = ""

    # Training mode
    use_vector_mode: bool = (
        False  # If True, train standard vector embeddings instead of box embeddings
    )

    # Model hyperparameters
    volume_temp: float = 1.0
    intersection_temp: float = 0.001  # Change to point 0.1.
    learning_rate: float = 2e-5
    fp16: bool = False
    grad_norm: float = 1.0

    # Paths
    base_directory: str = "path"
    # base_directory: str = "."


class VectorEntailmentHead(nn.Module):
    """
    MLP classifier that takes concatenated embeddings and produces entailment scores.

    This module:
    - Takes two sentence embeddings
    - Concatenates them
    - Passes through MLP layers
    - Outputs a scalar score representing entailment/similarity
    """

    def __init__(self, input_dim: int):
        """
        Initialize the classifier.

        Args:
            input_dim: Dimension of each input embedding
            hidden_dim: Hidden dimension for MLP (defaults to input_dim)
        """
        super(VectorEntailmentHead, self).__init__()
        self.input_dim = input_dim

        # MLP that takes concatenated embeddings (2 * input_dim) and outputs score
        self.classifier = nn.Sequential(nn.Linear(input_dim, 1))

    def get_scores(
        self, embedding1: torch.Tensor, embedding2: torch.Tensor
    ) -> torch.Tensor:
        # concatenated = torch.cat([embedding1, embedding2], dim=-1)

        # Pass through classifier
        score = embedding1 - embedding2

        return score.squeeze(-1)

    def forward(self, embedding1: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: concatenate embeddings and classify.

        Args:
            embedding1: First embedding [batch_size, input_dim]
            embedding2: Second embedding [batch_size, input_dim]

        Returns:
            Scores [batch_size, 1] representing entailment/similarity
        """
        embedding1["entailment_embedding"] = self.classifier(
            embedding1["sentence_embedding"]
        )
        return embedding1

    def save(self, save_dir: str, **kwargs) -> None:
        """Save the classifier state and configuration."""
        save_path = Path(save_dir)
        save_path.mkdir(parents=True, exist_ok=True)

        torch.save(self.state_dict(), save_path / "vector_entailment_classifier.pth")

        config = {
            "input_dim": self.input_dim,
        }
        with open(save_path / "classifier_config.json", "w") as f:
            json.dump(config, f, indent=2)

    @staticmethod
    def load(load_dir: str, **kwargs) -> "MLPHead":
        """Load MLP head from saved state."""
        load_path = Path(load_dir) / "2_VectorEntailmentHead"

        with open(load_path / "classifier_config.json") as f:
            config = json.load(f)

        print(config)
        mlp_head = VectorEntailmentHead(config["input_dim"])
        mlp_head.load_state_dict(
            torch.load(load_path / "vector_entailment_classifier.pth")
        )
        return mlp_head


class MLPHead(nn.Module):
    """
    MLP head that produces box embeddings (center and delta) from sentence embeddings.

    This module takes pooled sentence embeddings and produces two separate embeddings:
    - center: the center point of the box in embedding space
    - delta: the size/extent of the box in each dimension
    """

    def __init__(self, input_dim: int, output_dim: int):
        super(MLPHead, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.mlp_center = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, output_dim),
        )

        self.mlp_delta = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, output_dim),
        )

    def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Forward pass: convert sentence embedding to box embedding."""
        sentence_embedding = x["sentence_embedding"]
        center = self.mlp_center(sentence_embedding)
        delta = self.mlp_delta(sentence_embedding)

        # Concatenate center and delta to form box embedding
        x["sentence_embedding"] = torch.cat([center, delta], dim=-1)
        return x

    def save(self, save_dir: str, **kwargs) -> None:
        """Save the MLP head state and configuration."""
        save_path = Path(save_dir)
        save_path.mkdir(parents=True, exist_ok=True)

        torch.save(self.state_dict(), save_path / "mlp_head.pth")

        config = {"input_dim": self.input_dim, "output_dim": self.output_dim}

        with open(save_path / "config.json", "w") as f:
            json.dump(config, f, indent=2)

    @staticmethod
    def load(load_dir: str, **kwargs) -> "MLPHead":
        """Load MLP head from saved state."""
        # print(load_dir)
        load_path = Path(load_dir) / "2_MLPHead"

        with open(load_path / "config.json") as f:
            config = json.load(f)

        print(config)
        mlp_head = MLPHead(config["input_dim"], config["output_dim"])
        mlp_head.load_state_dict(torch.load(load_path / "mlp_head.pth"))
        return mlp_head


class DatasetManager:
    """Manages loading and preprocessing of datasets."""

    def __init__(self, config: TrainingConfig):
        self.config = config

    def _format_for_e5(self, sample: Dict) -> Dict:
        """Format samples for E5 model with query/passage prefixes."""
        result = {
            "anchor": f"query: {sample['anchor']}",
            "positive": f"passage: {sample['positive']}",
        }
        if "negative" in sample:
            result["negative"] = f"query: {sample['negative']}"
        if "sources" in sample:
            result["sources"] = sample["sources"]
        return result

    def _process_infinity_instruct(self, sample: Dict) -> Dict:
        """Process Infinity-Instruct dataset format."""
        anchor = None
        positive = None

        for conversation in sample["conversations"]:
            if conversation["from"] == "human":
                text = conversation["value"]
                # these thigns are not required
                if self.config.model_type == "e5-base":
                    anchor = f"query: {text}"
                else:
                    anchor = text
            else:  # assistant response
                text = conversation["value"]
                # these things are not required
                if self.config.model_type == "e5-base":
                    positive = f"passage: {text}"
                else:
                    positive = text

        return {"anchor": anchor, "positive": positive}

    def load_infinity_instruct_dataset(self) -> Dataset:
        """Load and process the Infinity-Instruct dataset."""
        dataset = load_dataset(
            "BAAI/Infinity-Instruct",
            "0625",
            split=f"train[:{self.config.dataset_size}]",
        )

        dataset = dataset.train_test_split(test_size=0.3, shuffle=False)
        dataset = dataset.map(self._process_infinity_instruct)

        # Filter and clean
        dataset = dataset.filter(lambda x: x["langdetect"] == "en").remove_columns(
            ["id", "conversations", "label", "langdetect", "source"]
        )

        return dataset

    def load_entailment_dataset(self) -> Tuple[Dataset, Dataset]:
        """Load entailment dataset and return train/validation splits."""
        entailment = Dataset.from_dict(
            load_from_disk(self.config.entailment_type).remove_columns(["sources"])[
                : self.config.dataset_size
            ]
        )

        if self.config.model_type == "e5-base":
            entailment = entailment.map(self._format_for_e5)

        split_dataset = entailment.train_test_split(test_size=0.3, shuffle=False)
        return split_dataset["train"], split_dataset["test"][:5000]

    def load_auxiliary_datasets(self) -> Dict[str, Dataset]:
        """Load additional datasets (MNLI, FollowBench, synthesized)."""
        datasets = {}

        # if self.config.use_vector_mode:
        #     return datasets

        # MNLI triplets
        datasets["mnli"] = Dataset.from_dict(
            load_from_disk("./mnli_triplets")[: self.config.dataset_size]
        )

        # FollowBench
        datasets["followbench"] = Dataset.from_dict(
            load_from_disk("./followbench_dataset")[:5000]
        )

        datasets["links"] = Dataset.from_dict(
            load_from_disk("./links_dataset_with_negative_fixed/")[
                : self.config.dataset_size
            ]
        )

        # Synthesized datasets
        if self.config.use_synthesized:
            datasets["synthesized"] = Dataset.from_dict(
                load_from_disk("./synthesized_dataset")[: self.config.dataset_size]
            ).shuffle(seed=42)
        elif self.config.use_synthesized_negative:
            datasets["synthesized"] = Dataset.from_dict(
                load_from_disk("./synthesized_dataset_negative_new")[
                    : self.config.dataset_size
                ]
            ).shuffle(seed=42)

        # Apply E5 formatting if needed
        if self.config.model_type == "e5-base":
            for key in datasets:
                datasets[key] = datasets[key].map(self._format_for_e5)

        return datasets

    def apply_augmentation(self, dataset: Dataset) -> Dataset:
        """Apply data augmentation strategies."""
        if self.config.use_simcse:
            # SimCSE-style augmentation: positive = anchor
            return dataset.map(lambda x: {**x, "positive": x["anchor"]})
        return dataset


class ModelFactory:
    """Factory for creating sentence transformer models."""

    @staticmethod
    def create_transformer(model_type: str) -> models.Transformer:
        """Create transformer model based on type."""
        model_configs = {
            "pretrained": "microsoft/mpnet-base",
            "finetuned": "sentence-transformers/all-mpnet-base-v2",
            "e5-base": "intfloat/e5-base-v2",
        }

        if model_type not in model_configs:
            raise ValueError(f"Unknown model type: {model_type}")

        return models.Transformer(model_configs[model_type])

    @staticmethod
    def create_sentence_transformer(config: TrainingConfig) -> SentenceTransformer:
        """Create complete sentence transformer with optional box embedding head."""
        transformer = ModelFactory.create_transformer(config.model_type)
        pooling = models.Pooling(
            transformer.get_word_embedding_dimension(), pooling_mode="mean"
        )

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if config.use_vector_mode:
            # Create standard vector model without MLP head
            input_dim = transformer.get_word_embedding_dimension()
            # TODO : Make this work properly i.e setup the proper options
            # vector_entailment_classifier = VectorEntailmentHead(input_dim)
            # return SentenceTransformer(modules=[transformer, pooling, vector_entailment_classifier]).to(device)
            return SentenceTransformer(modules=[transformer, pooling]).to(device)
        else:
            # Create box embedding model with MLP head
            input_dim = transformer.get_word_embedding_dimension()
            mlp_head = MLPHead(input_dim, input_dim)
            return SentenceTransformer(modules=[transformer, pooling, mlp_head]).to(
                device
            )


class LossManager:
    """Manages different loss functions for training."""

    def __init__(self, model: SentenceTransformer, config: TrainingConfig):
        self.model = model
        self.config = config

    def create_losses(self) -> Dict[str, CachedMultipleNegativesRankingLoss]:
        """Create loss functions for different dataset types."""
        if self.config.use_vector_mode:
            return self.create_losses_vector()
        else:
            return self.create_losses_box()

    def vector_entailment_similarity_diff(
        self,
        embedding1: torch.Tensor,
        embedding2: torch.Tensor,
    ) -> torch.Tensor:
        original_shape = embedding1.shape
        if len(original_shape) == 3:
            # [num_batches, batch_size, dim] -> flatten to [total_samples, dim]
            embedding1 = embedding1.reshape(-1, original_shape[-1])
            embedding2 = embedding2.reshape(-1, original_shape[-1])

        embedding1 = embedding1[:, None, :]
        embedding2 = embedding2[None, :, :]

        similarity_matrix = embedding1 - embedding2

        # Compute scores for all pairs

        if len(original_shape) == 3:
            # Reshape back if needed
            similarity_matrix = similarity_matrix.reshape(
                original_shape[0], original_shape[1], original_shape[2]
            )

        return similarity_matrix

    def create_losses_box(self) -> Dict[str, CachedMultipleNegativesRankingLoss]:
        """Create loss functions for box embeddings."""
        losses = {}

        # Standard similarity loss
        losses["response_instruction"] = CachedMultipleNegativesRankingLoss(
            self.model,
            similarity_fct=lambda x, y: similarity_function(
                x, y, self.config.volume_temp, self.config.intersection_temp
            ),
            mini_batch_size=self.config.mini_batch_size,
            show_progress_bar=False,
        )

        # Entailment loss
        losses["entailment"] = CachedMultipleNegativesRankingLoss(
            self.model,
            similarity_fct=lambda x, y: similarity_function_entailment(
                x, y, self.config.volume_temp, self.config.intersection_temp
            ),
            mini_batch_size=self.config.mini_batch_size,
            show_progress_bar=False,
        )

        # Additional losses for different dataset types
        losses["nli"] = losses["entailment"]
        losses["train_eval"] = losses["response_instruction"]
        losses["train_synth"] = losses["entailment"]
        losses["links"] = losses["entailment"]

        return losses

    def create_losses_vector(self) -> Dict[str, CachedMultipleNegativesRankingLoss]:
        """Create loss functions for vector models using cosine similarity."""
        losses = {}
        # WARNING: Remember to fix these issues

        # Standard cosine similarity loss for vector models
        losses["response_instruction"] = CachedMultipleNegativesRankingLoss(
            self.model,
            mini_batch_size=self.config.mini_batch_size,
            show_progress_bar=False,
        )

        losses["entailment"] = CachedMultipleNegativesRankingLoss(
            self.model,
            mini_batch_size=self.config.mini_batch_size,
            show_progress_bar=False,
            # type="entailment"
        )
        # losses["entailment"] = CachedMultipleNegativesRankingLoss(
        # self.model,
        # mini_batch_size=self.config.mini_batch_size,
        # show_progress_bar=False,
        # # similarity_fct=vector_entailment_similarity_csdelta,
        # )

        losses["nli"] = losses["entailment"]
        losses["train_eval"] = losses["response_instruction"]
        losses["train_synth"] = losses["entailment"]
        losses["links"] = losses["entailment"]

        return losses


class EvaluatorManager:
    """Manages evaluation metrics and evaluators."""

    @staticmethod
    def create_evaluators(
        validation_datasets: Dict[str, Dataset], config: TrainingConfig
    ) -> List:
        """Create evaluators for different validation datasets."""
        evaluators = []
        similarity_to_use_entailment = ""
        similarity_to_use_similarity = ""

        if config.use_vector_mode:
            similarity_to_use_entailment = "cosine"
            similarity_to_use_similarity = "cosine"
        else:
            similarity_to_use_entailment = "box_entailment"
            similarity_to_use_similarity = "box_intersection"

        # STS-B evaluator
        stsb_eval = load_dataset("sentence-transformers/stsb", split="validation")
        evaluators.append(
            EmbeddingSimilarityEvaluatorDiff(
                sentences1=stsb_eval["sentence1"],
                sentences2=stsb_eval["sentence2"],
                scores=stsb_eval["score"],
                name="sts-dev",
                similarity_fn_names=[similarity_to_use_similarity],
            )
        )

        # Entailment evaluators
        if "entailment_val" in validation_datasets:
            val_data = validation_datasets["entailment_val"]
            evaluators.append(
                TripletEvaluatorDiff(
                    anchors=val_data["anchor"],
                    positives=val_data["positive"],
                    negatives=val_data["negative"],
                    name="entailment_val",
                    similarity_fn_names=[similarity_to_use_entailment],
                )
            )

        if "followbench" in validation_datasets:
            fb_data = validation_datasets["followbench"]
            evaluators.append(
                TripletEvaluatorDiff(
                    anchors=fb_data["anchor"],
                    positives=fb_data["positive"],
                    negatives=fb_data["negative"],
                    name="entailment_followbench",
                    similarity_fn_names=[similarity_to_use_entailment],
                )
            )

        return evaluators


class BoxEmbeddingTrainer:
    """Main trainer class that orchestrates the training process."""

    def __init__(self, config: TrainingConfig):
        self.config = config
        self.dataset_manager = DatasetManager(config)
        self.setup_wandb()

    def setup_wandb(self):
        """Initialize Weights & Biases logging."""
        os.environ["WANDB_PROJECT"] = "clean_box_embeddings"
        wandb.login()

        wandb_config = {
            "volume_temp": self.config.volume_temp,
            "intersection_temp": self.config.intersection_temp,
            "training_size": self.config.dataset_size,
            "learning_rate": self.config.learning_rate,
            "fp16": self.config.fp16,
            "run_name": self.config.run_name,
            "batch_size": self.config.batch_size,
            "mini_batch_size": self.config.mini_batch_size,
            "grad_norm": self.config.grad_norm,
        }

        if self.config.resume:
            if not self.config.run_id:
                raise ValueError("Resume requires run_id for wandb")

            wandb.init(
                project="box_embeddings",
                config=wandb_config,
                name=self.config.run_name,
                resume="must",
                id=self.config.run_id,
            )
        else:
            wandb.init(
                project="clean_box_embeddings",
                config=wandb_config,
                name=self.config.run_name,
            )

    def prepare_datasets(self) -> Tuple[Dict[str, Dataset], Dict[str, Dataset]]:
        """Prepare training and validation datasets."""
        # Load main dataset
        infinity_dataset = self.dataset_manager.load_infinity_instruct_dataset()

        # Load entailment data
        train_datasets = {"response_instruction": infinity_dataset["train"]}
        validation_datasets = {}
        print(self.config.entailment_type)
        if self.config.entailment_type != "none":
            entailment_train, entailment_val = (
                self.dataset_manager.load_entailment_dataset()
            )
            train_datasets["entailment"] = entailment_train
            validation_datasets = {"entailment_val": entailment_val}
        else:
            # WARNING: Horrible way of doing things
            self.config.entailment_type = (
                "new_entailment_dataset_with_sister_with_negative"
            )
            entailment_train, entailment_val = (
                self.dataset_manager.load_entailment_dataset()
            )
            self.config.entailment_type = "none"
            validation_datasets = {"entailment_val": entailment_val}

        # Load auxiliary datasets
        aux_datasets = self.dataset_manager.load_auxiliary_datasets()

        # Prepare training datasets
        # if self.config.use_vector_mode:
        #     return train_datasets, {}

        if "mnli" in aux_datasets:
            train_datasets["nli"] = aux_datasets["mnli"]

        if self.config.use_synthesized or self.config.use_synthesized_negative:
            train_datasets["train_synth"] = aux_datasets["synthesized"]

        # Apply augmentation if needed
        if self.config.use_simcse:
            train_datasets["response_instruction"] = (
                self.dataset_manager.apply_augmentation(
                    train_datasets["response_instruction"]
                )
            )

        if self.config.use_links:
            train_datasets["links"] = aux_datasets["links"]

        # Prepare validation datasets
        if "followbench" in aux_datasets:
            validation_datasets["followbench"] = aux_datasets["followbench"]

        return train_datasets, validation_datasets

    def create_output_directory(self) -> str:
        """Create and return output directory path."""
        synthesized_suffix = ""
        if self.config.use_synthesized:
            synthesized_suffix = "_with_synthesized"
        elif self.config.use_synthesized_negative:
            synthesized_suffix = "_with_synthesized_negative"

        model_mode = "vector" if self.config.use_vector_mode else "box"
        local_path = (
            f"{self.config.model_type}_{self.config.dataset_size}_"
            f"use_simcse_{self.config.use_simcse}_new_entailment_corrected_"
            f"sister_with_mnli_with_negative{synthesized_suffix}_small_dim_{model_mode}"
        )

        models_path = Path("models") / local_path
        return str(Path(self.config.base_directory) / models_path)

    def train(self):
        """Execute the complete training pipeline."""
        print(f"Starting training with config: {self.config}")

        # Prepare datasets
        train_datasets, validation_datasets = self.prepare_datasets()

        # Create model
        model = ModelFactory.create_sentence_transformer(self.config)

        # Create losses
        loss_manager = LossManager(model, self.config)
        losses = loss_manager.create_losses()

        # Create evaluators
        evaluators = EvaluatorManager.create_evaluators(
            validation_datasets, self.config
        )

        # Setup training arguments
        output_dir = self.create_output_directory()

        args = SentenceTransformerTrainingArguments(
            output_dir=output_dir,
            report_to="wandb",
            run_name=self.config.run_name,
            num_train_epochs=1,
            per_device_train_batch_size=self.config.batch_size,
            per_device_eval_batch_size=self.config.batch_size,
            learning_rate=self.config.learning_rate,
            warmup_ratio=0.1,
            fp16=self.config.fp16,
            bf16=False,
            max_grad_norm=self.config.grad_norm,
            batch_sampler=BatchSamplers.NO_DUPLICATES,
            eval_strategy="steps",
            eval_steps=10,
            save_strategy="steps",
            save_steps=10,
            save_total_limit=10,
            logging_steps=1,
            remove_unused_columns=True,
            do_eval=True,
        )

        # Create trainer
        trainer = SentenceTransformerTrainer(
            model=model,
            args=args,
            train_dataset=train_datasets,
            loss=losses,
            evaluator=SequentialEvaluator(evaluators),
        )

        # Train
        trainer.train(resume_from_checkpoint=self.config.resume)

        # Save final model
        synthesized_suffix = ""
        if self.config.use_synthesized:
            synthesized_suffix = "_synth"
        elif self.config.use_synthesized_negative:
            synthesized_suffix = "_synth_neg"

        mode = "vector" if self.config.use_vector_mode else "box"
        entailment_name = (
            Path(self.config.entailment_type).name
            if self.config.entailment_type
            else "no_entailment"
        )

        final_output_path = Path(self.config.base_directory) / (
            f"{self.config.model_type}_"
            f"ds{self.config.dataset_size}_"
            f"{mode}_"
            f"bs{self.config.batch_size}_"
            f"mbs{self.config.mini_batch_size}_"
            f"lr{self.config.learning_rate}_"
            f"vt{self.config.volume_temp}_"
            f"it{self.config.intersection_temp}_"
            f"links{self.config.use_links}_"
            f"{entailment_name}"
            f"{synthesized_suffix}"
            f"_grad_norm_{self.config.grad_norm}_i_am_built_different"
        )
        model.save_pretrained(str(final_output_path))

        print(f"Training completed. Model saved to: {final_output_path}")


def parse_arguments() -> TrainingConfig:
    """Parse command line arguments and return config."""
    parser = argparse.ArgumentParser(
        description="Train box embedding or vector sentence transformers"
    )

    parser.add_argument(
        "-m",
        "--model",
        required=True,
        choices=["pretrained", "finetuned", "e5-base"],
        help="Model type to use",
    )
    parser.add_argument(
        "-d", "--dataset_size", type=int, required=True, help="Size of training dataset"
    )
    parser.add_argument(
        "-r", "--run_name", required=True, help="Name for this training run"
    )
    parser.add_argument(
        "-b", "--batch_size", type=int, required=True, help="Training batch size"
    )
    parser.add_argument(
        "-mb",
        "--mini_batch_size",
        type=int,
        required=True,
        help="Mini batch size for loss computation",
    )
    parser.add_argument(
        "--use_augmentation", type=float, default=0.0, help="Data augmentation value"
    )
    parser.add_argument(
        "--use_simcse", action="store_true", help="Use SimCSE-style augmentation"
    )
    parser.add_argument(
        "--resume", action="store_true", help="Resume training from checkpoint"
    )
    parser.add_argument("--run_id", help="W&B run ID for resuming")
    parser.add_argument(
        "--use_synthesized", action="store_true", help="Use synthesized dataset"
    )
    parser.add_argument(
        "--use_synthesized_negative",
        action="store_true",
        help="Use synthesized dataset with negatives",
    )
    parser.add_argument("--use_links", action="store_true", help="Use links dataset")
    parser.add_argument(
        "--entailment_type", required=True, help="Entailment dataset path"
    )
    parser.add_argument("--grad_norm", default=1.0, help="Grad Norm", type=float)
    parser.add_argument(
        "--use_vector_mode",
        action="store_true",
        help="Train basic vector models instead of box embeddings",
    )

    args = parser.parse_args()

    return TrainingConfig(
        model_type=args.model,
        dataset_size=args.dataset_size,
        run_name=args.run_name,
        batch_size=args.batch_size,
        mini_batch_size=args.mini_batch_size,
        use_augmentation=args.use_augmentation,
        use_simcse=args.use_simcse,
        resume=args.resume,
        run_id=args.run_id,
        use_synthesized=args.use_synthesized,
        use_synthesized_negative=args.use_synthesized_negative,
        entailment_type=args.entailment_type,
        use_links=args.use_links,
        use_vector_mode=args.use_vector_mode,
        grad_norm=args.grad_norm,
    )


def main():
    """Main execution function."""
    config = parse_arguments()
    trainer = BoxEmbeddingTrainer(config)
    trainer.train()


if __name__ == "__main__":
    main()
