import os
from typing import Optional, Tuple, List, Union, Callable
from tqdm import tqdm

import math
import argparse

import time
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from tqdm import trange
import cv2 
import onnx
import itertools
from collections import defaultdict

from gatenet.gatenet import GateNet
from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm, PerturbationLinear

def classificaion_based_on_constant_bounds(model, x_L, x_U, batch_size, camera_poses_gt):
    lirpa_model = BoundedModule(model, (torch.Tensor(x_L[:1]).to('cuda').permute(0,3,1,2),))
    total_num = 0
    overall_pose_estimation_result = np.zeros(x_L.shape[0])
    for i in range((x_L.shape[0]-1) // batch_size+1):
        x_L_batch = x_L[i * batch_size:(i + 1) * batch_size]
        x_U_batch = x_U[i * batch_size:(i + 1) * batch_size]
        x_L_batch = torch.Tensor(x_L_batch).to('cuda').permute(0,3,1,2)
        x_U_batch = torch.Tensor(x_U_batch).to('cuda').permute(0,3,1,2)
        camera_poses_gt_batch = camera_poses_gt[i*batch_size:(i+1)*batch_size]
        ptb = PerturbationLpNorm(x_L=x_L_batch, x_U=x_U_batch)
        bounded_x = BoundedTensor(x_L_batch, ptb)
        lb, ub = lirpa_model.compute_bounds(bounded_x, method="crown")
        # lb_all[i * batch_size:(i + 1) * batch_size] = lb
        # ub_all[i * batch_size:(i + 1) * batch_size] = ub

        lb_array = lb.detach().cpu().numpy()
        ub_array = ub.detach().cpu().numpy()
        diff1 = np.linalg.norm(lb_array-camera_poses_gt_batch, axis=1)
        diff2 = np.linalg.norm(ub_array-camera_poses_gt_batch, axis=1)
        max_diff = np.maximum(diff1, diff2)
        print(">>>>>>", max_diff)
        num_below_threshold = np.where(max_diff<3.5)[0].shape[0]
        overall_pose_estimation_result[i * batch_size:(i + 1) * batch_size] = (max_diff<3.5).astype(np.float16)
        print(f"{i}/{(x_L.shape[0]-1) // batch_size+1}: {num_below_threshold/max_diff.shape[0]*100:.2f}%")
        total_num += num_below_threshold
    print(f"overall accurcy: {total_num/x_L.shape[0]}")
    return total_num, overall_pose_estimation_result

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='resnet2b', help='model name')
    parser.add_argument('--data', type=str, default='downsampled_planes_4831x_pose_npz.npz', help='dataset name')
    parser.add_argument('--target_class', type=int, default=0, help='target class for classification (0 for plane)')
    parser.add_argument('--batch_size', type=int, default=5, help='batch size')
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    config = {
        'input_shape': (3, 80, 80),
        'output_shape': (3,),  # X, Y, Z, yaw, pitch, roll
        'l2_weight_decay': 1e-4,
        'batch_norm_decay': 0.99,
        'batch_norm_epsilon': 1e-3
    }

    model = GateNet(config)

    # Load the pre-trained weights
    model.load_state_dict(torch.load(f'./gatenet/checkpoint_epoch_100.pth')['model_state_dict'])
    model.to(device)
    model.eval()

    # Load image bounds
    data = np.load(args.data)
    # data has shape (N, 32, 32, 3) range [0, 255]
    x_L = data['output_lb']
    x_U = data['output_ub']
    plt.figure(0)
    plt.imshow(x_L[0])
    plt.figure(1)
    plt.imshow(x_L[-1])
    plt.show()
    # x_L = torch.from_numpy(x_L).to(torch.float32).to(device)
    # x_U = torch.from_numpy(x_U).to(torch.float32).to(device)
    # x_L = x_L.permute(0, 3, 1, 2)  # Change to (N, C, H, W)
    # x_U = x_U.permute(0, 3, 1, 2)  # Change to (N, C, H, W)
    # tmpresl = model(x_L[:])
    # tmpresu = model(x_U[:])
    # tmpresc = model((x_L+x_U)/2)
    # x_L and x_U are now tensors of shape (N, 3, 32, 32) with values in [0, 1]
    poses_gt = np.load('camera_poses_x.npz')
    camera_poses_gt = poses_gt['camera_poses']

    # # # Perform classification based on constant bounds
    tmpl = x_L[:]
    tmpu = x_U[:]
    tmpl = np.minimum(tmpl, tmpu)
    tmpu = np.maximum(tmpl, tmpu)

    with torch.no_grad():
        accuracy, overall_pose_estimation_result = classificaion_based_on_constant_bounds(
            model, tmpl, tmpu, args.batch_size, camera_poses_gt.T)    

    # print(tmpresl)
    # print(tmpresu)
    # print(tmpresc)
    # print(camera_poses_gt[:,:])
    print(overall_pose_estimation_result.shape)
    np.savez_compressed('all_results_pose_estimation_4831x.npz', all_results = overall_pose_estimation_result)

    # from PIL import Image 
    image_gt_fn = '../../pose_estimator_plane/test_images_8_npz/img_4381.npz'
    img_gt = np.load(image_gt_fn)['img']
    img_gt_tensor = torch.Tensor(img_gt[None]).permute(0,3,1,2).to('cuda')
    res = model(img_gt_tensor)
    print(res)
    # img_gt = Image.open(image_gt_fn)
    # img_gt_array = np.array(img_gt)

    # tmp_l = data['output_lb'][0]
    # tmp_u = data['output_ub'][1]

    # print(np.all(img_gt_array>=tmp_l))
    # print(np.all(img_gt_array<=tmp_u))

    # print("aa")
