import torch
from torchinfo import summary
    
def count_data_footprint(state_dict):
    counted_bytes = 0
    for key in state_dict:
        param = state_dict[key]
        if isinstance(param, torch.Tensor):
            val = 4
            for i in range(len(param.shape)):
                val *= param.shape[i]
            counted_bytes += val
    return counted_bytes

if __name__ == '__main__':

    import sys
    import json

    txtlen = len('memory')
    sys.path[0] = sys.path[0][:-txtlen]

    import argparse
    parser = argparse.ArgumentParser(description='Memory Measurements')
    parser.add_argument("--model", choices=['ResNet20', 'DenseNet40', 'ResNet32Large'], default='ResNet32Large')

    args = parser.parse_args()
    print(args)

    if args.model == 'ResNet20':
        from nets.SubsetNets.resnet_cifar import ResNet20 as Net
        INPUT_SHAPE = (32, 3, 32, 32)
        ratios = [1.0, 0.5, 0.25, 0.125]
    elif args.model == 'ResNet32Large':
        from nets.SubsetNets.resnet_cifar import ResNet32Large as Net
        INPUT_SHAPE = (32, 3, 64, 64)
        ratios = [1.0, 0.5, 0.25, 0.125]
    elif args.model == 'DenseNet40':
        from nets.SubsetNets.densenet_cifar import DenseNet40 as Net
        INPUT_SHAPE = (32, 3, 64, 64)
        ratios = [1.0, 0.66, 0.33]
    else: raise NotImplementedError

    from memory.mem_counted import training_mem, training_mem_individual
    
    #MEMORY OF SUBMODEL TECHNIQUES 
    for ratio in ratios:
        sd = Net().state_dict()
        mem = training_mem(Net, {'scale_factor' : ratio}, input_shape=INPUT_SHAPE)
        print('ratio ', ratio,' mem ', mem,)


    if args.model == 'ResNet20':
        from nets.SLTNets.resnet_cifar import ResNet20 as Net
        from utils.SLT_submodel import extract_submodel_resnet_structure as extract_submodel
    elif args.model == 'ResNet32Large':
        from nets.SLTNets.resnet_cifar import ResNet32Large as Net
        from utils.SLT_submodel import extract_submodel_resnet_structure as extract_submodel
    elif args.model == 'DenseNet40':
        from nets.SLTNets.densenet_cifar import DenseNet40 as Net
        from utils.SLT_submodel import extract_submodel_densenet_structure as extract_submodel
    else: raise NotImplementedError

    sd_standard = Net(scale_factor=1.0).state_dict()

    res = []
    configs = []

    #configs can be set here manually
    '''
    configs = [
                       [
                0.49,
                0.023255813953488372
            ],
            [
                0.51,
                0.046511627906976744
            ],
            ...
            ...
            ...

     ]
     '''

    #or imported from file
    
    with open('nets/SLTNets/configs/config__resnet32large.json', 'r') as fd:
        configs = json.load(fd)["0.125"]["values"]
    

    freeze_dicts_list = []
    full_keys = []
    data = 0


    #MANUALLY set freeze values
    freeze_values = []
    '''
    freeze_values = []
    '''

    for config_idx, config in enumerate(configs):

            sd = Net().state_dict()

            if len(freeze_values) != 0:
                freeze_value = freeze_values[config_idx]
            else:
                freeze_value = 1

            sd_reduced, indices, freeze_dict = extract_submodel(config[0], config[1], sd, training_depth=freeze_value)
            sd_reduced.update({'frozen' : freeze_dict})

            mem = training_mem(Net, {}, input_shape=INPUT_SHAPE, sd=sd_reduced)
            freeze_dicts_list.append(freeze_dict)
            
            print(('Config idx: ', config_idx,'freeze_value: ', freeze_value,
                   'head scale', config[0], 'filled up ratio: ', config[1], 'memory: ', mem))
            res.append((config[0], config[1], mem))