import os
import argparse
import numpy as np
import math
from scipy import optimize
import random
import time
import torch
from run_manage import RunManager
from models import MobileNetV2_ImageNet, MobileNetV2_CIFAR10, TrainRunConfig, ResNet_ImageNet, ResNet_CIFAR10
from utils.pytorch_utils import DFS_bit
from metrics import *
import ORM
from torchsummary import summary

parser = argparse.ArgumentParser()

""" model config """
parser.add_argument('--path', type=str)
parser.add_argument('--model', type=str, default="vgg", choices=['resnet50', 'mobilenetv2', 'mobilenet', 'resnet18'])
parser.add_argument('--cfg', type=str, default="None")
parser.add_argument('--manual_seed', default=0, type=int)
parser.add_argument("--model_size", default=0, type=float)
parser.add_argument("--beta", default=1, type=float)
parser.add_argument('--quant_type', type=str, default='QAT', choices=['QAT', 'PTQ'])

""" dataset config """
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet'])
parser.add_argument('--save_path', type=str, default='/Path/to/Dataset')

""" runtime config """
parser.add_argument('--gpu', help='gpu available', default='0')
parser.add_argument('--train_batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=24)
parser.add_argument("--local_rank", default=0, type=int)

device = "cuda" if torch.cuda.is_available() else "cpu"

def main():
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    torch.cuda.set_device(0)

    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed_all(args.manual_seed)
    np.random.seed(args.manual_seed)

    # distributed setting
    # torch.distributed.init_process_group(backend='nccl', init_method='env://')
    # args.world_size = torch.distributed.get_world_size()

    # prepare run config
    run_config_path = '{}/run.config'.format(args.path)
    run_config = TrainRunConfig(
        **args.__dict__
    )
    if args.local_rank == 0:
        print('Run config:')
        for k, v in args.__dict__.items():
            print('\t%s: %s' % (k, v))

    if args.model == "resnet50":
        # assert args.dataset == 'imagenet', 'resnet50 only supports imagenet dataset'
        if args.dataset == 'imagenet':
            net = ResNet_ImageNet(depth=50, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
        else:
            # print(run_config.data_provider.n_classes)
            # time.sleep(10)
            net = ResNet_CIFAR10(depth=50, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
    elif args.model == "mobilenetv2":
        # assert args.dataset == 'imagenet', 'mobilenetv2 only supports imagenet dataset'
        if args.dataset == 'imagenet':
            net = MobileNetV2_ImageNet(num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
        else:
            # print(run_config.data_provider.n_classes)
            # time.sleep(10)
            net = MobileNetV2_CIFAR10(num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
        # print('-'*50)
        # print(run_config.data_provider.data_shape)
        # time.sleep(20)
    elif args.model == "resnet18":
        # assert args.dataset == 'imagenet', 'resnet18 only supports imagenet dataset'
        if args.dataset == 'imagenet':
            net = ResNet_ImageNet(depth=18, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
            # net.to(device)
            # summary(net, (3, 224, 224))
            # time.sleep(10)
        else:
            # print(run_config.data_provider.n_classes)
            # time.sleep(10)
            net = ResNet_CIFAR10(depth=18, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
            
    # build run manager
    run_manager = RunManager(args.path, net, run_config)

    # load checkpoints
    best_model_path = '%s/checkpoint/model_best.pth.tar' % args.path
    print(best_model_path)
    assert os.path.isfile(best_model_path), 'wrong path'
    if torch.cuda.is_available():
        checkpoint = torch.load(best_model_path)
    else:
        checkpoint = torch.load(best_model_path, map_location='cpu')
    if 'state_dict' in checkpoint:
        checkpoint = checkpoint['state_dict']
    # print("-"*80)
    # print(checkpoint)
    # time.sleep(20)
    run_manager.net.load_state_dict(checkpoint, strict=False)
    output_dict = {}

    # feature extract
    data_loader = run_manager.run_config.train_dataloader
    data = next(iter(data_loader))
    data = data[0].to(device)
    n = data.size()[0]

    start = time.time()
    with torch.no_grad():
        features = net.feature_extract(data, args.quant_type)

    # for i in range(0, len(features)):
    #     print(features[i].shape)
    # time.sleep(10)

    ### ORM:
    # for i in range(len(features)):
    #     features[i] = features[i].view(n, -1)
    #     features[i] = features[i].data.cpu().numpy()

    # orthogonal_matrix = np.zeros((len(features), len(features)))

    # for i in range(len(features)):
    #     for j in range(len(features)):
    #         with torch.no_grad():
    #             orthogonal_matrix[i][j] = ORM.orm(ORM.gram_linear(features[i]), ORM.gram_linear(features[j]))

    # # sum1 = 0
    # # for i in range(len(features)):
    # #     for j in range(len(features)):
    # #         if i < j:
    # #             sum1 += orthogonal_matrix[i][j]
            
    # # print(sum1)
    # # time.sleep(10)

    # def sum_list(a, j):
    #     b = 0
    #     for i in range(len(a)):
    #         if i != j:
    #             b += a[i]
    #     return b

    # theta = []
    # gamma = []
    # flops = []

    # for i in range(len(features)):
    #     gamma.append( sum_list(orthogonal_matrix[i], i) )

    # # e^-x
    # for i in range(len(features)):
    #     theta.append( 1 * math.exp(-1* args.beta *gamma[i]))
    # theta = np.array(theta)
    # theta = np.negative(theta)

    # length = len(features)
    # # layerwise
    # params, first_last_size = net.cfg2params_perlayer(net.cfg, length, args.quant_type)
    # FLOPs, first_last_flops = net.cfg2flops_layerwise(net.cfg, length, args.quant_type)
    # params = [i/(1024*1024) for i in params]
    # first_last_size = first_last_size/(1024*1024)


    # # Objective function
    # def func(x, sign=1.0, theta=theta, length=length):
    #     """ Objective function """
    #     sum_fuc =[]
    #     for i in range(length):
    #         temp = 0.
    #         for j in range(i,length):
    #             temp += theta[j]
    #         sum_fuc.append( x[i] * (sign * temp / (length-i)) )

    #     return sum(sum_fuc)

    # # Derivative function of objective function
    # def func_deriv(x, sign=1.0, theta=theta, length=length):
    #     """ Derivative of objective function """
    #     diff = []
    #     for i in range(length):
    #         temp1 = 0.
    #         for j in range(i, length):
    #             temp1 += theta[j]
    #         diff.append(sign * temp1 / (length - i))

    #     return np.array(diff)

    # # Constraint function
    # def constrain_func(x, params=params, length=length):
    #     """ constrain function """
    #     a = []
    #     for i in range(length):
    #         a.append(x[i] * params[i])
    #     return np.array([args.model_size - first_last_size - sum(a)])

    # bnds = [] # bit search space: (0.25,0.5) for PTQ and (0.5,1.0) for QAT
    # if args.quant_type == 'PTQ':
    #     for i in range(length):
    #         bnds.append((0.25, 0.5))
    # else:
    #     for i in range(length):
    #         bnds.append((0.5, 1.0))

    # bnds = tuple(bnds)
    # cons = ({'type': 'ineq',
    #          'fun': constrain_func}
    #         )

    # result = optimize.minimize(func,x0=[1 for i in range(length)], jac=func_deriv, method='SLSQP', bounds=bnds, constraints=cons)

    # if args.model == "resnet18":
    #     prun_bitcfg, _ = DFS_bit(result.x[::-1] * 8, [params[length - i - 1] for i in range(length)])
    #     prun_bitcfg = [prun_bitcfg[length - i - 1] for i in range(length)]
    # else:
    #     prun_bitcfg = np.around(result.x * 8)
    # end = time.time()
    # print("Use", end - start, "seconds. ")


    # optimize_cfg = []
    # if type(prun_bitcfg[0]) != int:
    #     for i in range(len(prun_bitcfg)):
    #         b = list(prun_bitcfg)[i].tolist()
    #         optimize_cfg.append(int(b))
    # else:
    #     optimize_cfg =prun_bitcfg
    # # print(result.x)
    # print(optimize_cfg)
    # print("Quantization model is", np.sum(np.array(optimize_cfg) * np.array(params) / 8) + first_last_size, "Mb")
    # print("Original model is", np.sum(np.array(params)) * 4 + first_last_size * 4, "Mb")
    # print('Quantization model BOPs is',
    #       (first_last_flops * 8*8 + sum([FLOPs[i] * optimize_cfg[i] *5 for i in range(length)])) / 1e9)

    # ### Our:
    length = len(features)

    if args.quant_type == "PTQ":
        for i in range(len(features[:-1])):
            features[i] = torch.sum(features[i], dim=1)
            # print(features[i].shape)
    else:
        for i in range(len(features)):
            features[i] = torch.sum(features[i], dim=1)

    # entropy every layer
    entropy = cal_score(features=features, batch_size=n)
    # print(entropy)
    sum_entropy = np.sum(entropy)
    print("Sum = {}".format(sum_entropy))
    print("entropy = {}".format(entropy))
    print("entropy_norm = {}".format((entropy / sum_entropy).tolist()))
    
    print(len(entropy))
    # for name, param in net.named_parameters():
    #     print(f"Parameter name: {name}")
    #     # print(f"Parameter size: {param.size()}")
    #     # print(param)
    #     print("\n")
    print("-"*80)
    # time.sleep(10)

    theta = []

    # e^-x
    for i in range(len(entropy)):
        theta.append(1 * math.exp(-1 * args.beta * entropy[i]))
    theta = np.array(theta)
    theta = np.negative(theta)

    # layerwise
    params, first_last_size = net.cfg2params_perlayer(net.cfg, length, args.quant_type)
    FLOPs, first_last_flops = net.cfg2flops_layerwise(net.cfg, length, args.quant_type)
    params = [i/(1024*1024) for i in params]
    first_last_size = first_last_size/(1024*1024)

    # Objective function
    def func(x, sign=1.0, theta=theta,length=length):
        sum_func = []
        for i in range(length):
            temp = theta[i]
            sum_func.append(x[i] * sign * temp)
        return sum(sum_func)


    # Constraint function
    def constrain_func(x, params=params, length=length):
        """ constrain function """
        a = []
        for i in range(length):
            a.append(x[i] * params[i])
        return np.array([args.model_size - first_last_size - sum(a)])
    
    bnds = [] # bit search space: (0.25,0.5) for PTQ and (0.5,1.0) for QAT
    if args.quant_type == 'PTQ':
        for i in range(length):
            bnds.append((0.25, 0.5))
    else:
        for i in range(length):
            bnds.append((0.5, 1.0))

    bnds = tuple(bnds)
    cons = ({
        'type': 'ineq',
        'fun': constrain_func
        })
    
    result = optimize.minimize(func,x0=[1 for i in range(length)], method='SLSQP', bounds=bnds, constraints=cons)

    if args.model == "resnet18":
        prun_bitcfg, _ = DFS_bit(result.x[::-1] * 8, [params[length - i - 1] for i in range(length)])
        prun_bitcfg = [prun_bitcfg[length - i - 1] for i in range(length)]
    else:
        prun_bitcfg = np.around(result.x * 8)
    end = time.time()
    print("Use", end - start, "seconds. ")


    optimize_cfg = []
    if type(prun_bitcfg[0]) != int:
        for i in range(len(prun_bitcfg)):
            b = list(prun_bitcfg)[i].tolist()
            optimize_cfg.append(int(b))
    else:
        optimize_cfg =prun_bitcfg
    # print(result.x)
    print(optimize_cfg)
    print("Quantization model is", np.sum(np.array(optimize_cfg) * np.array(params) / 8) + first_last_size, "Mb")
    print("Original model is", np.sum(np.array(params)) * 4 + first_last_size * 4, "Mb")
    print('Quantization model BOPs is',
          (first_last_flops * 8*8 + sum([FLOPs[i] * optimize_cfg[i] *5 for i in range(length)])) / 1e9)

if __name__ == '__main__':
    main()