from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
import math
from pathlib import Path
from typing import Any, Callable, Literal

from litgpt.config import Config

# from litgpt.model import GPT
from litgpt.utils import num_parameters, parse_devices
from mup import get_shapes, load_base_shapes, make_base_shapes

from saws.config.base_config import BaseConfig
from saws.config.ConfigWrapper import ConfigWrapper
from saws.config.data_config import DataHandler, preprocess_wikitext
from saws.config.eval_config import EvalHandler
from saws.config.log_args import LoggingArgs
from saws.config.warmstart_config import WarmstartConfig as WarmstartConfig
from saws.lr_schedule import LRScheduler
from saws.model import GPT_Scales


def resolve_model_config(
    model_config: Config | None = None,
    model_config_path: Path | None = None,
    model_checkpoint_dir: Path | None = None,
    model_name: str | None = None,
) -> Config:
    """4 methods of loading a model configuration...

    Make sure this function always returns an initialized Config no matter what the train_config args are

    """
    if model_checkpoint_dir is not None:
        model_checkpoint_dir = Path(model_checkpoint_dir)

    model_config_path = (
        model_checkpoint_dir / "model_config.yaml"
        if model_config_path is None and model_checkpoint_dir is not None and model_checkpoint_dir.is_dir()
        else model_config_path
    )

    if model_config is None:
        # Setting up model configuration
        if model_config_path and model_name is None:
            config = Config.from_file(model_config_path)
        elif model_name and model_config_path is None:
            config = Config.from_name(model_name)
        elif model_config_path and model_name:
            raise ValueError("Only one of `model_name` or `model_config_path` can be set.")
        else:
            raise ValueError("Please specify `model_name` or `model_config_path`")
    elif isinstance(model_config, dict):
        config = Config(**model_config)
    else:
        return model_config

    return config


def resolve_train_steps(
    max_tokens: int | None = None,
    max_train_steps: int | None = None,
    tokens_per_param: int | None = None,
    micro_batch_size: int | None = None,
    block_size: int | None = None,
    trainable_params: int | None = None,
    tokens_per_step: int | None = None,
    accumulation_iters: int = 1,
    devices: int = 1,
    deepseek_hparams: bool = True,
) -> int:
    """3 ways of computing train_steps...

    available settings:
        max_train_steps,
        tokens_per_param, trainable_params
        max_tokens, tokens_per_step
        max_tokens, batch_size, block_size

    """
    if max_train_steps is None and tokens_per_param is None and max_tokens is None:
        raise ValueError("One of max_train_steps, tokens_per_param, max_tokens must be set.")
    if (tokens_per_param and max_tokens) or (max_train_steps and max_tokens) or (max_train_steps and tokens_per_param):
        raise ValueError(
            f"Only one of max_train_steps={max_train_steps}, "
            f"tokens_per_param={tokens_per_param}, max_tokens={max_tokens} can be set."
        )
    if deepseek_hparams and (max_tokens is None or tokens_per_param is None) and max_train_steps:
        raise ValueError("When using `deepseek_hparams`, one should use either `max_tokens` or `tokens_per_param`")

    if max_train_steps:
        train_steps = max_train_steps

    if tokens_per_param:
        if trainable_params:
            max_tokens = trainable_params * tokens_per_param
        else:
            raise ValueError(
                f"when tokens_per_param={tokens_per_param} is set, trainable_params={trainable_params} "
                f"must also be set."
            )

    if max_tokens:
        if tokens_per_step:
            train_steps = int(max_tokens // tokens_per_step)
        elif micro_batch_size and block_size:
            train_steps = int(max_tokens // (micro_batch_size * block_size * accumulation_iters * devices))
        else:
            raise ValueError(
                f"Either tokens_per_step="
                f"{tokens_per_step} or both batch_size={micro_batch_size} "
                f"and block_size={block_size} must be set with max_tokens={max_tokens}"
            )

    return train_steps


def get_mup_base_shape(target_config: Config | ConfigWrapper, base_scales: dict[str, int] | None) -> dict | None:
    base_config = ConfigWrapper.from_config(target_config)
    delta_config = deepcopy(base_config)
    if base_scales is not None:
        # Set the scale of base and delta config
        for name, base_scale in base_scales.items():
            setattr(base_config, name, base_scale)
            setattr(delta_config, name, base_scale * 2)
        base_config = ConfigWrapper.from_config(base_config.config)
        delta_config = ConfigWrapper.from_config(delta_config.config)

    base_shapes = get_shapes(GPT_Scales(base_config, mup_init=True))
    delta_shapes = get_shapes(GPT_Scales(delta_config, mup_init=True))

    return make_base_shapes(base_shapes, delta_shapes)


@dataclass
class TrainConfig(BaseConfig):
    """Configuration to specify a recipie for the model training. This class initializes all the necessary values used
    during the training based on the initialization arguments.

    Note:
        This object initializes a Config object which is an input to the GPT model.
        We never initialize or load model weights inside TrainConfig
    Note:
        The arguments that are not appended to the self.ignore_list are not allowed to change during the lifecycle
        of this object. This is because, those arguments are written into the yaml files when the config is saved,
        and loaded using those exact values again. Check true_weight_decay attribute for an example.
    Note:
        Avoid putting paths inside config objects as they are not reliable, and require to be reset for
         every training experiment.

    """

    micro_batch_size: int
    """The batch per iteration."""
    block_size: int
    """Max sequence length/context length/block size."""
    weight_decay: float
    """Weight Decay for AdamW optimizer."""
    max_val_steps: int
    """N of validation steps on validation data."""
    max_micro_batch_size: int | None = None
    """Hardware specific maximum batch size to enforce."""
    max_lr: float | None = None
    """The maximum Learning Rate."""
    accumulation_iters: int = 1
    """Number of accumulation iters per device."""
    devices: int | str = "auto"
    """The number of devices to be trained on."""

    # model config
    model_config: ConfigWrapper | Config | None = None
    """Config object for model config."""
    model_config_path: Path | None = None
    """Config Path for the Config object, ignored if model_config provided."""
    model_name: str | None = None
    """Model name to load from HF hub."""
    weight_init_type: Literal["plain", "scaled", "GPT-NeoX", "DeepSeek"] | None = None
    """Model weight initialization."""

    # weight tying
    share_embeddings: bool = False

    # LR scheduler
    min_lr: float = 0.
    """Minimum learning rate that the scheduler is bounded to."""
    warmup_fraction:  float | None = None
    """Fraction of steps to cooldown schedule."""
    scheduler_type: Literal["constant", "cosine", "linear"] = "constant"
    """Type of scheduler for the `middle` section, that is, between the warmup and cooldown."""
    scheduler_args: dict | None = None
    """All torch scheduler arguments."""
    scheduler_lr_decay_factor:  float | None = None
    """Fraction of steps to cooldown schedule."""
    cooldown_fraction:  float | None = None
    """Fraction of steps to cooldown schedule."""
    cooldown_type: Literal["linear"] = "linear"
    cooldown_lr_decay_factor: float = 0.

    # training length
    max_train_steps: int | None = None
    """Max training steps to train for."""
    tokens_per_param: int | None = None
    """Used to calculate train_steps if train_steps not provided."""
    max_tokens: int | None = None

    early_stopping_max_train_steps: int | None = None
    """Number of training steps after which the training early stops."""
    early_stopping_tokens_per_param: int | None = None
    """Used to calculate early_stopping_train_steps if early_stopping_max_train_steps is not provided."""
    early_stopping_max_tokens: int | None = None
    """Number of tokens after which the training early stops."""

    # train details
    clip_max_norm: int | None = None
    clip_max_val: float | None = None
    validate_every: int = 5
    """Number of steps after which to validate the model."""
    z_loss_eps: float | None = None
    "Epsilon value for Z loss"

    # optimizer
    adam_beta_1: float = 0.9
    """Adam beta_1."""
    adam_beta_2: float = 0.95
    """Adam beta_2."""
    adam_eps: float = 1e-8
    """Adam epsilon."""
    independent_wd: bool = False
    "Whether to use independent weight decay during AdamW"

    # MuParam width
    mup_base_scales: dict[str, int] | int | None = None
    """Dict of scaling dimension to base scale."""
    mup_base_shape_path: str | Path | None = None
    """The path of the base model shape, ."""

    # DeepSeek hyperparameters
    deepseek_hparams: bool = False
    """Changes the learning rate, accumulation iters, and weight initialization to match deepseek's algorithm based on
    compute."""

    # logging details
    tracked_metrics: dict[str, int] | None = None
    global_log_step: int = 1

    # seeding
    seed: int = 444

    # checkpoint management
    base_name: str = "lit_model",
    """Name for the checkpoint state file.
    Stores files in the format of f"{base_name}.pth" or f"{base_name}_{step}.pth" depending 
    on other arguments and their interaction.
    """
    load_state_path: Path | None = None
    """Path to load checkpoint, random states, etc. for continued training.
    Equivalent to `load_dir` in `Checkpointer`.
    If run exists, and this is None, starts training from scratch.
    """
    save_state_path: Path | None = None
    """Path to save checkpoint, random states, etc. for continued training.
    Equivalent to `save_dir` in `Checkpointer`.
    """
    update_every_k: int | None = 100
    """Number of steps after which to update the checkpoint.
    Retains the same name for the checkpoint and replaces/updates it.
    """
    save_every_k: int | None = None
    """Number of steps after which to save the state.
    Saves the state with the step appended to the filename as f"{base_name}_{step}.pth"
    """
    save_total_k: int | None = None
    """Total number of checkpoints to save.
    Saves the total k checkpoints based on the total number of steps.
    NOTE: Overrides `save_every_k` and `update_every_k` if set.
    """
    save_top_k: int | None = None  # TODO: Implement this
    """Number of top checkpoints to save.
    Saves the top k checkpoints based on the best pretraining loss seen.
    """
    save_last_k: int | None = None  # TODO: Implement this
    """Saves the last k checkpoints.
    Saves the last k checkpoints based on the last k steps saved as f"{base_name}_{step}.pth".
    """
    save_weights_only: bool = False  # TODO: Implement this
    """Saves only the weights of the model.
    """
    save_step_list: list[int] | None = None  # TODO: Implement this
    """Save the state at specific steps.
    For example: [100, 200, 300] will save files as 
    f"{base_name}_100.pth", f"{base_name}_200.pth", f"{base_name}_300.pth", etc.
    """

    # warmstarting setting
    warmstart_config: dict[str, Any] | None = None
    warmstart_config_path: Path | None = None

    # layer-freezing setting
    layers_to_train: int | None = None  # trains all layers when None or a large-large number

    def __post_init__(self) -> None:
        super().__post_init__()
        if self.load_state_path is not None and isinstance(self.load_state_path, str):
            self.load_state_path = Path(self.load_state_path)
        if self.save_state_path is not None and isinstance(self.save_state_path, str):
            self.save_state_path = Path(self.save_state_path)

        self.ignore_fields.extend(["model_config_path", "model_name"])
        self.model_config = resolve_model_config(
            self.model_config, self.model_config_path, self.load_state_path, self.model_name
        )
        # override model block_size
        self.model_config.block_size = self.block_size
        self.model_config = ConfigWrapper.from_config(self.model_config)
        self._mup_base_shape: dict | None = None

        self.trainable_params = num_parameters(GPT_Scales(self.model_config), requires_grad=True)
        # TODO: fix parse_devices to be more strict
        self.devices = parse_devices(self.devices)

        if isinstance(self.devices, str):
            raise ValueError("`devices` is wrongly initialized and should be an in")

        # check for max batch size allowed
        if self.max_micro_batch_size is not None and self.max_micro_batch_size < self.micro_batch_size:
            _total_batch_size = self.micro_batch_size * self.accumulation_iters  # the effective batch size
            micro_batch_size_candidate = self.max_micro_batch_size
            while (_total_batch_size % (micro_batch_size_candidate * self.accumulation_iters * self.devices) != 0 ) and micro_batch_size_candidate > 1:
                # TODO: a harder check to see micro batch size value is not too low
                # For max. GPU utilization, more sensible to increase _total_batch_size slightly
                micro_batch_size_candidate -= 1
            self.micro_batch_size = micro_batch_size_candidate
            self.accumulation_iters = int(_total_batch_size / (self.micro_batch_size * self.devices))

        if self.deepseek_hparams:
            model_scale = (
                72 * self.model_config.n_layer * (self.model_config.config.n_embd**2)
                + 12 * self.model_config.n_layer * self.model_config.d_model * self.model_config.block_size
            )
            if self.max_tokens:
                compute = model_scale * self.max_tokens
            elif self.tokens_per_param and self.trainable_params:
                compute = model_scale * self.tokens_per_param * self.trainable_params
            else:
                raise ValueError("An error has accured during DeepSeek's compute calculation")

            optim_deepseek_lr = 0.3119 * (compute**-0.125)
            optim_deepseek_effective_batch_size = 0.2920 * (compute**0.3271)

            self.max_lr = optim_deepseek_lr
            self.accumulation_iters = round(
                optim_deepseek_effective_batch_size
                / (self.devices * self.micro_batch_size * self.model_config.block_size)
            )

            # Just a check for when the rounding is too low
            if self.accumulation_iters <= 0:
                self.accumulation_iters = 1

        self.train_steps = resolve_train_steps(
            max_tokens=self.max_tokens,
            max_train_steps=self.max_train_steps,
            tokens_per_param=self.tokens_per_param,
            micro_batch_size=self.micro_batch_size,
            block_size=self.block_size,
            trainable_params=self.trainable_params,
            accumulation_iters=self.accumulation_iters,
            devices=self.devices,
            deepseek_hparams=self.deepseek_hparams,
        )

        if (self.early_stopping_max_train_steps is None and
                self.early_stopping_max_tokens is None and
                self.early_stopping_tokens_per_param is None):
            self.early_stopping_train_steps = None
        else:
            self.early_stopping_train_steps = resolve_train_steps(
                max_tokens=self.early_stopping_max_tokens,
                max_train_steps=self.early_stopping_max_train_steps,
                tokens_per_param=self.early_stopping_tokens_per_param,
                micro_batch_size=self.micro_batch_size,
                block_size=self.block_size,
                trainable_params=self.trainable_params,
                accumulation_iters=self.accumulation_iters,
                devices=self.devices,
                deepseek_hparams=self.deepseek_hparams,
            )
            assert self.early_stopping_train_steps <= self.train_steps, (
                f"Early stopping steps should not be more than the total training steps. "
                f"Early stopping steps are {self.early_stopping_train_steps} and training steps are {self.train_steps}."
            )

        # Create LR scheduler args
        self.lr_scheduler_args = dict(
            max_steps=self.train_steps,
            warmup_fraction=self.warmup_fraction,
            scheduler_type=self.scheduler_type,
            scheduler_args=self.scheduler_args,
            scheduler_lr_decay_factor=self.scheduler_lr_decay_factor,
            cooldown_fraction=self.cooldown_fraction,
            cooldown_type=self.cooldown_type,
            cooldown_lr_decay_factor=self.cooldown_lr_decay_factor
        )

        # Adjust checkpoint setting
        if self.save_total_k is not None:
            # override the save_every_k and update_every_k
            self.save_every_k = int(self.train_steps // self.save_total_k)
            self.update_every_k = None

        if self.max_lr is None:
            raise ValueError("`max_lr` should not be `None`")

        self.tracked_metrics = {} if self.tracked_metrics is None else self.tracked_metrics

        self.logging_args = LoggingArgs(
            tracked_metrics=self.tracked_metrics,
            global_log_step=self.global_log_step,
            log_dir=None,
        )

        # resolve warmstart config
        assert self.warmstart_config is None or self.warmstart_config_path is None, \
            "Only one of `warmstart_config` or `warmstart_config_path` can be set."
        if self.warmstart_config_path:
            self.warmstart_config = WarmstartConfig(base_model_path=Path(__file__)).from_path(
                self.warmstart_config_path
            )
        else:
            self.warmstart_config = WarmstartConfig(**self.warmstart_config)

    @property
    def true_weight_decay(self) -> float:
        # TODO: account for optimizer groups?
        if self.independent_wd:
            return self.weight_decay / self.max_lr
        return self.weight_decay

    @property
    def mup_base_shape(self) -> dict | None:
        if self._mup_base_shape is not None:
            return self._mup_base_shape
        if self.mup_base_scales is None and self.mup_base_shape_path is not None:
            self._mup_base_shape = load_base_shapes(str(self.mup_base_shape_path))
            return self._mup_base_shape
        if isinstance(self.mup_base_scales, int):
            self.mup_base_scales = {"d_model": self.mup_base_scales}
        if isinstance(self.mup_base_scales, dict):
            self._mup_base_shape = get_mup_base_shape(self.model_config, self.mup_base_scales)

        return self._mup_base_shape

    @classmethod
    def from_yaml(cls, yaml_config: dict[str, Any], yaml_hook: Callable | None = None) -> TrainConfig:
        if yaml_hook is not None:
            yaml_config = yaml_hook(yaml_config)
        try:
            yaml_config["model_config"] = ConfigWrapper.from_yaml(yaml_config["model_config"])
        except TypeError:
            # Depending on if the train_config was saved with defaults or not
            # the model_config might have extra arguments
            yaml_config["model_config"] = ConfigWrapper.from_config(Config(**yaml_config["model_config"]))
        return cls(**yaml_config)


@dataclass
class PipelineConfig(BaseConfig):
    data_config_path: Path | None = None
    train_config_path: Path | None = None
    eval_config_path: Path | None = None

    data_config: DataHandler | None = None
    train_config: TrainConfig | None = None
    eval_config: EvalHandler | None = None

    def __post_init__(self) -> None:
        super().__post_init__()
        if self.data_config is None and self.data_config_path and self.data_config_path.exists():
            self.data_config = DataHandler.from_path(path=self.data_config_path)
        if self.train_config is None and self.train_config_path and self.train_config_path.exists():
            self.train_config = TrainConfig.from_path(path=self.train_config_path)
        if self.eval_config is None and self.eval_config_path and self.eval_config_path.exists():
            self.eval_config = EvalHandler.from_path(path=self.eval_config_path)

        assert self.data_config is not None
        assert self.train_config is not None
        self.data_config.block_size = self.train_config.block_size

    @classmethod
    def from_yaml(cls, yaml_config: dict[str, Any], yaml_hook: Callable | None = None) -> PipelineConfig:
        yaml_config["data_config"] = DataHandler.from_yaml(yaml_config["data_config"])
        yaml_config["train_config"] = TrainConfig.from_yaml(yaml_config["train_config"], yaml_hook)
        if yaml_config.get("eval_config"):
            yaml_config["eval_config"] = EvalHandler.from_yaml(yaml_config["eval_config"])
        else:
            yaml_config["eval_config"] = None
        return cls(**yaml_config)


if __name__ == "__main__":
    pass