import os
from dataclasses import dataclass, field
from typing import Optional, List

import torch
from transformers import TrainingArguments, HfArgumentParser, add_start_docstrings
from transformers.utils import logging

logger = logging.get_logger(__name__)


@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class TrainingArguments(TrainingArguments):

    # === 1. Model Parameters ===

    text_model_path: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the text model checkpoint"}
    )

    audio_model_path: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the audio model checkpoint"}
    )

    audio_model_type: Optional[str] = field(
        default=None,
        metadata={"help": "Type of audio model, which can be found in `configs/model_configs`"}
    )

    vision_model_path: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the vision model checkpoint"}
    )

    vision_projector_path: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the vision model checkpoint"}
    )

    vision_projector_type: Optional[str] = field(
        default="mlp2x_gelu",
        metadata={"help": "Type of vision projector"}
    )

    vision_projector_learning_rate: Optional[float] = field(
        default=None, metadata={"help": "Learning rate for vision projector"}
    )

    snac_model_path: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the snac model checkpoint"}
    )

    # === 2. Dataset Parameters ===

    dataset_paths: Optional[List[str]] = field(
        default=None,
        metadata={"help": "Paths to dataset"}
    )

    dataset_ratios: Optional[List[float]] = field(
        default=None,
        metadata={"help": "Ratios used for each dataset"}
    )

    dataset_split: Optional[str] = field(
        default='train', metadata={"help": "Tag of dataset"}
    )

    prompt_type: Optional[str] = field(
        default="chat", metadata={"help": "Type of prompt, should be either `plain` or `chat`."}
    )

    group_by_modality_length: Optional[bool] = field(
        default=False, metadata={"help": "Group training samples by modality length, which can accelerate training."}
    )

    image_folder: Optional[str] = field(
        default=None,
        metadata={"help": "Path to image folder"}
    )

    image_aspect_ratio: Optional[str] = field(
        default="square", metadata={"help": "Aspect ratio of images, which can be either `square` or `pad`."}
    )

    use_audio_representation: Optional[bool] = field(
        default=True, metadata={"help": "Whether to use audio representation."}
    )

    # === 3. Optimization Parameters ===

    min_learning_rate: float = field(
        default=1e-6, metadata={"help": "Min learning rate"}
    )

    loss_ratios: Optional[List[float]] = field(
        default=None, metadata={"help": "Ratios of different losses"}
    )

    loss_types: Optional[List[str]] = field(
        default=None, metadata={"help": "Types of different losses"}
    )

    freeze_encoder: Optional[bool] = field(
        default=True, metadata={"help": "Freeze encoder"}
    )

    freeze_decoder: Optional[bool] = field(
        default=False, metadata={"help": "Freeze decoder layers"}
    )

    train_projector_only: Optional[bool] = field(
        default=True, metadata={"help": "Only train vision projector between CLIP Encoder and LLM."}
    )

    # === 4. LoRA Parameters ===
    
    lora_rank: Optional[int] = field(
        default=None, metadata={"help": "Rank for LoRA"}
    )

    lora_alpha: Optional[int] = field(
        default=None, metadata={"help": "Alpha for LoRA"}
    )

    lora_dropout: Optional[float] = field(
        default=0.05, metadata={"help": "Dropout for LoRA"}
    )


def parse_args() -> TrainingArguments:

    # parse arguments
    parser = HfArgumentParser((TrainingArguments, ))
    (args, ) = parser.parse_args_into_dataclasses()

    if args.lr_scheduler_type == "polynomial":
        # If polynomial decay is used, we use a linear decay
        args.lr_scheduler_kwargs = {
            "power": 1.0,
            "lr_end": args.min_learning_rate
        }
    elif args.lr_scheduler_type == "cosine_with_min_lr":
        args.lr_scheduler_kwargs = {"min_lr": args.min_learning_rate}

    # Logging arguments
    args_str = "\n" + "#" * 40 + "Training Arguments" + "#" * 40 + "\n"
    for i, arg in enumerate(vars(args)):
        name = str(arg).replace("\n", "↵")
        value = str(getattr(args, arg)).replace("\n", "↵")
        args_str += f"[{i:03d}] {name}" + '.' * max(
            92 - len(name) - len(value), 5) + f"{value}\n"
    args_str += "#" * 98 + "\n"
    logger.info(args_str)

    # save the training arguments
    if args.do_train:
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
    else:
        torch.save(args, os.path.join(args.output_dir, "evaluation_args.bin"))

    return args
