import time
import torch
from fvcore.nn import FlopCountAnalysis
import librosa
from inference_error import InferenceHandler
import hydra
from omegaconf import DictConfig


def load_audio(file_path, sr=16000):
    audio, _ = librosa.load(file_path, sr=sr)
    print(f"Loaded {file_path} with length {len(audio)/sr} seconds.")
    audio_tensor = torch.tensor(audio, dtype=torch.float32)
    return audio_tensor


def benchmark_model(
    model, mistake_file, score_file, prompt_path, batch_size=1, num_runs=10
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    mistake_audio = load_audio(mistake_file)
    score_audio = load_audio(score_file)

    # Initialize Inference Handler
    inference_handler = InferenceHandler(model=model, device=device)

    # Preprocess inputs
    mistake_inputs, score_inputs, mistake_frame_times, score_frame_times, prompt_row = (
        inference_handler._preprocess(mistake_audio, score_audio, prompt_path)
    )

    mistake_inputs_tensor = torch.from_numpy(mistake_inputs).to(device)
    score_inputs_tensor = torch.from_numpy(score_inputs).to(device)

    if model.config.use_prompt:
        prompt_row = inference_handler._postprocess_prompt_batch(prompt_row, dataset)
        prompt_tokens = torch.cat(prompt_row["prompt_tokens"], dim=0).to(device)
        prompt_masks = torch.cat(prompt_row["prompt_masks"], dim=0).to(device)
    else:
        prompt_tokens, prompt_masks = None, None

    print(f"mistake_frame_times type: {type(mistake_frame_times)}")
    print(
        f"mistake_frame_times shape: {mistake_frame_times.shape if mistake_frame_times is not None else 'None'}"
    )

    # Prepare batches
    (
        mistake_batches,
        score_batches,
        _,
        _,
        prompt_tokens_batch,
        prompt_masks_batch,
    ) = inference_handler._batching(
        mistake_inputs_tensor,
        score_inputs_tensor,
        mistake_frame_times,
        score_frame_times,
        prompt_tokens,
        prompt_masks,
        batch_size=batch_size,
    )

    torch.cuda.synchronize()
    start_time = time.time()

    for _ in range(num_runs):
        with torch.no_grad():
            for mistake_batch, score_batch, pt_batch, pt_mask_batch in zip(
                mistake_batches, score_batches, prompt_tokens_batch, prompt_masks_batch
            ):
                if (
                    mistake_batch is None
                    or mistake_batch.numel() == 0
                    or score_batch is None
                    or score_batch.numel() == 0
                ):
                    print(
                        "WARNING: skipping forward pass because batch is empty or invalid."
                    )
                    continue

                print("----- Debugging Model Inputs -----")
                print(f"mistake_batch shape: {mistake_batch.shape}")
                print(f"score_batch shape: {score_batch.shape}")
                print(
                    f"decoder_input_ids shape: {pt_batch.shape if pt_batch is not None else 'None'}"
                )
                print(
                    f"decoder_attention_mask shape: {pt_mask_batch.shape if pt_mask_batch is not None else 'None'}"
                )

                sos_token = model.config.decoder_start_token_id
                if pt_batch is None or pt_batch.shape[1] == 0:
                    pt_batch = torch.full(
                        (batch_size, 1), sos_token, dtype=torch.long, device=device
                    )
                if pt_mask_batch is None:
                    pt_mask_batch = torch.ones_like(
                        pt_batch, dtype=torch.long, device=device
                    )

                outputs = model(
                    mistake_inputs=mistake_batch,
                    score_inputs=score_batch,
                    decoder_input_ids=pt_batch,
                    decoder_attention_mask=pt_mask_batch,
                )

                print(f"Model output type: {type(outputs)}")
                if isinstance(outputs, torch.Tensor):
                    logits = outputs
                else:
                    logits = outputs.logits
                next_token = torch.argmax(logits[:, -1, :], dim=-1)
                print(f"Generated token: {next_token}")

    torch.cuda.synchronize()
    total_time = time.time() - start_time
    avg_latency = total_time / (num_runs * len(mistake_batches))
    print(f"Average Forward Pass Latency per Batch: {avg_latency:.6f} seconds")

    # Debug prints
    print("DEBUG: Checking shapes before FLOP analysis...")
    print(
        f"mistake_batches[0]: {mistake_batches[0].shape if mistake_batches and mistake_batches[0] is not None else 'None'}"
    )
    print(
        f"score_batches[0]: {score_batches[0].shape if score_batches and score_batches[0] is not None else 'None'}"
    )
    if model.config.use_prompt:
        if prompt_tokens_batch and prompt_tokens_batch[0] is not None:
            print(f"prompt_tokens_batch[0]: {prompt_tokens_batch[0].shape}")
        else:
            print("prompt_tokens_batch[0] is None")
    else:
        print("Model config indicates no prompt usage.")

    sos_token = model.config.decoder_start_token_id
    if model.config.use_prompt and (
        not prompt_tokens_batch or prompt_tokens_batch[0] is None
    ):
        fallback_decoder_id = torch.full(
            (1, 1), sos_token, dtype=torch.long, device=mistake_batches[0].device
        )
    else:
        fallback_decoder_id = prompt_tokens_batch[0]

    flops = FlopCountAnalysis(
        model,
        {
            "mistake_inputs": mistake_batches[0],
            "score_inputs": score_batches[0],
            "decoder_input_ids": fallback_decoder_id,
        },
    )
    print(f"FLOPs: {flops.total()}")

    return avg_latency, flops.total()


@hydra.main(config_path=None, config_name=None, version_base="1.1")
def main(cfg: DictConfig):
    """Load model and run benchmarks."""
    assert cfg.path, "Model path must be specified in the config file"

    print("Loading model...")
    model_cls = hydra.utils.get_class(cfg.model._target_)
    model = model_cls.load_from_checkpoint(
        cfg.path, config=cfg.model.config, optim_cfg=cfg.optim
    )
    model.eval()
    model.cuda()

    # Define test files
    mistake_file = ""
    score_file = ""
    prompt_file = ""

    # Run benchmark
    benchmark_model(model, mistake_file, score_file, prompt_file)


if __name__ == "__main__":
    main()
