import ctypes
import os
import shutil
import subprocess
from typing import Optional

import cloudpickle
import lightning.pytorch as pl
import pandas as pd
import torch
from hydra.utils import get_class
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from omegaconf import DictConfig

FILE_ATTRIBUTE_REPARSE_POINT = 0x0400


@rank_zero_only
def rank_zero_print(*args, **kwargs):
    """
    Print only if the current process is rank zero. No-op for processes with rank > 0 in distributed settings.

    Args:
        *args: Variable-length positional arguments to pass to Python's built-in `print`.
        **kwargs: Variable-length keyword arguments to pass to Python's built-in `print`.
    """
    print(*args, **kwargs)


def rm(path: str) -> None:
    """
    Remove a file, directory, or symbolic link/junction at the specified path.

    This function handles:
      - Regular files.
      - Directories (recursively removed unless it's a directory junction).
      - Symbolic links or directory junctions (unlinked instead of recursively removed).

    Args:
        path (str): Path to the file or directory to remove.

    Returns:
        None
    """
    if not os.path.exists(path):
        return

    if os.path.isdir(path):
        # Check if it is a symbolic link or directory junction
        if os.path.islink(path) or is_junction(path):
            os.unlink(path)
        else:
            shutil.rmtree(path)
    else:
        # Regular file or symbolic link
        if os.path.islink(path):
            os.unlink(path)
        else:
            os.remove(path)


def is_junction(path: str) -> bool:
    """
    Check if a given directory path on Windows is a directory junction.
    If not on Windows or if the path does not exist, returns False.

    Args:
        path (str): Path to check.

    Returns:
        bool: True if the path is a directory junction, False otherwise.
    """
    if os.path.isdir(path):
        attrs = ctypes.windll.kernel32.GetFileAttributesW(path)
        if attrs == -1:
            # Path may not exist or another error occurred
            return False
        return bool(attrs & FILE_ATTRIBUTE_REPARSE_POINT)
    return False


def create_latest_symlink(log_dir: str) -> None:
    """
    Create or replace a symbolic link or directory junction named 'latest' that points to the given log_dir.

    On non-Windows platforms, this creates a symbolic link. On Windows, if the symbolic link
    approach fails, it falls back to creating a directory junction using 'mklink /J'.

    Args:
        log_dir (str): The directory path that the link should point to.

    Returns:
        None
    """
    try:
        parent_dir = os.path.dirname(log_dir)
        latest_link = os.path.join(parent_dir, 'latest')
        rm(latest_link)  # Remove existing link/junction if any
        os.symlink(log_dir, latest_link)
    except OSError:
        # If symbolic link creation fails (often on Windows without admin privileges),
        # create a junction via cmd.
        parent_dir = os.path.dirname(log_dir)
        latest_link = os.path.join(parent_dir, 'latest')
        rm(latest_link)
        latest_link = os.path.normpath(latest_link)
        log_dir = os.path.normpath(log_dir)
        subprocess.check_call([
            'cmd', '/c',
            'mklink', '/J',
            latest_link,
            log_dir
        ])


def setup_logger(model, cfg: DictConfig = None, log_dir=None, name=None) -> TensorBoardLogger:
    if log_dir is None:
        log_dir = os.path.join(cfg.root_dir, cfg.dataset.name)
    name = model.name if name is None else name
    return TensorBoardLogger(save_dir=log_dir, name=name)


def maybe_save(df: pd.DataFrame, fname: str, output_directory: Optional[str] = None) -> None:
    """
    Save a DataFrame as a CSV file if `output_directory` is provided.

    Args:
        df (pd.DataFrame): The DataFrame to save.
        fname (str): The file name to use.
        output_directory (str, optional): The directory in which to save the file.
                                          If None, the function does nothing.
    """
    if output_directory is not None:
        path = os.path.join(output_directory, fname)
        df.to_csv(path, index=False)


def maybe_load(path):
    try:
        return pd.read_csv(path)
    except:
        return None


def setup_model_from_checkpoint(ckpt):
    # Load hyperparameters from the checkpoint
    checkpoint = torch.load(ckpt, weights_only=False)
    cfg = checkpoint['hyper_parameters']
    # Instantiate and load the model
    model_cls = get_class(cfg['_target_'])
    model = model_cls.load_from_checkpoint(ckpt)
    return model, cfg


def save_model_no_lightning(model, save_path):
    try:
        obj = {
            "state_dict": model.state_dict(),
            "hyper_parameters": model.hparams,
            "pytorch-lightning_version": pl.__version__,
        }
        if hasattr(model, 'on_save_checkpoint'):
            model.on_save_checkpoint(obj)
        torch.save(obj, save_path,
                   pickle_module=cloudpickle,
                   pickle_protocol=4)
    except Exception as e:
        print(e)
