import sys
sys.path.append("..")

import utils.experiment as ue
import utils.preprocessing as up

import datetime
import argparse
import tensorflow as tf

parser = argparse.ArgumentParser()
parser.add_argument("-n", "--name", required=True, help="name of the model")
parser.add_argument("-s", "--size", default=20, help="image size: 20px, 40px or 50px")
parser.add_argument("--sample", default=1, help="sampling frequency of training data")
parser.add_argument("--large_beta", action='store_true', help="using beta=1.25")

args = parser.parse_args()

name = args.name

activation="softplus"

if int(args.size) == 20:
    x_dim=400
    z1_dim=450
    z2_dim=70

    q_z1_x_configs={
        "name": "q_z1_x",
        "input_shape": x_dim,
        "neurons": 500,
        "hidden_layers": 2,
        "activation": activation,
        "output_shape": z1_dim,
        "distribution": "laplace",
    }

    q_z2_z1_configs={
        "name": "q_z2_z1",
        "input_shape": z1_dim,
        "neurons": [500,300,100],
        "hidden_layers": 3,
        "activation": activation,
        "output_shape": z2_dim,
        "distribution": "normal",
    }

    p_x_z1_configs={
        "name": "p_x_z1",
        "input_shape": z1_dim,
        "neurons": None,
        "hidden_layers": 0,
        "activation": "linear",
        "output_shape": x_dim,
        "distribution": "observation_normal",
    }

    p_z1_z2_configs={
        "name": "p_z1_z2",
        "input_shape": z2_dim,
        "neurons": 500,
        "hidden_layers": 1,
        "activation": activation,
        "output_shape": z1_dim,
        "distribution": "laplace",
    }
    
elif int(args.size) == 40:
    x_dim=1600
    z1_dim=1800
    z2_dim=250

    q_z1_x_configs={
        "name": "q_z1_x",
        "input_shape": x_dim,
        "neurons": 2000,
        "hidden_layers": 2,
        "activation": activation,
        "output_shape": z1_dim,
        "distribution": "laplace",
    }

    q_z2_z1_configs={
        "name": "q_z2_z1",
        "input_shape": z1_dim,
        "neurons": [1000,500,250],
        "hidden_layers": 3,
        "activation": activation,
        "output_shape": z2_dim,
        "distribution": "normal",
    }

    p_x_z1_configs={
        "name": "p_x_z1",
        "input_shape": z1_dim,
        "neurons": None,
        "hidden_layers": 0,
        "activation": "linear",
        "output_shape": x_dim,
        "distribution": "observation_normal",
    }

    p_z1_z2_configs={
        "name": "p_z1_z2",
        "input_shape": z2_dim,
        "neurons": 2000,
        "hidden_layers": 1,
        "activation": activation,
        "output_shape": z1_dim,
        "distribution": "laplace",
    }
elif int(args.size) == 50:
    x_dim=2500
    z1_dim=2813
    z2_dim=390

    q_z1_x_configs={
        "name": "q_z1_x",
        "input_shape": x_dim,
        "neurons": 3000,
        "hidden_layers": 2,
        "activation": activation,
        "output_shape": z1_dim,
        "distribution": "laplace",
    }

    q_z2_z1_configs={
        "name": "q_z2_z1",
        "input_shape": z1_dim,
        "neurons": [1500,750,500],
        "hidden_layers": 3,
        "activation": activation,
        "output_shape": z2_dim,
        "distribution": "normal",
    }

    p_x_z1_configs={
        "name": "p_x_z1",
        "input_shape": z1_dim,
        "neurons": None,
        "hidden_layers": 0,
        "activation": "linear",
        "output_shape": x_dim,
        "distribution": "observation_normal",
    }

    p_z1_z2_configs={
        "name": "p_z1_z2",
        "input_shape": z2_dim,
        "neurons": 3000,
        "hidden_layers": 1,
        "activation": activation,
        "output_shape": z1_dim,
        "distribution": "laplace",
    }
else:
    raise ValueError("Allowed image size is either 20px, 40px or 50px")

model_configs={
    "q_z1_x_configs": q_z1_x_configs,
    "q_z2_z1_configs": q_z2_z1_configs,
    "p_z1_z2_configs": p_z1_z2_configs,
    "p_x_z1_configs": p_x_z1_configs,
}

beta1 ={
    "start" : 500,
    "stop" : 1400,
    "init" : 0.1,
    "final" : 1
}

if args.large_beta:
    beta2 ={
        "start" : 500,
        "stop" : 1650,
        "init" : 0.1,
        "final" : 1.25
    }
else:
    beta2 ={
        "start" : 500,
        "stop" : 1400,
        "init" : 0.1,
        "final" : 1
    }    
    
experiment_configs={
    "model_configs": model_configs,
    "beta1" : beta1,
    "beta2" : beta2,
    "experiment_directory": f"/home/documentation/experiments/{name}",
    "max_epochs": 7400
}


exp=ue.Experiment(experiment_configs)

ds_train, ds_val = up.get_natural_ds(batch_size=128, image_size=int(args.size), subsample=int(args.sample))
exp.set_datasets(ds_train,ds_val)

log_dir = f"/home/documentation/logs/{name}_{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}"
exp.train(log_dir=log_dir, learning_rate=0.00001)