# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from dataclasses import dataclass, field
from typing import Literal, Optional, Union

import torch
import torch.distributed as dist

from swift import get_logger
from swift.utils import broadcast_string, get_dist_setting, is_dist

logger = get_logger()


@dataclass
class AnimateDiffArguments:
    motion_adapter_id_or_path: Optional[str] = None
    motion_adapter_revision: Optional[str] = None

    model_id_or_path: str = None
    model_revision: str = None

    dataset_sample_size: int = None

    sft_type: str = field(
        default='lora', metadata={'choices': ['lora', 'full']})

    output_dir: str = 'output'
    ddp_backend: str = field(
        default='nccl', metadata={'choices': ['nccl', 'gloo', 'mpi', 'ccl']})

    seed: int = 42

    lora_rank: int = 8
    lora_alpha: int = 32
    lora_dropout_p: float = 0.05
    lora_dtype: Literal['fp16', 'bf16', 'fp32', 'AUTO'] = 'fp32'

    gradient_checkpointing: bool = False
    batch_size: int = 1
    num_train_epochs: int = 1
    # if max_steps >= 0, override num_train_epochs
    max_steps: int = -1
    learning_rate: Optional[float] = None
    weight_decay: float = 0.01
    gradient_accumulation_steps: int = 16
    max_grad_norm: float = 1.
    lr_scheduler_type: str = 'cosine'
    warmup_ratio: float = 0.05

    eval_steps: int = 50
    save_steps: Optional[int] = None
    dataloader_num_workers: int = 1

    push_to_hub: bool = False
    # 'user_name/repo_name' or 'repo_name'
    hub_model_id: Optional[str] = None
    hub_private_repo: bool = False
    push_hub_strategy: str = field(
        default='push_best',
        metadata={'choices': ['push_last', 'all_checkpoints']})
    # None: use env var `MODELSCOPE_API_TOKEN`
    hub_token: Optional[str] = field(
        default=None,
        metadata={
            'help':
            'SDK token can be found in https://modelscope.cn/my/myaccesstoken'
        })

    ignore_args_error: bool = False  # True: notebook compatibility

    text_dropout_rate: float = 0.1

    validation_prompts_path: str = field(
        default=None,
        metadata={
            'help':
            'The validation prompts file path, use llm/configs/ad_validation.txt is None'
        })

    trainable_modules: str = field(
        default='.*motion_modules.*',
        metadata={
            'help':
            'The trainable modules, by default, the .*motion_modules.* will be trained'
        })

    mixed_precision: bool = True

    enable_xformers_memory_efficient_attention: bool = True

    num_inference_steps: int = 25
    guidance_scale: float = 8.
    sample_size: int = 256
    sample_stride: int = 4
    sample_n_frames: int = 16

    csv_path: str = None
    video_folder: str = None

    motion_num_attention_heads: int = 8
    motion_max_seq_length: int = 32
    num_train_timesteps: int = 1000
    beta_start: int = 0.00085
    beta_end: int = 0.012
    beta_schedule: str = 'linear'
    steps_offset: int = 1
    clip_sample: bool = False

    use_wandb: bool = False

    def __post_init__(self) -> None:
        handle_compatibility(self)

        current_dir = os.path.dirname(__file__)
        if self.validation_prompts_path is None:
            self.validation_prompts_path = os.path.join(
                current_dir, 'configs/animatediff', 'validation.txt')
        if self.learning_rate is None:
            self.learning_rate = 1e-4
        if self.save_steps is None:
            self.save_steps = self.eval_steps

        if is_dist():
            rank, local_rank, _, _ = get_dist_setting()
            torch.cuda.set_device(local_rank)
            self.seed += rank  # Avoid the same dropout
            # Initialize in advance
            if not dist.is_initialized():
                dist.init_process_group(backend=self.ddp_backend)
            # Make sure to set the same output_dir when using DDP.
            self.output_dir = broadcast_string(self.output_dir)


@dataclass
class AnimateDiffInferArguments:

    motion_adapter_id_or_path: Optional[str] = None
    motion_adapter_revision: Optional[str] = None

    model_id_or_path: str = None
    model_revision: str = None

    sft_type: str = field(
        default='lora', metadata={'choices': ['lora', 'full']})

    ckpt_dir: Optional[str] = field(
        default=None, metadata={'help': '/path/to/your/vx-xxx/checkpoint-xxx'})
    eval_human: bool = False  # False: eval val_dataset

    seed: int = 42

    # other
    ignore_args_error: bool = False  # True: notebook compatibility

    validation_prompts_path: str = None

    output_path: str = './generated'

    enable_xformers_memory_efficient_attention: bool = True

    num_inference_steps: int = 25
    guidance_scale: float = 7.5
    sample_size: int = 256
    sample_stride: int = 4
    sample_n_frames: int = 16

    motion_num_attention_heads: int = 8
    motion_max_seq_length: int = 32
    num_train_timesteps: int = 1000
    beta_start: int = 0.00085
    beta_end: int = 0.012
    beta_schedule: str = 'linear'
    steps_offset: int = 1
    clip_sample: bool = False

    merge_lora: bool = False
    replace_if_exists: bool = False

    # compatibility. (Deprecated)
    merge_lora_and_save: Optional[bool] = None

    def __post_init__(self) -> None:
        handle_compatibility(self)


def handle_compatibility(
        args: Union[AnimateDiffArguments, AnimateDiffInferArguments]) -> None:
    if isinstance(args, AnimateDiffInferArguments):
        if args.merge_lora_and_save is not None:
            args.merge_lora = args.merge_lora_and_save
