import torch
import sys
sys.path.append('./AdjoinedNetwork')
from model import xresnet50, xresnet101, xresnet18
from model import resnet152 as xresnet152
import time

def get_improvement_factor(batch_size, image_size, model, maodel_compressed, oncuda, number_iterations):
    model_time = 0.
    maodel_compressed_time = 0.
    for i in range(number_iterations):
        inp = torch.rand(batch_size, 3, image_size, image_size)
    
        if oncuda:
            inp = inp.cuda()
            model = model.cuda()
            maodel_compressed = maodel_compressed.cuda()

        start = time.time()
        pred = model(inp)
        model_time += (time.time() - start)
        
        start = time.time()
        pred = maodel_compressed(inp)
        maodel_compressed_time += (time.time() - start)
    
    
    '''
    print("Model big elasped time = {}".format(model_time) )
    print("Model small elasped time = {}".format(maodel_compressed_time) )
    '''

    return model_time/maodel_compressed_time

def count_parameters(model, random_mask=False, mask_factor=0.9): 
    s = 0
    for name, p in model.named_parameters():
        if name.startswith('10'): continue
        if not p.requires_grad: continue
        n_elems = p.numel()

        if random_mask and (name.startswith('5') or name.startswith('6') or name.startswith('7')):    
            t_n = n_elems
            n_elems = (1-mask_factor)*n_elems
            '''
            print(name, t_n, n_elems)
        else: print(name, n_elems)
        '''
        s += n_elems
    return s

def analyze_size_with_random_masking(): 
    models = [xresnet18()]

    for ratio in [2, 4, 16]:
        print("\n Analyzing at alpha = {}".format(ratio))
        
        models_compressed = [xresnet18(compression_factor=ratio)]
        i = 0
        for model, model_compressed in zip(models, models_compressed):
            params = count_parameters(model)
            params_compressed = count_parameters(model_compressed, random_mask=True, mask_factor=0.5)
            model_type = "res50"
            print("Size mprovement factor for {} = {}".format(model_type, params/params_compressed) )
            i += 1
        

def analyze_size(): 
    models = [xresnet18(), xresnet50(), xresnet101(), xresnet152()]

    for ratio in [4, 8, 16, 32]:
        print("\n Analyzing at alpha = {}".format(ratio))
        
        models_compressed = [xresnet18(compression_factor=ratio), xresnet50(compression_factor=ratio), xresnet101(compression_factor=ratio), xresnet152(compression_factor=ratio)]
        i = 0
        for model, model_compressed in zip(models, models_compressed):
            params = count_parameters(model)
            params_compressed = count_parameters(model_compressed)
            model_type = "res18"
            if i == 1: model_type = "res50"
            if i == 2: model_type = "res101"
            if i == 3: model_type = "res152"
            print("Size mprovement factor for {} = {}".format(model_type, params/params_compressed) )
            i += 1
        
        print("+"*30)
        model_xresnet18 = xresnet18()
        params = count_parameters(model_xresnet18)
        i = 0
        for model_compressed in models_compressed:
            if i == 0:
                i += 1
                continue
            params_compressed = count_parameters(model_compressed)
            model_type = "res50"
            if i == 2: model_type = "res101"
            if i == 3: model_type = "res152"
            print("Size improvement compared to res18 for {} = {}".format(model_type, params/params_compressed) )
            i += 1

def analyze_time(batch_size, ratio, number_iterations, oncuda):

    image_sizes = [64, 128, 224, 256]
    models = [xresnet50(), xresnet101(), xresnet152()]
    models_compressed = [xresnet50(compression_factor=ratio), xresnet101(compression_factor=ratio), xresnet152(compression_factor=ratio)]
    
    for image_size in image_sizes:
        i = 0
        for model, model_compressed in zip(models, models_compressed):

            imp_factor = get_improvement_factor(batch_size, image_size, model, model_compressed, oncuda, number_iterations)
            model_type = "xresnet50"
            if i == 1: model_type = "xresnet101"
            if i == 2: model_type = "xresnet152"
            print("Improvement factor for {} at size ({}, {}) = {}".format(model_type, image_size, image_size, imp_factor))

            i += 1            


if __name__ == "__main__":
    batch_size = 1
    number_iterations = 5
    ratio = 4
    oncuda = False
    only_size = True


    #analyze_size()
    analyze_size_with_random_masking()
    if only_size: exit()
    analyze_time(batch_size, ratio, number_iterations, oncuda) 
