# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import math
from dataclasses import dataclass
from typing import Optional, Union
import warnings


@dataclass
class TrainArgs:
    """Training-related arguments"""

    save_interval: Optional[int] = None
    """Number of optimizer steps between saving checkpoints"""
    save_per_tokens: Optional[int] = None
    """Number of tokens between saving checkpoints"""
    log_interval: int = 1
    """Number of iterations between logging calls"""
    global_batch_size: int = 64
    """Number of samples between optimizer steps across data-parallel ranks"""
    micro_batch_size: int = 4
    """Number of samples per data-parallel rank"""
    lr_warmup_steps: Optional[int] = None
    """Number of iterations with learning rate warmup active"""
    lr_warmup_fraction: Optional[float] = None
    """The fraction of an epoch to use for learning rate warmup (fraction of the training data)"""
    lr_warmup_ratio: Optional[float] = None
    """The ratio of iters to use for learning rate warmup (ratio of the training iterations)"""
    epochs: Optional[int] = None
    """Number of epochs to train on"""
    # TODO: `pretrain` is the only script using `max_tokens` explicitly. replace it with epoch_size*epochs?
    max_tokens: Optional[int] = None
    """Total number of tokens to train on"""
    max_steps: Optional[int] = None
    """Limits the number of optimizer steps to run"""
    max_seq_length: Optional[int] = None
    """Limits the length of samples"""
    tie_embeddings: Optional[bool] = None
    """Whether to tie the embedding weights with the language modeling head weights"""

    # Optimization args
    max_norm: Optional[float] = None
    min_lr: float = 6e-5

    def __post_init__(self) -> None:
        if self.save_per_tokens is not None:
            raise NotImplementedError("`--train.save_per_tokens` is not yet implemented.")
        
        if self.save_interval and self.save_per_tokens:
            raise ValueError(
                "Can't provide both `--train.save_interval` and `--train.save_per_tokens`. Choose one."
            )
        if self.save_interval is None and self.save_per_tokens is None:
            raise ValueError(
                "Must provide either `--train.save_interval` or `--train.save_per_tokens`."
            )
        
        if self.lr_warmup_steps is None and self.lr_warmup_fraction is None and self.lr_warmup_ratio is None:
            raise ValueError(
                "Must provide either `--train.lr_warmup_steps`, `--train.lr_warmup_fraction` or `--train.lr_warmup_ratio`."
            )
        if self.lr_warmup_fraction and self.lr_warmup_steps:
            raise ValueError(
                "Can't provide both `--train.lr_warmup_fraction` and `--train.lr_warmup_steps`. Choose one."
            )
        if self.lr_warmup_fraction and self.lr_warmup_ratio:
            raise ValueError(
                "Can't provide both `--train.lr_warmup_fraction` and `--train.lr_warmup_ratio`. Choose one."
            )
        if self.lr_warmup_steps and self.lr_warmup_ratio:
            raise ValueError(
                "Can't provide both `--train.lr_warmup_steps` and `--train.lr_warmup_ratio`. Choose one."
            )
        if self.lr_warmup_fraction and not (0 <= self.lr_warmup_fraction <= 1):
            raise ValueError("`--train.lr_warmup_fraction` must be between 0 and 1.")
        if self.lr_warmup_ratio and not (0 <= self.lr_warmup_ratio <= 1):
            raise ValueError("`--train.lr_warmup_ratio` must be between 0 and 1.")

        if self.lr_warmup_steps and self.max_steps and (self.lr_warmup_steps >= self.max_steps):
            warnings.warn(
                "`--train.lr_warmup_steps` should be less than `--train.max_steps`."
                f" Got {self.lr_warmup_steps} lr_warmup_steps and {self.max_steps} max_steps.", UserWarning)

    def gradient_accumulation_iters(self, devices: int) -> int:
        """Number of iterations between gradient synchronizations"""
        gradient_accumulation_iters = self.batch_size(devices) // self.micro_batch_size
        assert gradient_accumulation_iters > 0
        return gradient_accumulation_iters

    def batch_size(self, devices: int) -> int:
        """Number of samples between optimizer steps per data-parallel rank"""
        batch_size = self.global_batch_size // devices
        assert batch_size > 0
        return batch_size

    def warmup_iters(self, devices: int, max_iters: int, train_dataloader) -> int:
        """Number of iterations to warm up the learning rate."""
        if self.lr_warmup_ratio:
            return math.ceil(self.lr_warmup_ratio * max_iters)
        if self.lr_warmup_fraction:
            return min(max_iters, math.ceil(self.lr_warmup_fraction * len(train_dataloader)))
        if self.lr_warmup_steps:
            return min(max_iters, self.lr_warmup_steps * self.gradient_accumulation_iters(devices))
        return 0


@dataclass
class EvalArgs:
    """Evaluation-related arguments"""

    interval: int = 600
    """Number of optimizer steps between evaluation calls"""
    max_new_tokens: Optional[int] = None
    """Number of tokens to generate"""
    max_iters: int = 100
    """Number of iterations"""
    initial_validation: bool = False
    """Whether to evaluate on the validation set at the beginning of the training"""
    final_validation: bool = True
    """Whether to evaluate on the validation set at the end of the training"""
    evaluate_example: Union[str, int] = "first"
    """How to pick an example instruction to evaluate periodically during training.
       Can be "first", "random", or an integer index to pick a specific example."""
