import logging
import math
import os
import pathlib
import sys
from dataclasses import dataclass, field
from datetime import timedelta
from typing import Any, List, Optional, Sequence, Tuple
from omegaconf import DictConfig, OmegaConf
import dinov3.distributed as distributed
from dinov3.logging import cleanup_logging, setup_logging
from dinov3.utils import fix_random_seeds, get_conda_env, get_sha
logger = logging.getLogger("dinov3")
@dataclass
class DinoV3SetupArgs:
    config_file: str
    pretrained_weights: str | None = None
    shard_unsharded_model: bool = False
    output_dir: str = ""
    opts: List[Any] = field(default_factory=lambda: [])
    def __post_init__(self):
        if OmegaConf.is_config(self.opts):
            self.opts = OmegaConf.to_object(self.opts)
def apply_scaling_rules_to_cfg(cfg):
    assert distributed.is_enabled(), "Setup distributed to get global size !"
    if "schedules" in cfg:
        return cfg
    if cfg.optim.scaling_rule == "linear_wrt_256":
        old_lr = cfg.optim.lr
        cfg.optim.lr *= cfg.train.batch_size_per_gpu * distributed.get_world_size() / 256.0
        logger.info(f"linear scaling learning rate; old: {old_lr}, new: {cfg.optim.lr}")
    elif cfg.optim.scaling_rule == "sqrt_wrt_1024":
        old_lr = cfg.optim.lr
        cfg.optim.lr *= 4 * math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_world_size() / 1024.0)
        logger.info(f"sqrt scaling learning rate; old: {old_lr}, new: {cfg.optim.lr}")
    return cfg
def write_config(cfg, output_dir, name="config.yaml"):
    logger.info(OmegaConf.to_yaml(cfg))
    output_dir = os.path.abspath(output_dir)
    saved_cfg_path = os.path.join(output_dir, name)
    with open(saved_cfg_path, "w") as f:
        OmegaConf.save(config=cfg, f=f)
    return saved_cfg_path
def get_default_config() -> DictConfig:
    p = pathlib.Path(__file__).parent / "ssl_default_config.yaml"
    return OmegaConf.load(p)
def get_cfg_from_args(args: DinoV3SetupArgs, multidistillation=False, strict=True):
    overrides = [*args.opts]
    if args.output_dir is not None:
        overrides.append(f"train.output_dir={os.path.realpath(args.output_dir)}")
    cfg = OmegaConf.load(args.config_file)
    opts_cfg = OmegaConf.from_cli(overrides)
    if multidistillation:
        cfg = OmegaConf.merge(cfg, opts_cfg)
    else:
        default_cfg = get_default_config()
        if strict:
            OmegaConf.set_struct(default_cfg, True)
        cfg = OmegaConf.merge(default_cfg, cfg, opts_cfg)
    return cfg
def setup_config(args: DinoV3SetupArgs, strict_cfg=True):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg_from_args(args, strict=strict_cfg)
    logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
    if args.output_dir is not None:
        write_config(cfg, args.output_dir)
    apply_scaling_rules_to_cfg(cfg)
    return cfg
def _enumerate_all_subgroup_ranks(all_subgroup_rank_spans: Sequence[Tuple[int, int]]):
    """Expands a specification of process subgroups from spans to enumerated ranks.
    Args:
       all_group_rank_spans: a sequence of rank spans (first rank, last rank),
           one for each process group. Example: ((0, 1), (2, 3), (4, 7)).
    """
    for first, last in all_subgroup_rank_spans:
        assert first <= last
    return tuple(tuple(range(first, last + 1)) for first, last in all_subgroup_rank_spans)
def setup_multidistillation(args: DinoV3SetupArgs):
    base_output_dir = args.output_dir
    os.makedirs(args.output_dir, exist_ok=True)
    base_cfg = OmegaConf.load(args.config_file)
    assert base_cfg.multidistillation.enabled
    global_batch_size = base_cfg.multidistillation.global_batch_size
    distributed.enable(overwrite=True)
    seed = getattr(args, "seed", 0)
    rank = distributed.get_rank()
    all_subgroup_rank_spans = tuple(
        (student.ranks_range[0], student.ranks_range[1] - 1) for student in base_cfg.multidistillation.students
    )
    all_subgroup_ranks = _enumerate_all_subgroup_ranks(all_subgroup_rank_spans)
    distributed.new_subgroups(all_subgroup_ranks)
    found = False
    for student in base_cfg.multidistillation.students:
        if rank in range(*student.ranks_range):
            found = True
            break
    assert found, "rank of worker not in defined range"
    name = student.name
    config_path = student.config_path
    n_gpus = student.ranks_range[1] - student.ranks_range[0]
    assert global_batch_size % n_gpus == 0
    total_n_gpus = distributed.get_world_size()
    args.output_dir = os.path.join(base_output_dir, name)
    args.opts += [f"train.output_dir={args.output_dir}"]
    args.opts += [f"train.batch_size_per_gpu={global_batch_size // total_n_gpus}"]
    args.config_file = os.path.abspath(config_path)
    default_cfg = get_default_config()
    cfg = OmegaConf.load(args.config_file)
    cfg = OmegaConf.merge(default_cfg, cfg, base_cfg, OmegaConf.from_cli(args.opts))
    global logger
    setup_logging(output=args.output_dir, level=logging.INFO)
    fix_random_seeds(seed + rank)
    write_config(cfg, args.output_dir)
    apply_scaling_rules_to_cfg(cfg)
    return cfg
def setup_job(
    output_dir: Optional[str] = None,
    distributed_enabled: bool = True,
    logging_enabled: bool = True,
    seed: Optional[int] = 0,
    restrict_print_to_main_process: bool = True,
    distributed_timeout: timedelta | None = None,
):
    """
    Setup methods that should be done in every fairvit job
    Initializes logging, distributed, random seeds and other utilities.
    """
    if output_dir is not None:
        output_dir = os.path.realpath(output_dir)
        os.makedirs(output_dir, exist_ok=True)
    if logging_enabled:
        setup_logging(
            output=output_dir,
            level=logging.INFO,
            log_to_stdout_only_in_main_process=restrict_print_to_main_process,
        )
    if distributed_enabled:
        distributed.enable(
            overwrite=True,
            nccl_async_error_handling=True,
            restrict_print_to_main_process=restrict_print_to_main_process,
            timeout=distributed_timeout,
        )
    if seed is not None:
        rank = distributed.get_rank()
        fix_random_seeds(seed + rank)
    logger = logging.getLogger("dinov3")
    logger.info("git:\n  {}\n".format(get_sha()))
    conda_env_name, conda_env_path = get_conda_env()
    logger.info(f"conda env name: {conda_env_name}")
    logger.info(f"conda env path: {conda_env_path}")
    logger.info(f"python path: {sys.path}")
def exit_job(distributed_enabled: bool = True, logging_enabled: bool = True):
    if distributed_enabled:
        distributed.disable()
    if logging_enabled:
        cleanup_logging()
