import os
from typing import Optional, Tuple, List, Union, Callable
from tqdm import tqdm

import math
import argparse

import time
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from tqdm import trange
import cv2 
import onnx
import itertools
from collections import defaultdict

from cifar10_resnet.resnet import resnet2b, resnet4b

from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm, PerturbationLinear

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)
mu = torch.Tensor(cifar10_mean).view(3,1,1).cuda()
std = torch.Tensor(cifar10_std).view(3,1,1).cuda()
def normalize(X):
    return (X-mu)/std

def classificaion_based_on_constant_bounds(model, x_L, x_U, batch_size, target_class):
    lirpa_model = BoundedModule(model, (x_L,))
    mask_verifeid_all = torch.zeros(x_L.shape[0], dtype=torch.bool, device=x_L.device)

    for i in range(0, x_L.shape[0] // batch_size):
        x_L_batch = x_L[i * batch_size:(i + 1) * batch_size]
        x_U_batch = x_U[i * batch_size:(i + 1) * batch_size]
        ptb = PerturbationLpNorm(x_L=x_L_batch, x_U=x_U_batch)
        bounded_x = BoundedTensor(x_L_batch, ptb)
        lb, ub = lirpa_model.compute_bounds(bounded_x, method="crown")
        # print("Bound lb: ", lb)
        # print("Bound ub: ", ub)

        mask_other = torch.ones(lb.size(1), dtype=torch.bool, device=lb.device)
        mask_other[target_class] = False
        
        mask_verified = (lb[:, target_class].unsqueeze(-1) > ub[:, mask_other]).all(dim=1)
        mask_verifeid_all[i * batch_size:(i + 1) * batch_size] = mask_verified
        print(f"Batch {i + 1}/{x_L.shape[0] // batch_size} processed.")
        print(f"Verified rate: {mask_verified.float().mean().item() * 100:.2f}%")
    return mask_verifeid_all

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='resnet2b', help='model name')
    parser.add_argument('--data', type=str, default='res_classification_truck.npz', help='dataset name')
    parser.add_argument('--target_class', type=int, default=9, help='target class for classification (0 for plane)')
    parser.add_argument('--batch_size', type=int, default=16, help='batch size')
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the model
    if args.model == 'resnet2b':
        model = resnet2b()
    elif args.model == 'resnet4b':
        model = resnet4b()
    else:
        raise ValueError("Unsupported model")
    # Load the pre-trained weights
    model.load_state_dict(torch.load(f'./cifar10_resnet/{args.model}.pth')['state_dict'])
    model.to(device)

    # Load image bounds
    data = np.load(args.data)
    # data has shape (N, 32, 32, 3) range [0, 255]
    x_L = data['output_lb']
    x_U = data['output_ub']
    # plt.figure(0)
    # plt.imshow(x_L[0])
    # plt.figure(1)
    # plt.imshow(x_U[0])
    # plt.show()
    x_L = torch.from_numpy(x_L).to(torch.float32).to(device)
    x_U = torch.from_numpy(x_U).to(torch.float32).to(device)
    x_L = x_L.permute(0, 3, 1, 2)/255.0  # Change to (N, C, H, W)
    x_U = x_U.permute(0, 3, 1, 2)/255.0  # Change to (N, C, H, W)
    x_L = normalize(x_L)
    x_U = normalize(x_U)
    # tmp_x_L = x_L.repeat((10000,1,1,1))
    # tmp_x_U = x_U.repeat((10000,1,1,1))
    # alpha = torch.rand((10000,1,1,1)).to('cuda')
    # tmp = alpha*tmp_x_L+(1-alpha)*tmp_x_U
    # res_tmp = model(tmp)
    # print("Sample lb: ", res_tmp.min(dim=0).values)
    # print("Sample ub: ", res_tmp.max(dim=0).values)
    # x_L and x_U are now tensors of shape (N, 3, 32, 32) with values in [0, 1]
    # res_L = model(x_L)
    # res_U = model(x_U)
    # print(res_L)
    # print(res_U)

    # Perform classification based on constant bounds
    mask_verified = classificaion_based_on_constant_bounds(
        model, x_L, x_U, args.batch_size, args.target_class)
    print(mask_verified.shape)
    np.savez_compressed('all_results.npz', all_results=mask_verified.detach().cpu().numpy())
    print(f"\nTotal verified rate: {mask_verified.float().mean().item() * 100:.2f}%")
