import multiprocessing
import os
import pickle
import time
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, TypeVar

import mlflow

T = TypeVar("T")

from pickle import Unpickler

import torch
from joblib import delayed, Parallel
from tqdm import tqdm


def parallel_call_wrapper(
    func: Callable[[Any], T],
    func_kwargs_list: List[Dict[str, Any]],
    backend="multiprocessing",
    n_parallel=None,
) -> List[T]:
    if n_parallel is None:
        n_parallel = max(1, _get_affinity() - 1)  # the set of accessible CPUs
    if n_parallel == 1:  # do not run in parallel -> this will make debugging easier
        return [func(**func_kwargs) for func_kwargs in tqdm(func_kwargs_list)]  # type: ignore[call-arg]
    return Parallel(n_jobs=n_parallel, backend=backend)(
        delayed(func)(**func_kwargs) for func_kwargs in tqdm(func_kwargs_list)
    )  # type: ignore[call-arg]


def _get_affinity() -> int:
    """
    os.get_affinity is not available on all platforms,
    so we use multiprocessing.cpu_count as a fallback.
    """
    try:
        return len(os.sched_getaffinity(0))  # type: ignore[attr-defined]
    except AttributeError:
        return multiprocessing.cpu_count()


def copy_tensor(x: torch.Tensor) -> torch.Tensor:
    return torch.empty_like(x).copy_(x)


def filter_kwargs(func: Callable, **kwargs: Dict[str, Any]) -> Dict[str, Any]:
    r"""Given a function, select only the arguments that are applicable.

    Returns:
        Dict[str, Any]: A dictionary containing only the arguments that are applicable.
    """
    return {k: v for k, v in kwargs.items() if k in signature(func).parameters}


def save_w_pickle(obj: Any, path: str, filename: Optional[str] = None) -> None:
    """Save object obj in file exp_path/filename.pkl"""
    if filename is None:
        filename = "data.pkl"
        path = os.path.dirname(path)
    if len(filename) < 4 or filename[-4:] != ".pkl":
        filename += ".pkl"
    os.makedirs(path, exist_ok=True)
    with open(os.path.join(path, filename), "wb") as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)


def load_w_pickle(path: str, filename: Optional[str] = None) -> Any:
    """Load object from file exp_path/filename.pkl"""
    if filename is None:
        filename = os.path.basename(path)
        path = os.path.dirname(path)
    if len(filename) < 4 or filename[-4:] != ".pkl":
        filename += ".pkl"
    p = os.path.join(path, filename)
    with open(p, "rb") as fd:
        try:
            total = os.path.getsize(p)
            with TQDMBytesReader(fd, total=total) as pbfd:
                up = Unpickler(pbfd)
                obj = up.load()
            return obj
        except EOFError:
            raise Exception(f"EOFError with {p}")
        except UnicodeDecodeError:
            raise Exception(f"UnicodeDecodeError with {p}")
        except pickle.UnpicklingError:
            raise Exception(f"UnpicklingError with {p}")


def safe_load_W_pickle(
    path: str, filename: Optional[str] = None, n_trials=3, time_sleep=2
) -> Any:
    """Make several attempts to load an object from file exp_path/filename.pkl"""
    trial = 0
    end = False
    result = None
    while not end:
        try:
            result = load_w_pickle(path=path, filename=filename)
            end = True
        except (pickle.UnpicklingError, EOFError) as e:
            trial += 1
            if trial >= n_trials:
                raise e
            time.sleep(time_sleep)
        except UnicodeDecodeError as e:
            if filename is None:
                filename = os.path.basename(path)
                path = os.path.dirname(path)
            print(os.path.join(path, filename))
            raise e
    return result


class TQDMBytesReader(object):
    def __init__(self, fd, **kwargs):
        self.fd = fd
        from tqdm import tqdm

        self.tqdm = tqdm(**kwargs)

    def read(self, size=-1):
        bytes = self.fd.read(size)
        self.tqdm.update(len(bytes))
        return bytes

    def readline(self):
        bytes = self.fd.readline()
        self.tqdm.update(len(bytes))
        return bytes

    def __enter__(self):
        self.tqdm.__enter__()
        return self

    def __exit__(self, *args, **kwargs):
        return self.tqdm.__exit__(*args, **kwargs)


def time_formatter(t: float, show_ms: bool = False):
    """Convert a duration in seconds to a str `dd::hh::mm::ss`

    Args:
        t (float): time in seconds
        show_ms (bool, optional): Whether tp show ms on top of dd:hh:mm:ss. Defaults to False.
    """
    n_day = time.gmtime(t).tm_yday - 1
    if n_day > 0:
        ts = time.strftime(f"%H:%M:%S", time.gmtime(t))
        ts = f"{n_day}:{ts}"
    else:
        ts = time.strftime(f"%H:%M:%S", time.gmtime(t))
    if show_ms:
        ts += f"{t - int(t):.3f}".replace("0.", ".")
    return ts


def success_count():
    pass


def suceess_rate():
    pass


def average_size():
    pass


def average_size_reduction():
    pass


def token_match_ration():
    pass
