import torch
import torchvision
import time
import argparse
import hubconf
from quant.fold_bn import search_fold_and_reset_bn
from quant import *
import torch.nn as nn
from data.imagenet import build_imagenet_data

def get_train_samples(train_dataloader, num_samples):
    train_data = []
    for batch in train_dataloader:
        train_data.append(batch[0])
        if len(train_data) * batch[0].size(0) >= num_samples:
            break
    return torch.cat(train_data, dim=0)[:num_samples]

parser = argparse.ArgumentParser(description='running parameters', formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# general parameters for data and model
parser.add_argument('--batch_size', default=256, type=int, help='mini-batch size for data loader')
parser.add_argument('--workers', default=24, type=int, help='number of workers for data loader')
parser.add_argument('--data_path', default='', type=str, help='path to ImageNet data', required=True)

# quantization parameters
parser.add_argument('--n_bits_w', default=4, type=int, help='bitwidth for weight quantization')
parser.add_argument('--channel_wise', action='store_true', help='apply channel_wise quantization for weights')
parser.add_argument('--n_bits_a', default=4, type=int, help='bitwidth for activation quantization')
parser.add_argument('--act_quant', action='store_true', help='apply activation quantization')
parser.add_argument('--disable_8bit_head_stem', action='store_true')
parser.add_argument('--test_before_calibration', action='store_true')
parser.add_argument('--bit_cfg', type=str, default="None")

# weight calibration parameters
parser.add_argument('--num_samples', default=1024, type=int, help='size of the calibration dataset')
parser.add_argument('--iters_w', default=20000, type=int, help='number of iteration for adaround')
parser.add_argument('--weight', default=0.01, type=float, help='weight of rounding cost vs the reconstruction loss.')
parser.add_argument('--sym', action='store_true', help='symmetric reconstruction, not recommended')
parser.add_argument('--b_start', default=20, type=int, help='temperature at the beginning of calibration')
parser.add_argument('--b_end', default=2, type=int, help='temperature at the end of calibration')
parser.add_argument('--warmup', default=0.2, type=float, help='in the warmup period no regularization is applied')
parser.add_argument('--step', default=20, type=int, help='record snn output per step')
parser.add_argument('--use_bias', action='store_true', help='fix weight bias and variance after quantization')
parser.add_argument('--vcorr', action='store_true', help='use variance correction')
parser.add_argument('--bcorr', action='store_true', help='use bias correction')

args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoint = torch.load("/home/admin1/Syh/Training-free-quant/PTQ/result/imagenet/resnet18/Test/checkpoint/model_best.pth.tar")
state_dict = checkpoint['state_dict']
# for key in state_dict.keys():
#     print(key)
print(state_dict['model.fc.weight'])
print(state_dict['model.fc.weight_quantizer.alpha'])
print('-'*150)
time.sleep(10)
# print(state_dict)
# print(type(state_dict))
# print("-"*100)
# print(state_dict["model.fc.weight"])

# build imagenet data loader
train_dataloader, val_dataloader, test_dataloader = build_imagenet_data(data_path=args.data_path, batch_size=args.batch_size, workers=args.workers)

model = eval('hubconf.{}(pretrained=True)'.format("resnet18"))
model.to(device)


# build quantization parameters
wq_params = {'n_bits': args.n_bits_w, 'channel_wise': args.channel_wise, 'scale_method': 'mse'}
aq_params = {'n_bits': args.n_bits_a, 'channel_wise': False, 'scale_method': 'mse', 'leaf_param': args.act_quant}
qnn = QuantModel(model=model, weight_quant_params=wq_params, act_quant_params=aq_params)

if not args.disable_8bit_head_stem:
        print('Setting the first and the last layer to 8-bit')
        qnn.set_first_last_layer_to_8bit()

if args.bit_cfg != "None":
        print('Setting each layer to different bit')
        qnn.set_mixed_precision(eval(args.bit_cfg))

def recon_model(model: nn.Module):
        """
        Block reconstruction. For the first and last layers, we can only apply layer reconstruction.
        """
        for name, module in model.named_children():
            if isinstance(module, QuantModule):
                if module.ignore_reconstruction is True:
                    print('Ignore reconstruction of layer {}'.format(name))
                    continue
                else:
                    print('Reconstruction for layer {}'.format(name))
                    layer_reconstruction(qnn, module, **kwargs)
            elif isinstance(module, BaseQuantBlock):
                if module.ignore_reconstruction is True:
                    print('Ignore reconstruction of block {}'.format(name))
                    continue
                else:
                    print('Reconstruction for block {}'.format(name))
                    block_reconstruction(qnn, module, **kwargs)
            else:
                recon_model(module)

cali_data = get_train_samples(train_dataloader, num_samples=args.num_samples)

# Initialize weight quantization parameters
qnn.set_quant_state(True, False)
_ = qnn(cali_data[:256].to(device))


# Kwargs for weight rounding calibration
kwargs = dict(cali_data=cali_data, iters=args.iters_w, weight=args.weight, asym=True,
                  b_range=(args.b_start, args.b_end), warmup=args.warmup, act_quant=False, opt_mode='mse')

# for name, _ in qnn.named_parameters():
#     print(name)
# print(qnn)
# print(qnn.model.layer4[1].conv2.weight_quantizer)
# time.sleep(10)

recon_model(qnn)
qnn.set_quant_state(weight_quant=True, act_quant=False)
qnn.set_bias_state(args.use_bias, args.vcorr, args.bcorr)

print(qnn.model.fc.weight)
print(qnn.model.fc.weight_quantizer.alpha)
time.sleep(1000)

# model_state_dict = model.state_dict()
# for name in model_state_dict.keys():
#     print(name)
# time.sleep(10)
# print(model_state_dict['layer1.0.conv1.weight'].shape)

# 去掉 'model.' 前缀
# new_state_dict = {}
# for k, v in state_dict.items():
#     if k.startswith('model.'):
#         # 去掉 'model.' 前缀
#         new_key = k[len('model.'):]
#         # 只保留不包含 '.weight_quantizer.alpha' 的参数
#         if '.weight_quantizer.alpha' not in new_key:
#             new_state_dict[new_key] = v
#     else:
#         # 保留原键，前提是没有 '.weight_quantizer.alpha'
#         if '.weight_quantizer.alpha' not in k:
#             new_state_dict[k] = v

# 过滤掉不需要的参数
# filtered_state_dict = {k: v for k, v in new_state_dict.items() if k in model_state_dict }







