from collections import namedtuple, OrderedDict
import platform
import rotor
import torch
import rotor.models as models


parameter = namedtuple('parameter', ('NETWORK_NAME', 'NETWORK_DEPTH', 'IMAGE_SIZE', 'BATCH'))
def param_to_str(p):
    return ':'.join(str(v) for v in p)

resnets = { 18: models.resnet18, 34: models.resnet34,
            50: models.resnet50, 101: models.resnet101, 152: models.resnet152,
            200: models.resnet200, 1001: models.resnet1001 }

densenets = { 121: models.densenet121, 161: models.densenet161,
              169: models.densenet169, 201: models.densenet201 }

vggnets = { 11: models.vgg11_bn, 13: models.vgg13_bn, 16: models.vgg16_bn, 19: models.vgg19_bn}
inceptions = {200: models.Inception3}


networks = { "resnet": (resnets, {"pretrained":False}), "densenet": (densenets, {"drop_rate": 0.25, "pretrained":False}), 
             "vgg": (vggnets, {"pretrained":False}), "inception": (inceptions, {"aux_logits":False})  }

def make_module(param, device=None):
    try: 
        (functions, kwargs) = networks[param.NETWORK_NAME]
    except KeyError:
        print("Unknown network name {}. Allowed Names: {}".format(param.NETWORK_NAME, list(networks.keys())))
        return None
    try:
        module = functions[param.NETWORK_DEPTH](**kwargs)
    except KeyError:
        print("Unknown {} number {}. Known values: {}".format(param.NETWORK_NAME, param.NETWORK_DEPTH, list(functions.keys())))
        return None

    if device:
        module.to(device=device)

    for (n, p) in module.named_parameters():
        p.grad = torch.zeros_like(p)

    return module

def get_shape(param):
    return (param.BATCH, 3, param.IMAGE_SIZE, param.IMAGE_SIZE)

def make_parameters(**args):
    params = [{}]
    for (k,v) in args.items():
        params = [OrderedDict(d, **{k:x}) for x in v for d in params]

    return [ parameter(**p) for p in params ]

parameters = []
parameters += make_parameters(NETWORK_NAME = ['resnet'], NETWORK_DEPTH = [18, 34, 50, 101, 152, 200], IMAGE_SIZE = [224], BATCH = [2**i for i in range(3, 8)])
parameters += make_parameters(NETWORK_NAME = ['resnet'], NETWORK_DEPTH = [18, 34, 50, 101, 152, 200], IMAGE_SIZE = [500], BATCH = [2**i for i in range(2, 7)])
parameters += make_parameters(NETWORK_NAME = ['resnet'], NETWORK_DEPTH = [18, 34, 50, 101, 152, 200], IMAGE_SIZE = [1000], BATCH = [2**i for i in range(0, 4)])
parameters += make_parameters(NETWORK_NAME = ['resnet'], NETWORK_DEPTH = [1001], IMAGE_SIZE = [224], BATCH = [2**i for i in range(0, 5)])
parameters += make_parameters(NETWORK_NAME = ['resnet'], NETWORK_DEPTH = [1001], IMAGE_SIZE = [500], BATCH = [2**i for i in range(0, 3)])
parameters += make_parameters(NETWORK_NAME = ['resnet'], NETWORK_DEPTH = [1001], IMAGE_SIZE = [1000], BATCH = [2**i for i in range(0, 1)])
parameters += make_parameters(NETWORK_NAME = ['densenet'], NETWORK_DEPTH = [121, 161, 169, 201], IMAGE_SIZE = [224], BATCH = [2**i for i in range(4, 8)])
parameters += make_parameters(NETWORK_NAME = ['densenet'], NETWORK_DEPTH = [121, 161, 169, 201], IMAGE_SIZE = [500], BATCH = [2**i for i in range(2, 5)])
parameters += make_parameters(NETWORK_NAME = ['densenet'], NETWORK_DEPTH = [121, 161, 169, 201], IMAGE_SIZE = [1000], BATCH = [2**i for i in range(0, 4)])
parameters += make_parameters(NETWORK_NAME = ['inception'], NETWORK_DEPTH = [200], IMAGE_SIZE = [224], BATCH = [2**i for i in range(3, 9)])
parameters += make_parameters(NETWORK_NAME = ['inception'], NETWORK_DEPTH = [200], IMAGE_SIZE = [500], BATCH = [2**i for i in range(2, 7)])
parameters += make_parameters(NETWORK_NAME = ['inception'], NETWORK_DEPTH = [200], IMAGE_SIZE = [1000], BATCH = [2**i for i in range(0, 5)])


with open("all_chains.py", "w") as f:
    print("hostname = '{}'".format(platform.node().split('.')[0]), file=f)
    print("chains = {}", file=f)
    for p in parameters:
        print("Making model", p)
        model = make_module(p, device="cuda")
        shape = get_shape(p)
        check = rotor.Checkpointable(model)
        
        print("Measuring", p)
        check.measure(torch.randn(*shape, device="cuda"))
        check.makeParams(None)
        key = param_to_str(p)
        
        print("chains[\"{key}\"] = {chain}".format(key=key, chain=check.chain), file=f)
