import torch
from einops import rearrange

import utils.data_transforms as data_transforms
from configs.base_configs import ModelType


def prepare_mixture(mixture, config):
    if config.quantize_input:
        return data_transforms.stacked_to_interleaving(mixture)
    else:
        return data_transforms.stacked_to_windows_with_context(mixture, config.window_size, config.context_size)


def predict_quantized(model, mixture, config):
    preds = model(cond=prepare_mixture(mixture, config))
    return data_transforms.interleaving_to_stacked(preds)


def predict_quantized_llm(model, mixture, config, beam_k):
    recon = model.generate(cond=prepare_mixture(mixture, config), beam_k=beam_k)
    result = data_transforms.interleaving_to_stacked(recon)
    return result


def predict_windows(model, mixture, config):
    mixture = data_transforms.stacked_to_windows_with_context(mixture, config.window_size, config.context_size)
    preds = model(input=mixture)
    return data_transforms.windows_to_stacked(preds)


def predict_windows_llm(model, mixture, config):
    mixture = data_transforms.stacked_to_windows_with_context(mixture, config.window_size, config.context_size)
    recon = model.generate(cond=mixture, window_size=config.window_size, context_size=config.context_size)
    return data_transforms.windows_to_stacked(recon)


def predict_wavenet(model, mixture, config):
    mixture = data_transforms.stacked_to_wavenet(mixture)
    preds = model(mixture)
    return data_transforms.wavenet_to_stacked(preds)


def single_generate(model, mixture, config, **kwargs):
    if config.model_type == ModelType.WAVENET:
        return predict_wavenet(model, mixture, config)
    elif config.model_type == ModelType.WINDOWS_LLM:
        return predict_windows_llm(model, mixture, config)
    elif config.model_type == ModelType.WINDOWS:
        return predict_windows(model, mixture, config)
    elif config.model_type == ModelType.QUANTIZED:
        return predict_quantized(model, mixture, config)
    elif config.model_type == ModelType.QUANTIZED_LLM:
        return predict_quantized_llm(model, mixture, config, beam_k=kwargs["beam_k"])
    else:
        raise ValueError("Unsupported model type")
    

@torch.no_grad()
def generate(model, input, config, base_signal_length, expansion, **kwargs):
    if expansion == "id":
        assert base_signal_length == input.shape[1]
        return single_generate(model, input, config, **kwargs)
    elif expansion == "concat":
        assert input.shape[1] % base_signal_length == 0
        batch_size = input.shape[0]
        input_parts = input.unfold(dimension=1, size=base_signal_length, step=base_signal_length)
        input_parts = rearrange(input_parts, "b p c s -> (b p) s c")
        output = single_generate(model, input_parts, config, **kwargs)
        output = rearrange(output, "(b p) s c -> b (p s) c", b=batch_size)
        return output
    elif expansion == "multidiff":
        multidiff_step = kwargs["multidiff_step"]
        assert input.shape[1] >= base_signal_length and \
                (input.shape[1] - base_signal_length) % multidiff_step == 0
        input_parts = input.unfold(dimension=1, size=base_signal_length, step=multidiff_step)
        batch_size = input.shape[0]
        input_parts = rearrange(input_parts, "b p c s -> (b p) s c")
        output_parts = single_generate(model, input_parts, config, **kwargs)
        output_parts = rearrange(output_parts, "(b p) s c -> b p s c", b=batch_size)

        # We are computing the average
        output_sum = torch.zeros_like(input)
        output_cnt = torch.zeros_like(input)
        for i in range(output_parts.shape[1]):
            start = i * multidiff_step
            end = i * multidiff_step + base_signal_length
            output_cnt[:, start:end, :] += 1.0
            output_sum[:, start:end, :] += output_parts[:, i, :, :]
        return output_sum / output_cnt
    else:
        raise ValueError("Unsupported expansion type")
