import numpy as np
import torch
from torch.cuda import synchronize
from omegaconf import DictConfig

from pado.core.tracker import TimeTracker
from pado.utils import default_parser, load_config, override_config_by_cli, count_params
from model import ReuseAttnConformerCTC


def run(cfg: DictConfig, batch_size: int, seq_length: int, num_iters: int):
    # -------------------------------------------------------------------------------------------------------- #
    # Network
    # -------------------------------------------------------------------------------------------------------- #
    # loss_type = cfg["loss_type"].lower()

    network = ReuseAttnConformerCTC(cfg["model"])

    device = torch.device("cuda:0")
    network = network.to(device)
    network.eval()
    param_count, param_count_elements = count_params(network.encoder.parameters())
    print("-" * 64)
    print(f"Encoder Parameters: {param_count}, elements: {param_count_elements}")

    feature_dim = network.encoder.hidden_dim
    dummy_input = torch.randn(batch_size, seq_length, feature_dim, dtype=torch.float32, device=device)
    dummy_length = torch.full((batch_size,), fill_value=seq_length, dtype=torch.long, device=device)

    tracker = TimeTracker()

    durations = []
    with torch.no_grad():
        for n in range(num_iters):
            synchronize(device)
            tracker.reset()

            _ = network.encoder(dummy_input, dummy_length)

            synchronize(device)
            d = tracker.update()
            durations.append(d)

    durations = np.array(durations[5:], dtype=np.float32)
    print("=" * 64)
    print(f"Batch size: {batch_size}, Length: {seq_length}")
    print(f"Average duration (ms): {np.mean(durations) * 1000:.6f}")
    print(f"Average duration (ms/sample): {np.mean(durations) / batch_size * 1000:.6f}")


if __name__ == '__main__':
    parser = default_parser()
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
    parser.add_argument("--seq_length", type=int, default=512, help="Sequence length")
    parser.add_argument("--num_iters", type=int, default=25, help="Number of iterations")
    args = parser.parse_args()

    d_config = load_config(args.config)
    d_config = override_config_by_cli(d_config, args.script_args)
    run(d_config, args.batch_size, args.seq_length, args.num_iters)
