# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims


class BatchWarmupStrategy:
    """
    Manages batch warmup strategy that gradually increases gradient accumulation steps
    during the initial phase of training. This helps with training stability and
    convergence, especially for large models.

    The strategy accounts for different parallelization approaches:
    - Data Parallel (DP): Standard gradient accumulation
    - Pipeline Parallel (PP): Different accumulation needs due to pipeline stages
    """

    def __init__(
        self,
        job_config: JobConfig,
        parallel_dims: ParallelDims,
        target_gradient_accumulation_steps: int,
        dp_degree: int,
    ):
        self.enabled = job_config.training.enable_batch_warmup
        self.strategy = job_config.training.batch_warmup_strategy
        self.warmup_steps = job_config.training.batch_warmup_steps
        self.start_ratio = job_config.training.batch_warmup_start_ratio
        self.target_grad_accum_steps = target_gradient_accumulation_steps
        self.parallel_dims = parallel_dims
        self.dp_degree = dp_degree

        # Calculate starting gradient accumulation steps
        # Ensure at least 1 step
        self.start_grad_accum_steps = max(
            1, int(target_gradient_accumulation_steps * self.start_ratio)
        )

        # Pipeline parallel needs special handling
        self.is_pp_enabled = parallel_dims.pp_enabled

        # Pre-calculate warmup schedule for step strategy
        if self.strategy == "step":
            self._calculate_step_schedule()

    def _calculate_step_schedule(self):
        """Calculate step-wise warmup schedule similar to the original hardcoded logic."""
        # Create a schedule that doubles gradient accumulation steps at specific intervals
        self.step_schedule = []

        # Calculate number of doubling steps needed
        current_steps = self.start_grad_accum_steps
        step_interval = self.warmup_steps // 5  # Distribute over 5 intervals

        for i in range(5):
            step_threshold = (i + 1) * step_interval
            self.step_schedule.append((step_threshold, current_steps))
            current_steps = min(current_steps * 2, self.target_grad_accum_steps)

        # Ensure we reach target at the end
        self.step_schedule.append((self.warmup_steps, self.target_grad_accum_steps))

    def get_gradient_accumulation_steps(self, current_step: int) -> int:
        """
        Get the gradient accumulation steps for the current training step.

        Args:
            current_step: Current training step (0-indexed)

        Returns:
            Number of gradient accumulation steps to use
        """
        if not self.enabled or current_step >= self.warmup_steps:
            return self.target_grad_accum_steps

        if self.strategy == "linear":
            return self._linear_warmup(current_step)
        elif self.strategy == "exponential":
            return self._exponential_warmup(current_step)
        elif self.strategy == "step":
            return self._step_warmup(current_step)
        else:
            raise ValueError(f"Unknown batch warmup strategy: {self.strategy}")

    def _linear_warmup(self, current_step: int) -> int:
        """Linear increase from start_ratio to 1.0 over warmup_steps."""
        progress = current_step / self.warmup_steps
        current_ratio = self.start_ratio + progress * (1.0 - self.start_ratio)
        return max(1, int(self.target_grad_accum_steps * current_ratio))

    def _exponential_warmup(self, current_step: int) -> int:
        """Exponential increase from start_ratio to 1.0 over warmup_steps."""
        progress = current_step / self.warmup_steps
        # Use exponential curve: start_ratio * (1/start_ratio)^progress
        current_ratio = self.start_ratio * ((1.0 / self.start_ratio) ** progress)
        return max(1, int(self.target_grad_accum_steps * current_ratio))

    def _step_warmup(self, current_step: int) -> int:
        """Step-wise increase based on predefined schedule."""
        for step_threshold, grad_accum_steps in self.step_schedule:
            if current_step < step_threshold:
                return grad_accum_steps
        return self.target_grad_accum_steps

    def get_loss_scaling_factor(self, current_step: int) -> float:
        """
        Get the loss scaling factor to compensate for different gradient accumulation steps.

        This ensures that the effective batch size remains consistent during warmup
        by scaling the loss appropriately.

        Args:
            current_step: Current training step

        Returns:
            Scaling factor to multiply the loss by
        """
        if not self.enabled or current_step >= self.warmup_steps:
            return 1.0

        current_grad_accum = self.get_gradient_accumulation_steps(current_step)

        # Scale loss to maintain effective batch size consistency
        # Higher scaling when we have fewer gradient accumulation steps
        scaling_factor = self.target_grad_accum_steps / current_grad_accum

        return scaling_factor

    def should_apply_warmup(self, current_step: int) -> bool:
        """Check if warmup should be applied at the current step."""
        return self.enabled and current_step < self.warmup_steps

    def get_effective_batch_size(self, current_step: int) -> int:
        """
        Calculate the effective batch size at the current step.

        Args:
            current_step: Current training step

        Returns:
            Effective batch size (local_batch_size * dp_degree * grad_accum_steps)
        """
        grad_accum_steps = self.get_gradient_accumulation_steps(current_step)
        # Note: local_batch_size would need to be passed in or stored
        # For now, we return the gradient accumulation component
        return grad_accum_steps * self.dp_degree

    def log_warmup_info(self, current_step: int) -> dict:
        """
        Get logging information about the current warmup state.

        Returns:
            Dictionary with warmup information for logging
        """
        if not self.enabled:
            return {}

        grad_accum = self.get_gradient_accumulation_steps(current_step)
        loss_scaling = self.get_loss_scaling_factor(current_step)
        is_warmup = self.should_apply_warmup(current_step)

        return {
            "batch_warmup_active": is_warmup,
            "grad_accum_steps": grad_accum,
            "loss_scaling_factor": loss_scaling,
            "warmup_progress": min(1.0, current_step / self.warmup_steps)
            if is_warmup
            else 1.0,
        }
