#!/usr/bin/env python3

from __future__ import annotations

from typing import Literal

import os
import wandb
from wandb.sdk.wandb_run import Run


def wandb_init(
    project: str | None,    
    exp_name: str | None,
    wandb_id: bool | None,
    run_id: int,
    tags: list[str] | None = None,
    notes: str | None = None,
    mode: Literal['online', 'offline', 'disabled'] | None = None,
    resume_from_checkpoint: int | None = None,
    experiment_settings: dict | None = None,
) -> Run:
    r""" 
    Initialize a W&B run with the given parameters.

    Args:
        project (str | None): W&B project name.
        exp_name (str | None): Experiment name.
        wandb_id (bool | None): Use W&B ID for the run.
        run_id (int): Run ID.
        tags (list[str] | None): Tags. Defaults to None.
        notes (str | None): Notes. Defaults to None.
        mode (Literal['online', 'offline', 'disabled'] | None): W&B 
            mode. Defaults to None.
        resume_from_checkpoint (int | None): Checkpoint ID to resume 
            from. Defaults to None.
        experiment_settings (dict | None): Additional experiment 
            settings. Defaults to None.

    Returns:
        Run: The initialized W&B run object.
    """
    run_name = f"{exp_name}_run.{run_id}"

    # Base config
    base_config = {
        "exp_name": exp_name,
        "id": run_name if wandb_id else None,
        "run_id": run_id,
        "resume_from_checkpoint": resume_from_checkpoint
    }

    # Merge experiment settings if provided
    if experiment_settings:
        base_config.update(experiment_settings)

    run = wandb.init(
        project=project,
        group=exp_name,
        name=run_name,
        id=run_name if wandb_id else None,
        config=base_config,
        notes=notes,
        tags=tags,
        save_code=True,
        job_type="optimization",
        mode=mode,
        resume="allow" if resume_from_checkpoint is not None else None,
        reinit="finish_previous"
    )
    return run

def get_wandb_artifact(
    entity: str,
    project: str,
    exp_name: str,
    run_id: int,
    checkpoint_id: int,
) -> str:
    r""" 
    Access a W&B artifact for a specific run.

    Args:
        entity (str): W&B entity (username or team).
        project (str): W&B project name.
        exp_name (str): Experiment name.
        run_id (int): Run ID.
        checkpoint_id (int): Checkpoint ID.

    Returns:
        str: Path to the downloaded artifact.

    Note:
        Artifact name format: 
        `{exp_name}_run.{run_id}_ckpt_{checkpoint_id}.pt`
    """
    api = wandb.Api()
    artifact_name = f"{exp_name}_run.{run_id}_ckpt_{checkpoint_id}.pt"
    get_artifact_name = (
        f"{entity}/{project}/{artifact_name}"
    ) 
    artifact_dir = api.artifact(
        f"{get_artifact_name}:latest",
        type="state"
    ).download()
    return os.path.join(artifact_dir, artifact_name)


#### FUTURE ME: this need to change
#### when wandb_run_id is False
#### we dnt know the run_id
#### instead manually provide the state files
#### or download it all automatically 
#### then filter based on problem_name and method
def get_wandb_artifacts(
    project: str,
    exp_name: str,
    run_id: int,
    entity: str | None = None,
) -> list[str]:
    r""" 
    Download all artifacts for a specific run.

    Args:
        project (str): W&B project name.
        exp_name (str): Experiment name.
        run_id (int): Run ID.
        entity (str | None): W&B entity (username or team). 
            Defaults to None.

    Returns:
        list[str]: List of paths to the downloaded artifacts.

    Raises:
        RuntimeError: If no logged artifacts are found for the run.
    """
    
    api = wandb.Api()
    run_name = f"{project}/{exp_name}_run.{run_id}"
    if entity is not None:
        run_name = f"{entity}/{run_name}"
    run = api.run(run_name)
    path_list = []
    logged  = list(run.logged_artifacts())
    if not logged:
        raise RuntimeError(
            f"No logged artifacts found for run: {run_name}"
        )
    for art in run.logged_artifacts():
        # ckpt_0 is just initializations of the algorithm
        if art.type == "state" and "ckpt_0" not in art.name:
            artifact_dir = art.download()
            # assume single file inside
            f = next(iter(art.files()))
            path_list.append(os.path.join(artifact_dir, f.name))
    return path_list