import torch
import torchvision
import torchvision.models as models
from pytorchcv.model_provider import get_model as ptcv_get_model
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
import time
from utils import *

# convbn_scaling_factor
# fc_scaling_factor
# weight_integer
# bias_integer
# act_scaling_factor

checkpoint1 = torch.load("/home/admin1/Syh/Training-free-quant/QAT/result/resnet18_modelsize_6.7_a8_97B/checkpoint/quantized_checkpoint.pth.tar")
# print(checkpoint1['bias_integer'])
# for key in checkpoint1:
#     print(key)

# # print(checkpoint1['model_state_dict'])
# model_state_dict = checkpoint1['model_state_dict']
# for name in model_state_dict.keys():
#     print(name)

model = ptcv_get_model("resnet18", pretrained=True).to("cuda")
print(model._modules['features'].init_block)
# for name, module in model.named_modules():
#     print(name)

# tensor1 = checkpoint1['model_state_dict']['stage4.unit2.quant_convbn2.conv.weight']
# print(tensor1.shape)

# tensor2 = model.features.stage4.unit2.body.conv2.conv.weight
# print(tensor2.shape)
# print(torch.equal(tensor1, tensor2))

# print(tensor1)
# print("-"*100)
# print(tensor2)
# features = getattr(model, 'features')
# init_block = getattr(features, 'init_block')

# weights = models.ResNet18_Weights.DEFAULT
# model = models.resnet18(weights=weights)
# for name, param in model.named_parameters():
#     print(name)
# print("-"*100)
# for name, module in model.named_modules():
#     print(name)

# model = q_resnet18(model)
# for name, param in model.named_parameters():
#     print(name)

# conv1 = getattr(model, 'quant_init_block_convbn')
# print(conv1.conv.weight)

# print(torch.equal(conv1.conv.weight, init_block.conv.conv.weight))

# model.eval()
# model = torch.quantization.fuse_modules(model, [['conv1', 'bn1', 'relu']], inplace=True)
# print(model.qconfig)

# backbone = ptcv_get_model("resnet18", pretrained=True)
# print(backbone.output.out_features)

# model = ptcv_get_model("resnet18", pretrained=True)
# # print(model.output.weight)
# model = q_resnet18(model)
# # for name, module in model.named_modules():
# #     print(name)
# #     print("-"*80)
# #     print(module)
# print(model.quant_init_block_convbn.weight_integer)

# model = ptcv_get_model("mobilenetv2_w1", pretrained=True)
# model = q_mobilenetv2_w1(model)
# print(model.features)
# print("-"*100)

# weights = torchvision.models.MobileNet_V2_Weights.DEFAULT
# model = torchvision.models.mobilenet_v2(weights=weights)
# print(model)