

import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torch_tensorrt

import os
import argparse
import time
import numpy as np

import _init_paths
from config import cfg, update_config
from utils.transforms import transform_preds
from core.num_convert import convert
from dataset import coco


def _print_name_value(name_value, full_arch_name):
    names = name_value.keys()
    values = name_value.values()
    num_values = len(name_value)
    print('| Arch ' + ' '.join(['| {}'.format(name) for name in names]) + ' |')
    print('|---' * (num_values+1) + '|')

    if len(full_arch_name) > 15:
        full_arch_name = full_arch_name[:8] + '...'
    print('| ' + full_arch_name + ' ' +
        ' '.join(['| {:.3f}'.format(value) for value in values]) + ' |'
    )

def throughput_benchmark(model, input_shape=(32, 3, 256, 192), device='cuda', dtype=torch.bfloat16, use_amp=True, nwarmup=10, nruns=500):
    dummy_input = torch.randn(input_shape, device=device)
        
    with torch.no_grad():
        with torch.autocast(device_type=device, dtype=dtype, enabled=use_amp):
            for ii in range(nwarmup):
                model(dummy_input)
    torch.cuda.synchronize()

    timings = []
    with torch.no_grad():
        with torch.autocast(device_type=device, dtype=dtype, enabled=use_amp):
            for ii in range(nruns):
                start_time = time.time()
                xx, yy = model(dummy_input)[:]
        
                xxx = xx.detach().clone().sigmoid()
                yyy = yy.detach().clone().sigmoid()
        
                xiaxi, _ = convert(xxx, yyy, n_bit_integer, n_bit_fractional, narrow_ratio=narrow_ratio)
                torch.cuda.synchronize()
                timings.append(time.time() - start_time)

    print(f'Input shape: {input_shape}, Average throughput: {input_shape[0]/np.mean(timings):.3f} images/second')
    print(f"batched ({input_shape[0]}) inference time: {np.mean(timings)*1000:.3f} milliseconds")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg', required=True, type=str)
    parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
    parser.add_argument('--modelDir', type=str, default='')
    parser.add_argument('--logDir', type=str, default='')
    parser.add_argument('--dataDir', type=str, default='')
    parser.add_argument('--prevModelDir', type=str, default='')
    parser.add_argument("--model", default="convnextv2_tiny",type=str)
    parser.add_argument("--bs", type=int, default=32)
    parser.add_argument("--channels_last", default=True, type=bool)
    parser.add_argument('--drop_path_rate', default=0, type=float)
    parser.add_argument('--layer_scale_init_value', default=1e-6, type=float)
    parser.add_argument('--head_init_scale', default=1.0, type=float)
    parser.add_argument('--model_key', default='model|module', type=str)
    args = parser.parse_args()
    update_config(cfg, args)

    torch.backends.cuda.matmul.allow_tf32 = True
    cudnn.benchmark = True
    cudnn.deterministic = False
    cudnn.enabled = True

    narrow_ratio = 0.5
    device = 'cuda'
    n_range = int(cfg.MODEL.IMAGE_SIZE[1] * narrow_ratio)
    n_bit_integer = int(torch.ceil(torch.log2(torch.tensor(n_range)))) 
    n_bit_fractional = 0

    memory_format = torch.channels_last if args.channels_last else torch.contiguous_format
    
    print('loading tensorrt model ...')
    trt_model_new = torch.export.load("HashPose_L.ep").module()

    throughput_benchmark(trt_model_new, input_shape=(1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))

    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

    if cfg.DATASET.DATASET == 'coco':
        selected_dataset = coco
    valid_dataset = selected_dataset(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]),
        narrow_ratio
    )
    
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=True
    )

    num_samples = len(valid_dataset)
    all_preds = np.zeros((num_samples, cfg.MODEL.NUM_JOINTS, 3), dtype=np.float32)
    all_boxes = np.zeros((num_samples, 6))
    image_path = []
    filenames = []
    imgnums = []
    idx = 0

    with torch.no_grad():
        for i, (input, target_f, target_i, target_weight_f, target_weight_i, meta) in enumerate(valid_loader):
            input = input.to(device=device, memory_format=memory_format)

            with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                xx, yy = trt_model_new(input)[:]
                xxx = xx.detach().clone().sigmoid() 
                yyy = yy.detach().clone().sigmoid()
                quota, nmaly = convert(xxx, yyy, n_bit_integer, n_bit_fractional, narrow_ratio=narrow_ratio)
            num_images = input.size(0)
            c = meta['center'].numpy()
            s = meta['scale'].numpy()
            score = meta['score'].numpy()

            n_batch, n_channel = quota.shape[0:2]
            xxz = np.zeros([n_batch, n_channel, 2], dtype=float)
            for p_i in range(n_batch):
                xxz[p_i] = transform_preds(
                    quota[p_i].cpu().numpy(), c[p_i], s[p_i], [cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]]
                )
            all_preds[idx:idx + num_images, :, 0:2] = xxz[:, :, 0:2]
            all_preds[idx:idx + num_images, :, 2:3] = nmaly.cpu().numpy()
            all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2]
            all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2]
            all_boxes[idx:idx + num_images, 4] = np.prod(s*200, 1)
            all_boxes[idx:idx + num_images, 5] = score
            image_path.extend(meta['image'])
            idx += num_images
            if i % 100 == 0:
                print(f'processing {i}-th sample ...')
        
        output_dir = './'
        name_values, perf_indicator = valid_dataset.evaluate( 
            cfg, all_preds, output_dir, all_boxes, image_path,
            filenames, imgnums
        )

        model_name = cfg.MODEL.NAME
        if isinstance(name_values, list):
            for name_value in name_values:
                _print_name_value(name_value, model_name)
        else:
            _print_name_value(name_values, model_name)

   

    


