import torch
import torch.nn as nn
import argparse
import os
import random
import numpy as np
import time
import hubconf
from data.imagenet import build_imagenet_data
from quant import *
from metrics import *
import statistics
import numpy as np
import logging
import shutil
from quant.fold_bn import search_fold_and_reset_bn, search_fold_and_remove_bn

def seed_all(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def hook_fn_forward(module, input, output):
    feature_maps[module] = output

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)
    
class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res
    
@torch.inference_mode()
def validate_model(val_dataloader, model, device=None, print_freq=100, step=0):
    if device is None:
        device = next(model.parameters()).device
    else:
        model.to(device)
    batch_time = AverageMeter('Time', ':6.3f')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_dataloader),
        [batch_time, top1, top5],
        prefix='Test: '
    )

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (images, labels) in enumerate(val_dataloader):
        images = images.to(device)
        labels = labels.to(device)

        # compute output
        output = model(images)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            progress.display(i)

    print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))
    logging.info('step: {step}  * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(step=step, top1=top1, top5=top5))

    return top1.avg

def get_train_samples(train_dataloader, num_samples):
    train_data = []
    for batch in train_dataloader:
        train_data.append(batch[0])
        if len(train_data) * batch[0].size(0) >= num_samples:
            break
    return torch.cat(train_data, dim=0)[:num_samples]

def save_checkpoint(state, is_best, filename=None):
    if not os.path.isdir(filename):
        os.makedirs(filename)
    torch.save(state, os.path.join(filename,'checkpoint.pth.tar'))
    if is_best:
        shutil.copyfile(os.path.join(filename,'checkpoint.pth.tar'), os.path.join(filename,'model_best.pth.tar'))

best_acc1 = 0
feature_maps = {}

def main():
    global best_acc1
    parser = argparse.ArgumentParser(description='running parameters', formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # general parameters for data and model
    parser.add_argument('--seed', default=1005, type=int, help='random seed for results reproduction')
    parser.add_argument('--arch', default='resnet18', type=str, help='dataset name', choices=['resnet18', 'mobilenetv2'])
    parser.add_argument('--batch_size', default=256, type=int, help='mini-batch size for data loader')
    parser.add_argument('--workers', default=24, type=int, help='number of workers for data loader')
    parser.add_argument('--data_path', default='', type=str, help='path to ImageNet data', required=True)
    parser.add_argument('--save_path', type=str, default='/home/admin1/Syh/Training-free-quant/PTQ/result/imagenet')
    parser.add_argument('--name', type=str, default='3.0Mb_batch=64_a8')
    # parser.add_argument('--step', type=int, default=5)

    # quantization parameters
    parser.add_argument('--n_bits_w', default=4, type=int, help='bitwidth for weight quantization')
    parser.add_argument('--channel_wise', action='store_true', help='apply channel_wise quantization for weights')
    parser.add_argument('--n_bits_a', default=4, type=int, help='bitwidth for activation quantization')
    parser.add_argument('--act_quant', action='store_true', help='apply activation quantization')
    parser.add_argument('--disable_8bit_head_stem', action='store_true')
    parser.add_argument('--test_before_calibration', action='store_true')
    parser.add_argument('--bit_cfg', type=str, default="None")

    # weight calibration parameters
    parser.add_argument('--num_samples', default=1024, type=int, help='size of the calibration dataset')
    parser.add_argument('--iters_w', default=20000, type=int, help='number of iteration for adaround')
    parser.add_argument('--weight', default=0.01, type=float, help='weight of rounding cost vs the reconstruction loss.')
    parser.add_argument('--sym', action='store_true', help='symmetric reconstruction, not recommended')
    parser.add_argument('--b_start', default=20, type=int, help='temperature at the beginning of calibration')
    parser.add_argument('--b_end', default=2, type=int, help='temperature at the end of calibration')
    parser.add_argument('--warmup', default=0.2, type=float, help='in the warmup period no regularization is applied')
    parser.add_argument('--step', default=20, type=int, help='record snn output per step')
    parser.add_argument('--use_bias', action='store_true', help='fix weight bias and variance after quantization')
    parser.add_argument('--vcorr', action='store_true', help='use variance correction')
    parser.add_argument('--bcorr', action='store_true', help='use bias correction')

    # activation calibration parameters
    parser.add_argument('--iters_a', default=5000, type=int, help='number of iteration for LSQ')
    parser.add_argument('--lr', default=4e-4, type=float, help='learning rate for LSQ')
    parser.add_argument('--p', default=2.4, type=float, help='L_p norm minimization for LSQ')

    args = parser.parse_args()
    args.save_path = os.path.join(args.save_path, args.arch, args.name)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    if not os.path.exists(os.path.join(args.save_path, 'logs')):
        os.makedirs(os.path.join(args.save_path, 'logs'))

    if not os.path.exists(os.path.join(args.save_path, 'checkpoint')):
        os.makedirs(os.path.join(args.save_path, 'checkpoint'))

    # currentTime = time.strftime("%H:%M:%S")
    logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%d-%b-%y %H:%M:%S', filename=args.save_path + '/logs/' + 'log.txt')
    logging.getLogger().setLevel(logging.INFO)
    logging.getLogger().addHandler(logging.StreamHandler())

    logging.info(args)

    bit_config_list = []
    bit_cfg = args.bit_cfg[1:-1].split(",")
    for i in range(len(bit_cfg)):
        bit_config_list.append(int(bit_cfg[i]))
    # print(bit_config_list)
    # time.sleep(10)

    # seed_all(args.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.backends.cudnn.benchmark = True

    # build imagenet data loader
    train_dataloader, val_dataloader, test_dataloader = build_imagenet_data(data_path=args.data_path, batch_size=args.batch_size, workers=args.workers)

    data = next(iter(test_dataloader))
    data = data[0].to("cuda")

    # load model
    model = eval('hubconf.{}(pretrained=True)'.format(args.arch))
    model.to(device)
    model.eval()
    # print(model)
    # time.sleep(10)
    # for name, param in model.named_parameters():
    #     # print(param)
    #     parameters1.append(param)
    # print("-"*80)
    # time.sleep(10)

    # build quantization parameters
    wq_params = {'n_bits': args.n_bits_w, 'channel_wise': args.channel_wise, 'scale_method': 'mse'}
    aq_params = {'n_bits': args.n_bits_a, 'channel_wise': False, 'scale_method': 'mse', 'leaf_param': args.act_quant}
    qnn = QuantModel(model=model, weight_quant_params=wq_params, act_quant_params=aq_params)

    for name, module in qnn.named_modules():
        if isinstance(module, QuantModule):  # 仅对 QuantModule 类型的层注册钩子
            if 'conv1' in name or 'conv2' in name or 'fc' in name:
                module.register_forward_hook(hook_fn_forward)
                # print(name)
        # print(module)

    qnn.to(device)
    qnn.eval()

    if not args.disable_8bit_head_stem:
        print('Setting the first and the last layer to 8-bit')
        qnn.set_first_last_layer_to_8bit()
        bit_config_list[0] = 8
        bit_config_list[-1] = 8
        print(bit_config_list)

    if args.bit_cfg != "None":
        print('Setting each layer to different bit')
        qnn.set_mixed_precision(eval(args.bit_cfg))
        # for name, param in qnn.named_parameters():
        #     print(param)
        # print("-"*80)

    def recon_model(model: nn.Module):
        """
        Block reconstruction. For the first and last layers, we can only apply layer reconstruction.
        """
        for name, module in model.named_children():
            if isinstance(module, QuantModule):
                if module.ignore_reconstruction is True:
                    print('Ignore reconstruction of layer {}'.format(name))
                    continue
                else:
                    print('Reconstruction for layer {}'.format(name))
                    layer_reconstruction(qnn, module, **kwargs)
            elif isinstance(module, BaseQuantBlock):
                if module.ignore_reconstruction is True:
                    print('Ignore reconstruction of block {}'.format(name))
                    continue
                else:
                    print('Reconstruction for block {}'.format(name))
                    block_reconstruction(qnn, module, **kwargs)
            else:
                recon_model(module)

    cali_data = get_train_samples(train_dataloader, num_samples=args.num_samples)

    # Test
    # x = torch.randn(256, 3, 224, 224).to(device)
    # with torch.no_grad():
    #     _ = qnn(x)
    
    # # print(feature_maps)
    # for value in feature_maps.values():
    #     print(value.shape)
    # print(len(feature_maps))

    # features = []
    # with torch.no_grad():
    #     _ = qnn(data)

    # for value in feature_maps.values():
    #     print(value.shape)
    #     print("-"*100)
    #     # print(torch.sum(value[1], dim=1).shape)
    #     if value.dim() != 2:
    #         features.append(torch.sum(value, dim=1))
    #     else:
    #         features.append(value)
    
    # for i in range(0, len(features)):
    #     print(features[i].shape)
    # time.sleep(10)

    # Initialize weight quantization parameters
    qnn.set_quant_state(True, False)
    _ = qnn(cali_data[:256].to(device))
    # parameters2 = []
    # for name, param in qnn.named_parameters():
    #     parameters2.append(param)
    # print(torch.equal(parameters1[-1], parameters2[-1]))
    # time.sleep(10)

    if args.test_before_calibration:
        # qnn.set_bias_state(args.use_bias, args.vcorr, args.bcorr)
        print('Quantized accuracy before brecq: {}'.format(validate_model(val_dataloader, qnn)))

        # feature extract
        features = []
        # qnn.eval()
        with torch.no_grad():
            _ = qnn(data)

        for value in feature_maps.values():
            if value.dim() != 2:
                features.append(torch.sum(value, dim=1))
            else:
                features.append(value)

        entropy = cal_score(features=features, batch_size=args.batch_size)
        length = len(entropy)
        sum = 0
        for i in range(length):
            sum += bit_config_list[i] * entropy[i]

        sum_entropy = np.sum(entropy)
        print("Sum of bit * entropy = {}".format(sum))
        print("Sum of entropy = {}".format(sum_entropy))
        print("entropy = {}".format(entropy))
        print("entropy_norm = {}".format((entropy / sum_entropy).tolist()))
        logging.info("Sum of bit * entropy = {}".format(sum))
        logging.info("Sum of entropy = {}".format(sum_entropy))
        logging.info("entropy = {}".format(entropy))
        logging.info("entropy_norm = {}".format((entropy / sum_entropy).tolist()))

        # qnn.set_bias_state(False, False, False)

    # Kwargs for weight rounding calibration
    kwargs = dict(cali_data=cali_data, iters=args.iters_w, weight=args.weight, asym=True,
                  b_range=(args.b_start, args.b_end), warmup=args.warmup, act_quant=False, opt_mode='mse')
    
    Weight_quant_acc_top1 = []
    # Weight_quant_acc_top5 = []
    for i in range(100):
        print("Step {}".format(i+1))
        # Start calibration
        # qnn.train()
        recon_model(qnn)
        qnn.set_quant_state(weight_quant=True, act_quant=False)
        qnn.set_bias_state(args.use_bias, args.vcorr, args.bcorr)
        acc1 = validate_model(val_dataloader, qnn, step=i+1, device=device)
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        
        if is_best:
            # feature extract
            features = []
            # qnn.eval()
            with torch.no_grad():
                _ = qnn(data)

            for value in feature_maps.values():
                if value.dim() != 2:
                    features.append(torch.sum(value, dim=1))
                else:
                    features.append(value)

            # for i in range(0, len(features)):
            #     print(features[i].shape)

            entropy = cal_score(features=features, batch_size=args.batch_size)
            length = len(entropy)
            sum = 0
            for i in range(length):
                sum += bit_config_list[i] * entropy[i]

            sum_entropy = np.sum(entropy)
            print("Sum of bit * entropy = {}".format(sum))
            print("Sum of entropy = {}".format(sum_entropy))
            print("entropy = {}".format(entropy))
            print("entropy_norm = {}".format((entropy / sum_entropy).tolist()))
            logging.info("Sum of bit * entropy = {}".format(sum))
            logging.info("Sum of entropy = {}".format(sum_entropy))
            logging.info("entropy = {}".format(entropy))
            logging.info("entropy_norm = {}".format((entropy / sum_entropy).tolist()))

        Weight_quant_acc_top1.append(acc1.item())

        save_checkpoint({
                        'arch': args.arch,
                        'state_dict': qnn.state_dict(),
                    }, is_best, os.path.join(args.save_path, 'checkpoint'))
        
        # Weight_quant_acc_top5.append(acc5.item())
        print('{} - Weight quantization accuracy: {}'.format(i+1, acc1.item()))
        qnn.set_bias_state(False, False, False)

    if args.act_quant:
        # Initialize activation quantization parameters
        qnn.set_quant_state(True, True)
        with torch.no_grad():
            _ = qnn(cali_data[:256].to(device))
        # Disable output quantization because network output
        # does not get involved in further computation
        qnn.disable_network_output_quantization()
        # Kwargs for activation rounding calibration
        kwargs = dict(cali_data=cali_data, iters=args.iters_a, act_quant=True, opt_mode='mse', lr=args.lr, p=args.p)
        recon_model(qnn)
        qnn.set_quant_state(weight_quant=True, act_quant=True)
        qnn.set_bias_state(args.use_bias, args.vcorr, args.bcorr)
        acc1 = validate_model(val_dataloader, qnn)
        print('Full quantization (W{}A{}) accuracy: {}'.format(args.n_bits_w, args.n_bits_a, acc1))
        logging.info('Full quantization (W{}A{}) accuracy: {}'.format(args.n_bits_w, args.n_bits_a, acc1))
        qnn.set_bias_state(False, False, False)

    print("-"*100)

    print("Acc@1 Result: {}".format(Weight_quant_acc_top1))
    logging.info("Acc@1 Result: {}".format(Weight_quant_acc_top1))
    # print("Acc@5 Result: {}".format(Weight_quant_acc_top5))

    # n = len(Weight_quant_acc_top1)
    # sqrt_n = np.sqrt(n)

    average_top1 = statistics.mean(Weight_quant_acc_top1)
    std_1 = np.std(Weight_quant_acc_top1, ddof=1)
    print("Acc@1 : Average={:.2f}, Std={:.2f}".format(average_top1, std_1))
    logging.info("Acc@1 : Average={:.2f}, Std={:.2f}".format(average_top1, std_1))

    # average_top5 = statistics.mean(Weight_quant_acc_top5)
    # std_5 = np.std(Weight_quant_acc_top5, ddof=1)
    # print("Acc@5 : Average={:.2f}, Std={:.2f}".format(average_top5, std_5))

if __name__ == '__main__':
    main()