import logging

import numpy as np
import torch
from tqdm import tqdm

from unrealpose.pose.models import darknet
from unrealpose.pose.refine.preprocess2d import gen_video_specific_rot
from .utils import non_max_suppression, scale_round_coords_inplace, interpolate_missing_detections_inplace


logger = logging.getLogger(__name__)


def load_yolov3(config):
    darknet_model = darknet.Darknet(config.YOLO.ARCH_CFG_PATH, config.YOLO.IMG_SIZE)
    logger.info('=> loading YoloV3 (people) from {}'.format(config.YOLO.WEIGHTS_PATH))
    _ = darknet.load_darknet_weights(darknet_model, config.YOLO.WEIGHTS_PATH)
    darknet_model.fuse()
    darknet_model.to(torch.device('cuda:0'))
    model_dict = {}
    model_dict['yolo_model'] = darknet_model

    return model_dict


def yolov3_inference(config,
                     loader,
                     dataset,
                     input_video_basename,
                     output_dir,
                     model_dict):
    device = torch.device('cuda', 0)
    for model in model_dict.values():
        model.eval()

    nsamples = len(dataset)
    njoints = config.NETWORK.NUM_JOINTS
    ori_frame_h, ori_frame_w = dataset.hw

    img_level_bad_flags = []  # indicate bad poses
    target_detections = np.zeros((nsamples, 4), dtype=np.int32)
    candidate_detections = [[] for _ in range(nsamples)]

    video_specific_rot = gen_video_specific_rot(input_video_basename)

    # Debug
    if video_specific_rot != 0:
        logger.warning('=> Warning, manually specified rotation 0')
        video_specific_rot = 0

    logger.info('=> Video specific rotation (detection): %d' % video_specific_rot)

    # if not config.PRODUCTION:
    #     import shutil, os
    #     debug_path = os.path.join(output_dir, 'special')
    #     shutil.rmtree(debug_path, ignore_errors=True)
    #     os.makedirs(debug_path)

    count = 0
    with torch.no_grad():
        for batch_idx, darknet_img in tqdm(enumerate(loader)):
            # frame: BGR [N, h, w, 3] 0-255, darknet_img: RGB [N, 3, h_resize, w_resize] 0.0 - 1.0

            if video_specific_rot == 90:
                # clock-wise rotation
                rotated_darknet_img = darknet_img.transpose(2, 3).flip(3)  # [N, 3, w_resize, h_resize]
                frame_h, frame_w = ori_frame_w, ori_frame_h
            else:
                rotated_darknet_img = darknet_img
                frame_h, frame_w = ori_frame_h, ori_frame_w
            nimgs = len(rotated_darknet_img)

            rotated_darknet_img = rotated_darknet_img.to(device, non_blocking=False)  # [N, 3, w_r, h_r]
            yolo_pred, _ = model_dict['yolo_model'](rotated_darknet_img)  # [8, 6552, 85], accept input: [N, c, h, w]

            det_list = non_max_suppression(yolo_pred, config.YOLO.CONF_THRES, config.YOLO.NMS_THRES)  # list
            # print(len(det_list))  # [8] number of imgs in one batch
            # print(det_list[0].shape)  # [k, 7]  k: detected object in this img

            detection_batch = np.zeros((nimgs, 4), dtype=np.int32)

            # make sure the bounding box
            for idx, one_img_det in enumerate(det_list):
                # left upper corner
                # x, y, w, h = 0, 0, frame_w, frame_h  # initial value for bad det
                x, y, w, h = 0, 0, 0, 0  # initial value for bad det
                # nothing detected if one_img_det is None
                if one_img_det is not None and len(one_img_det) > 0:
                    # change from the coord of darkimg to the coord of rotated frame
                    scale_round_coords_inplace(rotated_darknet_img.shape[2:], one_img_det[:, :4], np.array([
                        frame_h, frame_w], dtype=np.int32))

                    # if a person is detected
                    people_det = one_img_det[one_img_det[:, -1] == 0]
                    if len(people_det) == 0:
                        img_level_bad_flags.append(config.ERROR.NO_PEOPLE_DETECTED)
                    elif len(people_det) == 1:
                        # select the first detected person
                        select_det = people_det[0]
                        xyxy = select_det[:4].cpu().numpy().astype(np.int32)
                        np.clip(xyxy, 0, [frame_w - 1, frame_h - 1, frame_w - 1, frame_h - 1], xyxy)
                        x, y = xyxy[:2]
                        w, h = xyxy[2] - x, xyxy[3] - y
                        img_level_bad_flags.append(config.ERROR.PASS)
                    else:
                        cand_xyxy = people_det[:, :4].cpu().numpy().astype(np.int32)
                        cand_area = (cand_xyxy[:, 2] - cand_xyxy[:, 0]) * (cand_xyxy[:, 3] - cand_xyxy[:, 1])
                        sorted_idx = np.argsort(cand_area)[::-1]  # descent order
                        # clip to image size
                        np.clip(cand_xyxy[:, 0], 0, frame_w - 1, cand_xyxy[:, 0])
                        np.clip(cand_xyxy[:, 2], 0, frame_w - 1, cand_xyxy[:, 2])
                        np.clip(cand_xyxy[:, 1], 0, frame_h - 1, cand_xyxy[:, 1])
                        np.clip(cand_xyxy[:, 3], 0, frame_h - 1, cand_xyxy[:, 3])
                        xyxy = cand_xyxy[sorted_idx[0]]  # [4], the biggest box
                        x, y = xyxy[:2]
                        w, h = xyxy[2] - x, xyxy[3] - y
                        if cand_area[sorted_idx[0]] >= cand_area[sorted_idx[1]] * config.YOLO.SINGLE_PERSON_THRES and \
                            sorted_idx[0] == 0:
                            img_level_bad_flags.append(config.ERROR.PASS)
                        else:
                            # select the biggest box
                            img_level_bad_flags.append(
                                config.ERROR.MULTIPLE_PEOPLE_DETECTED)  # multiple people bad flag
                        candidate_detections[count + idx].append(cand_xyxy[sorted_idx[1:]])  # [K-1, 4]
                else:
                    img_level_bad_flags.append(config.ERROR.NO_PEOPLE_DETECTED)

                detection_batch[idx] = np.array((x, y, w, h), dtype=np.float32)

            target_detections[count:count + nimgs] = detection_batch
            count += nimgs  # the last batch has a smaller number than batch size

    # flags
    img_level_bad_flags = np.array(img_level_bad_flags, dtype=np.int32)

    # interpolate missing detections for short missing clips
    logger.info('=> interpolate missing detections')
    interpolate_missing_detections_inplace(config, target_detections, img_level_bad_flags)

    # Return candidate boxes for visualization
    # empty detections have zero w,h ; all returned detections are clipped to image size
    # rotate back
    if video_specific_rot == 90:
        target_detections = np.stack((target_detections[..., 1],
                                      ori_frame_h - target_detections[..., 2] - target_detections[..., 0],
                                      target_detections[..., 3],
                                      target_detections[..., 2]),
                                     axis=1)
    for item in candidate_detections:
        if len(item) == 0:
            continue
        else:
            candidates = item[0]  # [k, 4], x1, y1, x2, y2
            if video_specific_rot == 90:
                x, y = candidates[:, 0], candidates[:, 1]
                w, h = candidates[:, 2] - candidates[:, 0], candidates[:, 3] - candidates[:, 1]
                item[0] = np.stack((y, ori_frame_h - x - w, h, w), axis=1)
            else:
                candidates[:, 2] = candidates[:, 2] - candidates[:, 0]
                candidates[:, 3] = candidates[:, 3] - candidates[:, 1]

    record = {
        'target_person': target_detections,  # [N, 4]
        'candidate_person': candidate_detections,  # list [[], [k,4], [], ...]
        'img_level_bad_flags': img_level_bad_flags
    }

    return record
