import os
from typing import Optional, Tuple, List, Union, Callable
from tqdm import tqdm

import math
import argparse

import time
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from tqdm import trange
import cv2 
import onnx
import itertools
from collections import defaultdict

from cifar10_resnet.resnet import resnet2b, resnet4b

from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm, PerturbationLinear
import onnx2pytorch

from yolo.yolo_utils import *



if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    detection_model_path = 'yolo/TinyYOLO.onnx'
    detection_model = ONNXYOLOModel(detection_model_path, input_size=52, device=device)
    detection_model.to(device)
    detection_model.eval()


    data = np.load('downsampled_planes_yolo_npz.npz')
    x_L = data['output_lb']
    x_U = data['output_ub']

    verify_result = np.zeros(x_L.shape[0])
    BS = 1
    for j in range(0, x_L.shape[0], BS):
    # # x_L and x_U have shape (1, 3, 52, 52)
    # x_L = torch.load('lower.pth').to(device)
    # x_U = torch.load('upper.pth').to(device)
        # print(i)
        x_L_batch = x_L[j:j+BS]
        x_U_batch = x_U[j:j+BS]

        x_L_batch = torch.Tensor(x_L_batch).to('cuda').permute(0,3,1,2)
        x_U_batch = torch.Tensor(x_U_batch).to('cuda').permute(0,3,1,2)

        # The model takes in flatten input
        x_L_batch = x_L_batch.reshape(x_L_batch.shape[0], -1)
        x_U_batch = x_U_batch.reshape(x_U_batch.shape[0], -1)

        bounded_model = BoundedModule(detection_model, (torch.zeros(1, 3 * 52 * 52).to(device),))

        unperturbed_input = (x_L_batch + x_U_batch) / 2

        # Load image bounds
        bounded_input = BoundedTensor(unperturbed_input, PerturbationLpNorm(
            x_L=x_L_batch,
            x_U=x_U_batch
        ))


        flatten_test_output = detection_model(unperturbed_input)
        test_output = flatten_test_output.reshape(-1, 125, 13, 13)
        B, abC, H, W = test_output.shape    # Should be 1, 125, 13, 13
        KA = 5
        NC = 20

        # Extract the confidence, class, and bounding box predictions
        prediction = test_output.permute(0, 2, 3, 1).contiguous().view(B, -1, abC)
        conf_pred = prediction[..., :KA].contiguous().view(B, -1, 1)
        cls_pred = prediction[..., 1 * KA: (1 + NC) * KA].contiguous().view(B, -1, NC)
        txtytwth_pred = prediction[..., (1 + NC) * KA:].contiguous().view(B, -1, 4)

        conf_pred = conf_pred[0]  # [H*W*KA, 1]
        cls_pred = cls_pred[0]  # [H*W*KA, NC]
        txtytwth_pred = txtytwth_pred[0]  # [H*W*KA, 4]

        # Pick out the max confidence and its index
        max_conf, max_conf_idx = torch.max(conf_pred, dim=0)

        # calculate original max confidence index
        d2 = math.floor(max_conf_idx / (W * KA))
        d3 = math.floor((max_conf_idx - d2 * W * KA) / KA)
        d1 = max_conf_idx.item() - d2 * W * KA - d3 * KA
        # print('conf idx ', d1, d2, d3)
        # assert max_conf.item() == test_output[0, d1, d2, d3]

        #######################################################################################################
        # Part 1: Only compute the confidence bound

        # The five bounding boxes for the same grid cell
        conf_idx_same_grid = [idx * 13 * 13 + d2 * 13 + d3 for idx in range(5)]
        # print('conf idx same grid', conf_idx_same_grid)

        # Collect these five values from the output when computing bounds
        C_conf = torch.zeros(1, 5, flatten_test_output.size(1),
                            device=flatten_test_output.device,
                            dtype=flatten_test_output.dtype)
        for i in range(5):
            C_conf[0, i, conf_idx_same_grid[i]] = 1.0

        #######################################################################################################
        # Part 2: Compute the confidence and bounding box bounds
        # This part is optional, but it is useful to compute the bounding box
        # If we don't care about the bounding box, we can skip this part and only have Part 1
        # If we skip this part, the argument C in compute_bounds should be set to C_conf

        # xywh for this specific bounding box
        reg_idx_same_box = [(105 + d1 * 4) * 13 * 13 + idx * 13 * 13 + d2 * 13 + d3 for idx in range(4)]
        # Collect these nine values from the output when computing bounds
        C_conf_with_reg = torch.zeros(1, 9, flatten_test_output.size(1),
                                    device=flatten_test_output.device,
                                    dtype=flatten_test_output.dtype)
        for i in range(5):
            C_conf_with_reg[0, i, conf_idx_same_grid[i]] = 1.0
        for i in range(4):
            C_conf_with_reg[0, 5 + i, reg_idx_same_box[i]] = 1.0
        #######################################################################################################

        with torch.no_grad():
            ret = bounded_model.compute_bounds(bounded_input, C=C_conf_with_reg,  method="crown")

        # Specification: At least one of the confidences of the five bounding boxes in the same grid cell
        # is greater than a threshold
        conf_threshold = -0.2
        lower_confs = ret[0][:, :5]
        res = (lower_confs>-0.2).sum(dim=1)
        res = res.detach().cpu().numpy()
        verify_result[j:j+BS] = (res>0).astype(np.float16)
        if torch.any(lower_confs > -0.2):
            print(f"{j}: verified")
        else:
            print(f"{j}: {lower_confs.max()}")
        if j%100==0:
            np.savez_compressed('all_results_yolo.npz', all_results=verify_result)
    print(f'Verified: {verify_result.sum()/len(verify_result)}')
    np.savez_compressed('all_results_yolo.npz', all_results = verify_result)



    # #######################################################################################################
    # # Optional: Visualize the bounding box if you need
    # # Example: Visualizing the bouding box of the unperturbed output

    # # Get the anchors of that bounding box
    # anchors = detection_model.anchor_boxes[max_conf_idx]
    # # Decode the bounding box
    # bboxes = detection_model.decode_boxes(anchors, txtytwth_pred[max_conf_idx])
    # # Generate the image with bounding box
    # img = plot_bbox_labels(torch_to_cv2(bounded_input[0].reshape(3, 52, 52)),
    #                        bboxes[0], label=None, cls_color=[255, 0, 0])\
    # # Save the image
    # cv2.imwrite('test.png', img)
    # #####################################################################################################
