import gc
import threading
import warnings
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
warnings.filterwarnings("ignore", message="Some weights of Wav2Vec2Model")


def get_gpu_count(max_gpus=None):
    """
    Get the number of available GPUs.
    :param max_gpus: maximal number of GPUs to utilize.
    """
    ngpu = torch.cuda.device_count()
    if max_gpus is not None:
        ngpu = min(ngpu, max_gpus)
    return ngpu


def clear_gpu_memory():
    """
    Enhanced GPU memory clearing
    """
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            with torch.cuda.device(i):
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
        gc.collect()
        torch.cuda.empty_cache()


def get_gpu_memory_info(verbose=False):
    """
    Get GPU memory info.
    :param verbose: if True, get info.
    """
    if not verbose:
        return
    for i in range(torch.cuda.device_count()):
        try:
            free_b, total_b = torch.cuda.mem_get_info(i)
            free_gb = free_b / 1024**3
            total_gb = total_b / 1024**3
        except Exception:
            total_gb = torch.cuda.get_device_properties(i).total_memory / 1024**3
            free_gb = total_gb - (torch.cuda.memory_reserved(i) / 1024**3)
        mem_allocated = torch.cuda.memory_allocated(i) / 1024**3
        print(f"GPU {i}: {mem_allocated:.2f}GB allocated, {free_gb:.2f}GB free / {total_gb:.2f}GB total")


class GPUWorkDistributor:
    """
    Distribute GPU memory into multiple GPUs.
    """
    def __init__(self, max_gpus=None):
        ngpu = get_gpu_count(max_gpus)
        self.gpu_locks = [threading.Lock() for _ in range(max(1, min(ngpu, 2)))]
        self.gpu_load = [0 for _ in range(max(1, min(ngpu, 2)))]
        self.ngpu = ngpu

    def get_least_loaded_gpu(self):
        return int(np.argmin(self.gpu_load))

    def execute_on_gpu(self, func, *args, **kwargs):
        if self.ngpu == 0:
            kwargs.pop("device", None)
            return func(*args, **kwargs)
        gid = self.get_least_loaded_gpu()
        with self.gpu_locks[gid]:
            self.gpu_load[gid] += 1
            try:
                with torch.cuda.device(gid):
                    kwargs["device"] = f"cuda:{gid}"
                    result = func(*args, **kwargs)
                    torch.cuda.empty_cache()
                return result
            finally:
                self.gpu_load[gid] -= 1


@dataclass
class Mixture:

    mixture_id: str
    refs: list[Path]
    systems: dict[str, list[Path]]
    speaker_ids: list[str]


def canonicalize_mixtures(mixtures, systems=None):
    canon = []
    if not mixtures:
        return canon

    def as_paths(seq):
        return [p if isinstance(p, Path) else Path(str(p)) for p in seq]

    def speaker_id_from_ref(ref_path, idx, mixture_id):
        stem = (ref_path.stem or "").strip()
        if not stem:
            stem = f"spk{idx:02d}"
        return f"{mixture_id}__{stem}"

    if isinstance(mixtures[0], dict):
        for m in mixtures:
            mid = str(m.get("mixture_id") or m.get("id") or "").strip()
            if not mid:
                raise ValueError("Each mixture must include 'mixture_id'.")
            refs = as_paths(m.get("references", []))
            if not refs:
                raise ValueError(f"Mixture {mid}: 'references' must be non-empty.")
            sysmap = {}
            if isinstance(m.get("systems"), dict):
                for algo, outs in m["systems"].items():
                    sysmap[str(algo)] = as_paths(outs)
            spk_ids = [speaker_id_from_ref(r, i, mid) for i, r in enumerate(refs)]
            canon.append(Mixture(mid, refs, sysmap, spk_ids))
        return canon

    if isinstance(mixtures[0], list):
        for i, group in enumerate(mixtures):
            mid = f"mix_{i:03d}"
            refs, spk_ids = [], []
            for d in group:
                if not isinstance(d, dict) or "ref" not in d or "id" not in d:
                    raise ValueError(
                        "Legacy mixtures expect dicts with 'id' and 'ref'."
                    )
                rp = Path(d["ref"])
                refs.append(rp)
                spk_ids.append(f"{mid}__{str(d['id']).strip()}")
            sysmap = {}
            if systems:
                for algo, per_mix in systems.items():
                    if mid in per_mix:
                        sysmap[algo] = as_paths(per_mix[mid])
            canon.append(Mixture(mid, refs, sysmap, spk_ids))
        return canon

    raise ValueError("Unsupported 'mixtures' format.")

def safe_cov_torch(X):
    """
    Compute the covariance matrix of X.
    :param X: array to compute covariance matrix of.
    :return: regularized covariance matrix.
    """
    Xc = X - X.mean(dim=0, keepdim=True)
    cov = Xc.T @ Xc / (Xc.shape[0] - 1)
    if torch.linalg.matrix_rank(cov) < cov.shape[0]:
        cov += torch.eye(cov.shape[0], device=cov.device) * 1e-6
    return cov


def mahalanobis_torch(x, mu, inv):
    """
    Compute the mahalanobis distance with x centered around mu with inverse covariance matrix inv.
    :param x: point to calculates distance from.
    :param mu: x is centered around mu.
    :param inv: the inverse covariance matrix.
    :return: Mahalanobis distance.
    """
    diff = x - mu
    diff_T = diff.transpose(-1, -2) if diff.ndim >= 2 else diff
    return torch.sqrt(diff @ inv @ diff_T + 1e-6)