import os, re, collections, itertools, inspect
from pathlib import Path
from typing import Optional, Union
from collections.abc import Sequence
import torch


def pad_seq(sequences: list[torch.Tensor], pad):
    return torch.nn.utils.rnn.pad_sequence(
        sequences, batch_first=True, padding_value=pad
    )


def using_multiple_devices(devices: Union[int, list], num_nodes: int):
    if num_nodes > 1:
        return True
    if isinstance(devices, int):
        return devices > 1
    elif issubclass(type(devices), Sequence):
        return len(devices) > 1
    elif isinstance(devices, str):
        return len(devices.split(",")) > 1


def reduce_checkpoints(
    checkpoint_directory: str,  # Where to collect and reduce the checkpoints
    task_name: str,  # Collect and reduce checkpoints associated with the task
    expected_num_checkpoints: int = None,
    # If not enough checkpoints detected, skip the reduction. Useful when using multi training processes along N times of training.
    rank_by_metrics_avg: bool = True,
    # When there are multiple metrics, average or take first of metrics as ranking key
    exclude_keys: list = ("seed",),
    # We extract all <key>=<value> substrings as metrics, but you can specify which keys are not metrics
    show_max=True,
    show_median=False,
    show_mean=False,
    show_std=False,
):
    """
    Name of checkpoint could be as the following format:
    "{task name}_seed={random seed}_{metric name}={metric score}}_{more metric name}={more metric score}.ckpt"
    Note metric name should not have '_' and metric score should have '.' inside for the regex expression to work.
    """
    # Collect checkpoints and check for number of them
    ckpts: list[str] = [f for f in os.listdir(checkpoint_directory) if task_name in f]
    if expected_num_checkpoints and len(ckpts) != expected_num_checkpoints:
        return

    # Collect metric names and scores
    score_matrix = []
    for ckpt in ckpts:
        matches = re.findall(
            r"_([^_]+?)=([\.\d]+\d|nan|-?inf)", ckpt
        )  # list[tuple[str,str]]
        matches = [m for m in matches if m[0] not in exclude_keys]
        metric_names, scores = list(zip(*matches))  # list[str], list[str]
        score_matrix.append([float(s) for s in scores])
    score_matrix = torch.tensor(score_matrix)  # (#checkpoints, #metrics)
    if (~score_matrix.isfinite()).sum() > 0:
        print("Some scores aren't finite.")
        return
    if (score_matrix > 1).sum() == 0:
        score_matrix *= 100  # converted to percent format

    # Ranking checkpoints
    sorting_key = (
        score_matrix.mean(dim=1) if rank_by_metrics_avg else score_matrix[:, 0]
    )
    sorted_ckpts = [ckpts[i] for i in sorting_key.argsort(descending=True)]

    # Rename the best checkpoint with score statistics
    ckpt_dir = Path(checkpoint_directory)
    new_name = best_ckpt = sorted_ckpts[0]
    ## Statistics of metric scores
    for i, metric_name in enumerate(metric_names):
        scores = score_matrix[:, i]
        statistics_strings = [
            f"max{scores.max():.2f}" if show_max else None,
            f"median{scores.median():.2f}" if show_median else None,
            f"mean{scores.mean():.2f}" if show_mean else None,
            f"std{scores.std():.2f}" if show_std else None,
        ]
        statistics_string = "-".join([s for s in statistics_strings if s])
        metric_string = f"_{metric_name}={statistics_string}"
        new_name = re.sub(
            f"_{re.escape(metric_name)}=([\.\d]+\d)", metric_string, new_name
        )
    ##
    ## Rename
    os.rename(ckpt_dir / best_ckpt, ckpt_dir / new_name)

    # Remove other checkpoints
    for ckpt in sorted_ckpts[1:]:
        os.remove(ckpt_dir / ckpt)


def select_kwargs_for(func, **input_candidates):
    # Pick inputs the functions needs based on its signature
    argument_names = inspect.signature(func).parameters
    inputs = {}
    for name, value in input_candidates.items():
        if name in argument_names:
            inputs[name] = value
    return inputs


def concat_all_outputs(outputs, is_distributed=True, batch_dim=0):
    if not is_distributed:
        return outputs

    assert isinstance(outputs, list)
    output = outputs[0]
    if isinstance(output, tuple):  # xxx_step return multiple objects
        return tuple(
            concat_all_outputs(all_device_elements, True)
            for all_device_elements in zip(*outputs)
        )
    elif isinstance(output, torch.Tensor):
        return torch.cat([tensor for tensor in outputs], dim=batch_dim)
    elif isinstance(output, collections.abc.Sequence):
        return list(itertools.chain(*outputs))
    else:
        return outputs


def recursive_getattr(object, name, *_default):
    has_default = False
    if _default:
        assert len(_default) == 1
        has_default = True
        default = _default[0]

    try:
        target = object
        for attr_name in name.split("."):
            target = getattr(target, attr_name)
    except AttributeError as e:
        if has_default:
            return default
        else:
            raise e

    return target
