import time

from dataclasses import dataclass

import torch
import transformers

from transformers import TrainerControl
from transformers import TrainerState
from transformers import TrainingArguments

from liger_kernel.utils import infer_device

# https://simple.wikipedia.org/wiki/Byte
# For memory, we use binary system
M_BIN_UNIT = 2**20
# For metrics (tflops), we use decimal system
T_DEC_UNIT = 10**12


def round_to_n_decimal(x, n):
    return round(x, n)


@dataclass
class Precision:
    """
    Precision is a dataclass to store the number of decimal points for each metric.
    """

    n_decimal_time: int
    n_decimal_memory: int
    n_decimal_TPS: int


@dataclass
class State:
    """
    State is a dataclass to store the internal state of the efficiency callback.
    """

    n_warmup_steps: int = 0
    total_peak_memory_allocated: float = float("-inf")
    total_peak_memory_reserved: float = float("-inf")

    step_start_time: float = 0.0
    elapsed_time: float = 0.0

    elapsed_step: int = 0

    step_start_tokens_seen: int = 0
    elapsed_tokens_seen: int = 0

    global_start_step: int = 0


@dataclass
class Time:
    """
    Time is a dataclass to store the time-related metrics.
    """

    step: int = 0
    step_time_sec: float = 0.0
    avg_step_time_sec: float = 0.0
    time_to_completion_sec: float = 0.0
    estimated_total_time_sec: float = 0.0


@dataclass
class Memory:
    """
    Memory is a dataclass to store the memory-related metrics.
    """

    step_peak_memory_allocated_MB: float = 0.0
    step_peak_memory_reserved_MB: float = 0.0
    total_peak_memory_allocated_MB: float = 0.0
    total_peak_memory_reserved_MB: float = 0.0


@dataclass
class TPS:
    """
    TPS is a dataclass to store the tokens per second metrics.
    """

    step_tokens_per_second: float = 0.0
    avg_tokens_per_second: float = 0.0


class EfficiencyCallback(transformers.TrainerCallback):
    """
    EfficiencyCallback is a callback to track the efficiency of the training process.
    The tracked stats include: step time, memory, and throughput.

    It requires including `--include_num_input_tokens_seen` and `logging_steps=1` in the training arguments.

    Args:
        n_warmup_steps: number of warmup steps
            The stats in the first n_warmup_steps will not be added into the aggregated stats
            This is because the first few steps might take longer due to jit compliation and other initialization overheads
        n_decimal_time: number of decimal points for time
        n_decimal_memory: number of decimal points for memory
        n_decimal_TPS: number of decimal points for TPS
    """

    def __init__(self, n_warmup_steps=2, n_decimal_time=2, n_decimal_memory=2, n_decimal_TPS=2):
        self.state = State(
            n_warmup_steps,
        )

        self.precision = Precision(n_decimal_time, n_decimal_memory, n_decimal_TPS)

        self.time = Time()
        self.memory = Memory()
        self.tps = TPS()
        self.device = infer_device()

    def on_init_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        """
        Event called at the end of the initialization of the [`Trainer`].
        """
        if not args.include_num_input_tokens_seen:
            raise Exception(
                'Please pass training argument "--include_num_input_tokens_seen" to track tokens per second'
            )
        if args.logging_steps != 1:
            raise Exception("Please set logging_steps=1 to track the efficiency metrics accurately")

    def on_train_begin(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        # if loaded from checkpoints, global_start_step is not 1 but state.global_step
        self.state.global_start_step = state.global_step

    def on_log(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        logs: dict[str, float],
        **kwargs,
    ):
        if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps):
            return
        else:
            # spread self.time, self.memory, self.tps to logs
            logs.update(self.time.__dict__)
            logs.update(self.memory.__dict__)
            logs.update(self.tps.__dict__)

    def on_step_begin(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        """
        Event called at the beginning of a training step. If using gradient accumulation, one training step might take
        several inputs.
        """
        # memory
        getattr(torch, self.device).reset_peak_memory_stats()

        # time
        self.state.step_start_time = time.perf_counter()

    def on_step_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps):
            # The end the current step_start_tokens_seen is the start of next iteration

            # tokens
            self.state.step_start_tokens_seen = state.num_input_tokens_seen
            return

        # time
        current_time = time.perf_counter()
        step_time = current_time - self.state.step_start_time
        self.state.elapsed_time += step_time

        # step
        global_step = state.global_step
        self.state.elapsed_step += 1
        avg_step_time = self.state.elapsed_time / self.state.elapsed_step

        self.time.step = global_step
        self.time.step_time_sec = round_to_n_decimal(step_time, self.precision.n_decimal_time)
        self.time.avg_step_time_sec = round_to_n_decimal(avg_step_time, self.precision.n_decimal_time)
        self.time.time_to_completion_sec = round_to_n_decimal(
            avg_step_time * (state.max_steps - global_step),
            self.precision.n_decimal_time,
        )
        self.time.estimated_total_time_sec = round_to_n_decimal(
            avg_step_time * state.max_steps, self.precision.n_decimal_time
        )

        # memory
        step_peak_memory_allocated = getattr(torch, self.device).memory.max_memory_allocated()
        step_peak_memory_reserved = getattr(torch, self.device).memory.max_memory_reserved()

        self.memory.step_peak_memory_allocated_MB = round_to_n_decimal(
            step_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory
        )
        self.state.total_peak_memory_allocated = max(self.state.total_peak_memory_allocated, step_peak_memory_allocated)
        self.memory.total_peak_memory_allocated_MB = round_to_n_decimal(
            self.state.total_peak_memory_allocated / M_BIN_UNIT,
            self.precision.n_decimal_memory,
        )

        self.memory.step_peak_memory_reserved_MB = round_to_n_decimal(
            step_peak_memory_reserved / M_BIN_UNIT, self.precision.n_decimal_memory
        )

        self.state.total_peak_memory_reserved = max(self.state.total_peak_memory_reserved, step_peak_memory_reserved)

        self.memory.total_peak_memory_reserved_MB = round_to_n_decimal(
            self.state.total_peak_memory_reserved / M_BIN_UNIT,
            self.precision.n_decimal_memory,
        )

        # tokens
        step_tokens_seen = state.num_input_tokens_seen - self.state.step_start_tokens_seen

        self.state.elapsed_tokens_seen += step_tokens_seen

        self.tps.step_tokens_per_second = round_to_n_decimal(
            step_tokens_seen / step_time,
            self.precision.n_decimal_TPS,
        )

        self.tps.avg_tokens_per_second = round_to_n_decimal(
            self.state.elapsed_tokens_seen / self.state.elapsed_time,
            self.precision.n_decimal_TPS,
        )

        # The end the current step_start_tokens_seen is the start of next iteration

        # tokens
        self.state.step_start_tokens_seen = state.num_input_tokens_seen