import queue
import threading
import gc

import torch
import torch.nn.functional as F
from transformers import (
    HubertModel,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Model,
    WavLMModel,
)

from config import BATCH_SIZE, ENERGY_HOP_MS, ENERGY_WIN_MS, SR
from utils import get_gpu_count


class BalancedMultiGPUModel:
    """
    Distributes model inference workload across GPUs.
    """
    def __init__(self, model_name, layer, max_gpus=None):
        self.layer = layer
        self.models = []
        self.extractors = []
        self.devices = []
        ngpu = get_gpu_count(max_gpus)

        for gpu_id in range(ngpu):
            device = f"cuda:{gpu_id}"
            self.devices.append(device)
            ckpt, cls, _ = get_model_config(layer)[model_name]
            extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)

            attn_impl = "eager" if cls is WavLMModel else "sdpa"
            model = cls.from_pretrained(
                ckpt,
                output_hidden_states=True,
                use_safetensors=True,
                torch_dtype=torch.float16,
                low_cpu_mem_usage=True,
                attn_implementation=attn_impl
            )
            model.eval()
            model = model.to(device)

            for param in model.parameters():
                param.requires_grad = False

            self.extractors.append(extractor)
            self.models.append(model)

        self.gpu_queues = [queue.Queue() for _ in range(len(self.devices))]
        self.result_queue = queue.Queue()
        self.workers = []
        for i in range(len(self.devices)):
            worker = threading.Thread(target=self._gpu_worker, args=(i,))
            worker.daemon = True
            worker.start()
            self.workers.append(worker)

    def _gpu_worker(self, gpu_id):
        device = self.devices[gpu_id]
        model = self.models[gpu_id]
        extractor = self.extractors[gpu_id]
        while True:
            task = self.gpu_queues[gpu_id].get()
            if task is None:
                break
            signals, masks, use_mlm, task_id = task
            try:
                inputs = extractor(
                    signals, sampling_rate=SR, return_tensors="pt", padding=True
                )
                input_values = inputs.input_values.to(device, non_blocking=True)

                torch.cuda.empty_cache()

                orig_mode = model.training
                model.train() if use_mlm else model.eval()
                with torch.no_grad():
                    with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                        hs = model(
                            input_values, output_hidden_states=True
                        ).hidden_states[self.layer]
                model.train(orig_mode)

                B, T, D = hs.shape
                keep = []
                for b in range(B):
                    mask_b = masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
                    mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
                    keep.append(hs[b][mask_t].cpu())

                del hs, input_values, inputs
                torch.cuda.empty_cache()

                if keep:
                    L_max = max(x.shape[0] for x in keep)
                    keep_padded = [
                        F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in keep
                    ]
                    result = torch.stack(keep_padded, dim=0)
                else:
                    result = torch.empty(0, 0, 0)
                self.result_queue.put((task_id, result))
            except Exception as e:
                self.result_queue.put((task_id, e))
            finally:
                torch.cuda.empty_cache()

    def process_batch(self, signals, masks, use_mlm=False):
        if not signals:
            return torch.empty(0, 0, 0)
        batch_size = len(signals)
        split = (batch_size + len(self.devices) - 1) // len(self.devices)
        results = {}
        task_id = 0
        for i in range(0, batch_size, split):
            end = min(i + split, batch_size)
            gpu_id = (i // split) % len(self.devices)
            self.gpu_queues[gpu_id].put(
                (signals[i:end], masks[i:end], use_mlm, task_id)
            )
            task_id += 1
        for _ in range(task_id):
            tid, result = self.result_queue.get()
            if isinstance(result, Exception):
                raise result
            results[tid] = result
        parts = [results[i] for i in range(task_id) if results[i].numel() > 0]
        return torch.cat(parts, dim=0) if parts else torch.empty(0, 0, 0)

    def cleanup(self):
        """Explicit cleanup method"""
        for q in self.gpu_queues:
            q.put(None)
        for w in self.workers:
            w.join(timeout=5.0)
        for model in self.models:
            del model
        for extractor in self.extractors:
            del extractor
        self.models.clear()
        self.extractors.clear()
        torch.cuda.empty_cache()
        gc.collect()

    def __del__(self):
        self.cleanup()


def get_model_config(layer):
    """
    Get self-supervised model configuration.
    :param layer: specific transformer layer to choose.
    :return: Configuration.
    """
    return {
        "raw": (None, None, None),
        "wavlm": ("microsoft/wavlm-large", WavLMModel, layer),
        "wav2vec2": ("facebook/wav2vec2-large-lv60", Wav2Vec2Model, layer),
        "hubert": ("facebook/hubert-large-ll60k", HubertModel, layer),
        "wavlm_base": ("microsoft/wavlm-base", WavLMModel, layer),
        "wav2vec2_base": ("facebook/wav2vec2-base", Wav2Vec2Model, layer),
        "hubert_base": ("facebook/hubert-base-ls960", HubertModel, layer),
        "wav2vec2_xlsr": ("facebook/wav2vec2-large-xlsr-53", Wav2Vec2Model, layer),
    }


_loaded_models = {}


def load_model(name, layer, max_gpus=None):
    """
    Load the chosen self-supervised model.
    :param name: name of model.
    :param layer: chosen layer.
    :param max_gpus: maximal gpus to use.
    :return: extractor, model, and layer.
    """
    global _loaded_models

    if _loaded_models:
        for key, model_data in _loaded_models.items():
            if isinstance(model_data, tuple) and len(model_data) == 2:
                if isinstance(model_data[0], BalancedMultiGPUModel):
                    model_data[0].cleanup()
                elif isinstance(model_data[0], tuple):
                    _, model = model_data[0]
                    del model
        _loaded_models.clear()
        torch.cuda.empty_cache()
        gc.collect()

    if name.lower() in {"raw", "waveform"}:
        return "raw", layer

    ngpu = get_gpu_count(max_gpus)
    if ngpu > 1:
        model = BalancedMultiGPUModel(name, layer, max_gpus)
        _loaded_models[name] = (model, layer)
        return model, layer
    else:
        ckpt, cls, layer_eff = get_model_config(layer)[name]
        extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)

        device = "cuda:0" if torch.cuda.is_available() else "cpu"
        attn_impl = "eager" if cls is WavLMModel else "sdpa"
        model = cls.from_pretrained(
            ckpt,
            output_hidden_states=True,
            use_safetensors=True,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            attn_implementation=attn_impl
        )
        model.eval()
        model = model.to(device)

        for param in model.parameters():
            param.requires_grad = False

        model_tuple = ((extractor, model), layer_eff)
        _loaded_models[name] = model_tuple
        return (extractor, model), layer_eff


def cleanup_all_models():
    """
    Call this at the end of each experiment to ensure complete cleanup
    """
    global _loaded_models
    if _loaded_models:
        for key, model_data in _loaded_models.items():
            if isinstance(model_data, tuple) and len(model_data) == 2:
                if isinstance(model_data[0], BalancedMultiGPUModel):
                    model_data[0].cleanup()
                elif isinstance(model_data[0], tuple):
                    _, model = model_data[0]
                    del model
        _loaded_models.clear()
    torch.cuda.empty_cache()
    gc.collect()


def embed_batch_raw(signals, masks_audio):
    """
    Waveform encoding in case it was chosen to skip self-supervised encording and push waveform directly to diffusion maps
    :param signals: waveform signals.
    :param masks_audio: voice activity masks of sources.
    :return:
    """
    win = int(ENERGY_WIN_MS * SR / 1000)
    hop = int(ENERGY_HOP_MS * SR / 1000)
    reps, L_max = [], 0
    for sig_np, mask_np in zip(signals, masks_audio):
        x = torch.as_tensor(sig_np[:-1], dtype=torch.float32)
        frames = x.unfold(0, win, hop)
        mask = torch.as_tensor(mask_np[: len(frames)], dtype=torch.bool)
        keep = frames[mask] if mask.any() else frames[:1]
        reps.append(keep)
        L_max = max(L_max, keep.size(0))
    reps = [F.pad(r, (0, 0, 0, L_max - r.size(0))) for r in reps]
    return torch.stack(reps, dim=0)


def embed_batch_single_gpu(
        signals, masks_audio, extractor, model, layer, use_mlm=False
):
    """
    See embed_batch.
    """
    if not signals:
        return torch.empty(0, 0, 0)
    device = next(model.parameters()).device

    max_batch = 2
    all_keeps = []

    for i in range(0, len(signals), max_batch):
        batch_signals = signals[i:i + max_batch]
        batch_masks = masks_audio[i:i + max_batch]

        inputs = extractor(batch_signals, sampling_rate=SR, return_tensors="pt", padding=True)
        input_values = inputs.input_values.to(device, non_blocking=True)

        orig_mode = model.training
        model.train() if use_mlm else model.eval()

        with torch.no_grad():
            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                hs = model(input_values, output_hidden_states=True).hidden_states[layer]
        model.train(orig_mode)

        B, T, D = hs.shape
        for b in range(B):
            mask_b = batch_masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
            mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
            all_keeps.append(hs[b][mask_t].cpu())

        del hs, input_values, inputs
        torch.cuda.empty_cache()

    if all_keeps:
        L_max = max(x.shape[0] for x in all_keeps)
        keep_padded = [F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in all_keeps]
        result = torch.stack(keep_padded, dim=0)
        del all_keeps, keep_padded
        return result
    else:
        return torch.empty(0, 0, 0)


def embed_batch(signals, masks_audio, model_wrapper, layer, use_mlm=False):
    """
    Encode a batch of signals using the self-supervised model chosen.

    :param signals: waveform signals to encode.
    :param masks_audio: voice activity masks of sources.
    :param model_wrapper: chosen model's wrapper.
    :param layer: transformer layer.
    :param use_mlm: deprecated.
    :return: embedded signal representations by the model's layer.
    """
    if model_wrapper == "raw":
        return embed_batch_raw(signals, masks_audio)
    if isinstance(model_wrapper, BalancedMultiGPUModel):
        all_embeddings = []
        batch_size = min(BATCH_SIZE, 2)
        for i in range(0, len(signals), batch_size):
            batch_emb = model_wrapper.process_batch(
                signals[i: i + batch_size], masks_audio[i: i + batch_size], use_mlm
            )
            if batch_emb.numel() > 0:
                all_embeddings.append(batch_emb)
            torch.cuda.empty_cache()

        if all_embeddings:
            result = torch.cat(all_embeddings, dim=0)
            del all_embeddings
            return result
        else:
            return torch.empty(0, 0, 0)
    else:
        extractor, model = model_wrapper
        return embed_batch_single_gpu(
            signals, masks_audio, extractor, model, layer, use_mlm=use_mlm
        )