import hydra
import omegaconf
import os
import yaml

import torch
from torch.utils.tensorboard import SummaryWriter

from model import MLP
from optimize import optimize


def main():
    # get config
    config = yaml.safe_load(open("configs/PWA.yaml", 'r'))

    # create training directory
    cwd = os.getcwd()
    train_dir = os.path.join(cwd, 'experiments', config['experiment_name'])
    if os.path.exists(train_dir):
        print("Experiment name already exists.")
        exit()
    os.makedirs(train_dir)

    # create model save directory
    model_dir = os.path.join(train_dir, 'models')
    os.makedirs(model_dir)

    # create tensorboard directory and writer
    tensorboard_dir = os.path.join(train_dir, "tensorboard")
    writer = SummaryWriter(log_dir=tensorboard_dir)

    # save config file
    with open(os.path.join(train_dir, "config.yaml"), 'w') as outfile:
        yaml.dump(config, outfile, default_flow_style=False)

    # create the dynamics model
    model = MLP(config)

    # start training
    global_iteration = 0

    ##### explore and learn
    for idx_compress_round in range(len(config['compression']['num_relu_remain']) + 1):

        ### optimize the dynamics model
        model, global_iteration, activation = optimize(
            idx_compress_round,
            config,
            model,
            global_iteration,
            writer,
            n_remain=config['compression']['num_relu_remain'][idx_compress_round-1] if idx_compress_round != 0 else 0)

        model.update_bounds(activation)

        model.save_model("compression_round_%d" % idx_compress_round)

        # plot the masks onto tensorboard in grayscale, with 0 = black and 1 = white
        for i, layer in enumerate(model.mask_prob):
            # height = #relus, width = the three probabilities
            writer.add_image("mask_prob_layer_%d"%i, torch.repeat_interleave(layer, 30, dim=1), global_step=global_iteration, dataformats='HW')
        for i, layer in enumerate(model.mask):
            # height = #relus, width = the value of the mask, dark = 0 = zero, grey = 0.5 = relu, white = 1 = ID
            writer.add_image("mask_layer_%d"%i, torch.repeat_interleave(layer, 90, dim=0) / 2.0, global_step=global_iteration, dataformats='WH')

        if idx_compress_round < len(config['compression']['num_relu_remain']):
            model.update_mask_based_on_mask_prob(n_remain=config['compression']['num_relu_remain'][idx_compress_round])


if __name__ == '__main__':
    main()
