import argparse
import random
import copy
import json
import tqdm
import numpy as np
from multiprocessing import Process, Queue
import gc

import resource
import torch
from timeit import default_timer as timer

def filter_state_dict(state_dict, model):
    for name, param in model.named_parameters():
        if not param.requires_grad:
            state_dict.pop(name)
    return state_dict

def sizeof_state_dict(state_dict):
    size_in_bytes = 0
    for key in state_dict:
        tensor_dim = 1
        for dim in state_dict[key].shape:
            tensor_dim *= dim
        size_in_bytes += 4*tensor_dim #conversion to bytes
    return size_in_bytes
    

def profile(model_class, model_args, q, n_batches=16):
    # set number of threads torch can use
    torch.set_num_threads(4)
    torch.backends.quantized.engine = 'fbgemm' if 'fbgemm' in torch.backends.quantized.supported_engines else 'qnnpack'
    lr = 0.005
    torch_device = torch.device('cpu')
    time_per_epoch_forward = []
    time_per_epoch_bw = []

    #previous memory is considered overhead
    max_mem_start = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/(10**6)
    model = model_class(**model_args)
    model.to(torch_device)
    model.load_state_dict(torch.load('tmp/state_dict.pt'), strict=False)

    inputs, targets = torch.load('tmp/inputs.pt'), torch.load('tmp/targets.pt')

    optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=lr)
    loss_f = torch.nn.CrossEntropyLoss()
    model.train()

    for _ in range(n_batches):
        inputs, targets = inputs.to(torch_device), targets.to(torch_device)

        t_start = timer()
        outputs = model(inputs)
        t_end = timer()

        time_per_epoch_forward.append(t_end - t_start)
        
        t_start = timer()
        loss = loss_f(outputs, targets)
        loss.backward()             
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        t_end = timer()
        del loss
        del outputs
        gc.collect()
        
        time_per_epoch_bw.append(t_end - t_start)
    max_mem_end = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/(10**6)

    data_down = sizeof_state_dict(torch.load("tmp/state_dict.pt"))
    data_up = sizeof_state_dict(filter_state_dict(model.state_dict(), model))

    q.put([time_per_epoch_forward, time_per_epoch_bw, data_up, data_down, max_mem_end - max_mem_start])
    return None


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(description='Process some integers.')

    parser.add_argument('--algorithm', default='CoCoFL', const='all', nargs='?', choices=['Subset', 'CoCoFL'])
    parser.add_argument('--model', default='QResNet18', const='all', nargs='?', choices=[
                            'SResNet18', 'QResNet18', 'QResNet50', 'SResNet50', #for femnist/cifar/cinic experiments
                            'MobileNet', 'QMobileNet', 'FMobileNet', 'SMobileNet', #for femnist/cifar/cinic experiments
                            'MobileNetLarge', 'QMobileNetLarge', 'SMobileNetLarge', #for xchest experiments
                            'QDenseNet', 'SDenseNet', #for femnist/cifar/cinic experiments
                            'QTransformer', 'STransformer', #for IMDB experiments
                            'QTransformerSeq2Seq', 'STransformerSeq2Seq']) # for shakespeare experiments
    parser.add_argument('--architecture', default='x64', const='all', nargs='?', choices=['x64', 'arm'])
    parser.add_argument("--epochs", default=1)

    args = parser.parse_args()
    print(args)

    if args.algorithm == 'Subset':
        if args.model == 'SResNet18':
            from nets.SubsetNets.ResNet.resnet import SResNet18
            model = SResNet18
        elif args.model == 'SResNet50':
            from nets.SubsetNets.ResNet.resnet import SResNet50
            model = SResNet50
        elif args.model == 'SMobileNet':
            from nets.SubsetNets.MobileNet.mobilenet import SMobileNet
            model = SMobileNet
        elif args.model =='SDenseNet':
            from nets.SubsetNets.DenseNet.densenet import SDenseNet40
            model = SDenseNet40
        elif args.model == 'SMobileNetLarge':
            from nets.SubsetNets.MobileNet.mobilenet import SMobileNetLarge
            model = SMobileNetLarge
        elif args.model == 'STransformer':
            from nets.SubsetNets.Transformer.transformer import STransformer
            model = STransformer
        elif args.model == 'STransformerSeq2Seq':
            from nets.SubsetNets.Transformer.transformer import STransformerSeq2Seq
            model = STransformerSeq2Seq

    elif args.algorithm == 'CoCoFL':
        if 'ResNet18' in args.model:
            from nets.QuantizedNets.ResNet.resnet import QResNet18
            if args.model == 'QResNet18': model = QResNet18
            else: raise ValueError(args.model)
        elif 'MobileNet' in args.model:
            from nets.QuantizedNets.MobileNet.mobilenet import MobileNet, QMobileNet, FMobileNet, QMobileNetLarge
            if args.model == 'MobileNet': model = MobileNet
            elif args.model == 'QMobileNet': model = QMobileNet
            elif args.model == 'FMobileNet': model = FMobileNet
            elif args.model == 'QMobileNetLarge': model = QMobileNetLarge
            else: raise ValueError(args.model)
        elif 'ResNet50' in args.model:
            from nets.QuantizedNets.ResNet.resnet import QResNet50
            if args.model == 'QResNet50': model = QResNet50
            else: raise ValueError(args.model)
        elif 'DenseNet' in args.model:
            from nets.QuantizedNets.DenseNet.densenet import  QDenseNet40
            if args.model == 'QDenseNet': model = QDenseNet40
            else: raise ValueError(args.model)
        elif args.model == 'QTransformer':
            from nets.QuantizedNets.Transformer.transformer import QTransformer
            model = QTransformer
        elif args.model == 'QTransformerSeq2Seq':
            from nets.QuantizedNets.Transformer.transformer import QTransformerSeq2Seq
            model = QTransformerSeq2Seq

    torch.save(model().state_dict(), "tmp/state_dict.pt")

    if 'ResNet' in args.model or 'DenseNet' in args.model:
        torch.save(torch.rand((32, 3, 32, 32), dtype=torch.float), "tmp/inputs.pt")
        torch.save(torch.randint(0, 9, (32,), dtype=torch.long), "tmp/targets.pt")
    elif 'MobileNetLarge' in args.model:
        torch.save(torch.rand((32, 3, 256, 256), dtype=torch.float), "tmp/inputs.pt")
        torch.save(torch.randint(0, 9, (32,), dtype=torch.long), "tmp/targets.pt")
    elif 'MobileNet' in args.model:
        torch.save(torch.rand((32, 3, 32, 32), dtype=torch.float), "tmp/inputs.pt")
        torch.save(torch.randint(0, 9, (32,), dtype=torch.long), "tmp/targets.pt")
    elif 'TransformerSeq2Seq' in args.model:
        torch.save(torch.randint(0, 79, (32, 80), dtype=torch.long), "tmp/inputs.pt")
        torch.save(torch.randint(0, 80, (32,), dtype=torch.long), "tmp/targets.pt")
    elif 'Transformer' in args.model:
        torch.save(torch.randint(0, 15999, (32, 512), dtype=torch.long), "tmp/inputs.pt")
        torch.save(torch.randint(0, 1, (32,), dtype=torch.long), "tmp/targets.pt")

    if args.algorithm == 'CoCoFL':
        list_of_configs = []

        for k in range(1, model.n_freezable_layers()):
            for i in range(0, model.n_freezable_layers()):
                    config = list(range(0, model.n_freezable_layers()))
                    try:
                        for x in range(k):
                            config.pop(i)
                    except:
                        continue
                    if True:
                        list_of_configs.append(({'freeze': config}))
        list_of_configs.append(({'freeze': []}))
        list_of_configs.append(({'freeze': []}))
        list_of_configs.append(({'freeze': []}))
        random.shuffle(list_of_configs)
        list_of_configs *= int(args.epochs)
        list_of_configs.append(({'freeze': []}))
        list_of_configs.append(({'freeze': []}))
        list_of_configs.append(({'freeze': []}))
        list_of_configs.append(({'freeze': []}))
        list_of_configs.append(({'freeze': []}))

    elif args.algorithm == 'Subset':
        list_of_configs = []
        if 'Transformer' in args.model:
            space = [1.0, 0.9, 0.8125, 0.75, 0.65, 0.5, 0.4, 0.25, 0.125, 0.0625]
        else:
            space = list(np.linspace(0.1, 1.0, num=50, endpoint=True))
        
        for item in space:
            list_of_configs.append({'keep_factor' : round(item, 4)})

        list_of_configs *= int(args.epochs)
        random.shuffle(list_of_configs)
        list_of_configs.append({'keep_factor' : 1.0})
        list_of_configs.append({'keep_factor' : 1.0})
        list_of_configs.append({'keep_factor' : 1.0})
        list_of_configs.append({'keep_factor' : 1.0})
        list_of_configs.append({'keep_factor' : 1.0})

    res = []

    for kwargs in tqdm.tqdm(reversed(list_of_configs), total=len(list_of_configs)):
        kwargs = copy.deepcopy(kwargs)

        q = Queue()
        p = Process(target=profile, args=(model, kwargs, q))
        p.start()
        p.join()

        t_fw, t_bw, data_up, data_down, memory = q.get()

        already_there = False

        if args.algorithm == 'CoCoFL':
            for item in res:
                if item['freeze'] == kwargs['freeze']:
                    item['time_forward'] += t_fw
                    item['time_backward'] += t_bw
                    item['memory'] += [memory]
                    already_there = True
            if not already_there:
                res.append({'freeze' : list(sorted(kwargs['freeze'])),
                            'time_forward' : t_fw, 
                            'time_backward' : t_bw,
                            'max' : model.n_freezable_layers(),
                            'data_down' : data_down,
                            'data_up' : data_up,
                            'memory' : [memory]})

        elif args.algorithm == 'Subset':
            for item in res:
                if item['keep_factor'] == kwargs['keep_factor']:
                    item['time_forward'] += t_fw
                    item['time_backward'] += t_bw
                    item['memory'] += [memory]
                    already_there = True
            if not already_there:
                res.append({'keep_factor' : kwargs['keep_factor'],
                            'time_forward' : t_fw,
                            'time_backward' : t_bw,
                            'data_down' : data_down,
                            'data_up' : data_up,
                            'memory' : [memory]})


    #create a table with all memory/time/data values relative to full-training
    table_string = f"tables/table__{args.algorithm}_{args.architecture}_{args.model}.json"

    if args.algorithm == 'CoCoFL':
        out = []
        max_time, max_up = 0.0, 0
        data = res
        data = list(sorted(data, key=lambda x: np.mean(x['time_forward']) + np.mean(x['time_backward']), reverse=True))

        max_config = list(filter(lambda x: x['freeze'] == [], data))
        if not max_config: assert False, "Maxium Configuration is missing. Abort!"
        max_config = max_config[0]

        max_time = np.mean(max_config['time_forward']) + np.mean(max_config['time_backward'])
        max_up = max_config['data_up']
        max_mem = np.mean(max_config['memory'])
        print('max_time: ', max_time, 'max_up: ', max_up,'max_memory: ', max_mem)

        for config in data:
            res = {
                'freeze' : config['freeze'],
                'time': round((np.mean(config['time_forward']) + np.mean(config['time_backward']))/max_time, 5),
                'data' : round(config['data_up']/max_up, 5),
                'memory': round(np.mean(config['memory'])/max_mem, 5),
            }
            if config['freeze'] == []:
                res.update({'__debug_max_time_in_s': round(max_time,5),
                '__debug_max_data_in_bytes': round((max_up*4)/(10**9),5),
                '__debug_max_mem_in_gb': round(max_mem,5)})

            out.append(copy.deepcopy(res))

    elif args.algorithm == 'Subset':
        out = []
        max_time, max_up = 0.0, 0
    
        data = res
        data = list(sorted(data, key=lambda x: np.mean(x['time_forward']) + np.mean(x['time_backward']), reverse=True))
        max_config = list(filter(lambda x: x['keep_factor'] == 1.0, data))
        if not max_config: assert False, "Maxium Configuration is missing. Abort!"
        max_config = max_config[0]

        max_time = np.mean(max_config['time_forward']) + np.mean(max_config['time_backward'])
        max_up = max_config['data_up']
        max_mem = np.mean(max_config['memory'])
        print('max_time: ', max_time, 'max_up: ', max_up,'max_memory: ', max_mem)

        for config in data:
            res = {
                'keep_factor' : config['keep_factor'],
                'time': round((np.mean(config['time_forward']) + np.mean(config['time_backward']))/max_time,5),
                'data' : round(config['data_up']/max_up, 5),
                'memory': round(np.mean(config['memory'])/max_mem, 5)
            }

            if config['keep_factor'] == 1.0:
                res.update({'__debug_max_time_in_s': round(max_time,5),
                '__debug_max_data_in_bytes': round((max_up*4)/(10**9),5),
                '__debug_max_mem_in_gb': round(max_mem,5)})

            out.append(copy.deepcopy(res))

    with open(table_string, 'w') as fd:
        json.dump(out, fd, indent=4)