# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import sys
from typing import TYPE_CHECKING, List, Optional, Tuple

import torch.nn as nn
from peft import PeftModel
from transformers import Trainer

from swift.trainers.optimizers.galore import create_optimizer_and_scheduler
from swift.utils import get_dist_setting

if TYPE_CHECKING:
    from swift.trainers import TrainingArguments


def calculate_max_steps(args: "TrainingArguments", dataset) -> int:
    if args.max_steps and args.max_steps > 0:
        max_steps = args.max_steps
    else:
        len_dataset = len(dataset)
        _, _, world_size, _ = get_dist_setting()
        total_train_batch_size = (
            args.per_device_train_batch_size
            * args.gradient_accumulation_steps
            * world_size
        )
        num_update_steps_per_epoch = len_dataset // total_train_batch_size
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
    return max_steps


def create_galore_optimizer(args: "TrainingArguments", model, dataset):
    training_steps = calculate_max_steps(args, dataset)
    optimizer, lr_scheduler = create_optimizer_and_scheduler(
        model,
        args,
        args.galore_config,
        training_steps,
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
    )
    # trainer cannot serialize galore_config
    args.galore_config = None
    return optimizer, lr_scheduler


def create_lorap_optimizer(args: "TrainingArguments", model, dataset):
    optimizer_grouped_parameters = None
    if hasattr(model, "create_optimizer_param_groups"):
        # Lora+ parameter groups
        optimizer_grouped_parameters = model.create_optimizer_param_groups(
            lr=args.learning_rate, weight_decay=args.weight_decay
        )

    if optimizer_grouped_parameters is None:
        # Default parameter groups
        decay_parameters = Trainer.get_decay_parameter_names(None, model)
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if (n in decay_parameters and p.requires_grad)
                ],
                "weight_decay": args.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if (n not in decay_parameters and p.requires_grad)
                ],
                "weight_decay": 0.0,
            },
        ]
    optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
    return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs), None


def create_muon_optimizer(args: "TrainingArguments", model, dataset):
    from swift.llm import git_clone_github, get_model_arch

    if not args.local_repo_path:
        args.local_repo_path = git_clone_github(
            "https://github.com/MoonshotAI/Moonlight.git"
        )
    sys.path.append(os.path.join(args.local_repo_path, "examples"))
    from toy_train import Muon

    # parse args.optim_args
    optim_args = {}
    if args.optim_args:
        for mapping in args.optim_args.replace(" ", "").split(","):
            key, value = mapping.split("=")
            optim_args[key] = value

    model_arch = get_model_arch(model.model_meta.model_arch)
    embed_key = model_arch.embedding or "embed_tokens"
    lm_head_key = model_arch.lm_head or "lm_head"
    muon_params = [
        p
        for n, p in model.named_parameters()
        if p.requires_grad
        and p.ndim >= 2
        and embed_key not in n
        and lm_head_key not in n
    ]
    adamw_params = [
        p
        for n, p in model.named_parameters()
        if p.requires_grad
        and not (p.ndim >= 2 and embed_key not in n and lm_head_key not in n)
    ]

    return (
        Muon(
            lr=args.learning_rate,
            wd=args.weight_decay,
            muon_params=muon_params,
            adamw_params=adamw_params,
            adamw_betas=(args.adam_beta1, args.adam_beta2),
            adamw_eps=args.adam_epsilon,
            **optim_args,
        ),
        None,
    )


def get_param_startswith(
    model, chosen_prefix: List[str], rejected_prefix: Optional[List[str]] = None
) -> List[Tuple[str, nn.Parameter]]:
    chosen_prefix = chosen_prefix or []
    rejected_prefix = rejected_prefix or []
    res = []
    if not chosen_prefix:
        return res
    is_peft_model = isinstance(model, PeftModel)
    if is_peft_model:
        model = model.model
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        is_rejected = False
        for prefix in rejected_prefix:
            if n.startswith(prefix):
                is_rejected = True
                break
        if is_rejected:
            continue
        for prefix in chosen_prefix:
            if n.startswith(prefix):
                if is_peft_model:
                    n = f"base_model.model.{n}"
                res.append((n, p))
                break
    return res


def create_multimodal_optimizer(args: "TrainingArguments", model, dataset):
    """ViT/Aligner/LLM use different learning rates."""
    from swift.llm import get_model_arch

    decay_parameters = set(Trainer.get_decay_parameter_names(None, model))
    model_arch = get_model_arch(model.model_meta.model_arch)
    vit_parameters = get_param_startswith(
        model, model_arch.vision_tower, model_arch.aligner
    )
    aligner_parameters = get_param_startswith(model, model_arch.aligner)
    llm_parameters = get_param_startswith(model, model_arch.language_model)
    optimizer_grouped_parameters = []
    for lr, parameters in zip(
        [args.vit_lr, args.aligner_lr, args.learning_rate],
        [vit_parameters, aligner_parameters, llm_parameters],
    ):
        if lr is None:
            lr = args.learning_rate
        for wd in [0.0, args.weight_decay]:
            if wd == 0:
                params = [p for n, p in parameters if n not in decay_parameters]
            else:
                params = [p for n, p in parameters if n in decay_parameters]
            if not params:
                continue
            optimizer_grouped_parameters.append(
                {
                    "params": params,
                    "weight_decay": args.weight_decay,
                    "lr": lr,
                }
            )
    optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args, model)
    return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs), None


# Add your own optimizers here, use --optimizer xxx to train
optimizers_map = {
    "galore": create_galore_optimizer,
    "lorap": create_lorap_optimizer,
    "muon": create_muon_optimizer,
    "multimodal": create_multimodal_optimizer,
}
