from argument_parser import argument_parser
from train_som import train_som
from train_full_model import train_full_model
from train_combined_model import train_model
import os

import torch
from utils import utils
from datetime import datetime
from os.path import join
from sampling.custom_lhs import *


def run_lhs_som(filename, lhs_samples=1):
    lhs = SOMLHS(n_max=[15, 250],
                 at=[0.95, 0.999],
                 lr=[0.001, 0.05],
                 lr_push=[0.01, 0.7],
                 ds_beta=[0.001, 0.2],
                 eps_ds=[0.01, 0.05],
                 ld=[0.01, 0.4],
                 epochs=[25, 70],
                 gamma=[0.1, 2.5],
                 seed=[1, 200000])

    sampling = lhs(lhs_samples)
    lhs.write_params_file(filename)

    return sampling


# DECAY + NEWUPDATE
def run_lhs_autoencoder_model(filename, lhs_samples=1):
    lhs = AutoEncoderModelLHS(som_in=[20, 28],
                              batch_size=[5, 9],
                              alpha=[0.03, 0.1],
                              beta=[0.001, 1.0],
                              n_max=[128, 300],
                              at=[0.98, 0.999],
                              lr=[0.1, 0.4], #
                              lr_decay=[0.7, 0.9],
                              lr_push=[0.5, 1.0], #
                              ds_beta=[0.01, 0.2],
                              eps_ds=[0.5, 0.7], #
                              ld=[0.1, 0.5],
                              epochs=[100, 200],
                              gamma=[2.0, 6.0],
                              seed=[1231, 1455454])

    sampling = lhs(lhs_samples) # , custom_dist_at='exp_inv')

    lhs.write_params_file(filename)

    return sampling


def run_lhs_full_model(filename, lhs_samples=1):
    lhs = FullModelLHS(n_conv=[2, 5],
                       lr_cnn=[0.00001, 0.001],
                       som_in=[10, 100],
                       max_pool=[0, 1],
                       max_pool2d_size=[3, 4],
                       filters_pow=[2, 6],
                       kernel_size=[0.5, 3.5],
                       n_max=[10, 150],
                       at=[0.85, 0.999],
                       lr=[0.0001, 0.005],
                       lr_push=[0.01, 1.0],
                       ds_beta=[0.001, 0.5],
                       eps_ds=[0.01, 0.1],
                       ld=[0.05, 0.5],
                       epochs=[70, 200],
                       gamma=[0.14, 4.0],
                       seed=[1, 200000])

    sampling = lhs(lhs_samples)
    lhs.write_params_file(filename)

    return sampling


if __name__ == '__main__':
    # Argument Parser
    args = argument_parser()

    out_folder = args.out_folder if args.out_folder.endswith("/") else args.out_folder + "/"
    if not os.path.exists(os.path.dirname(out_folder)):
        os.makedirs(os.path.dirname(out_folder), exist_ok=True)

    writer = None
    if args.tensorboard:
        tensorboard_root = args.tensorboard_root
        if not os.path.exists(os.path.dirname(tensorboard_root)):
            os.makedirs(os.path.dirname(tensorboard_root), exist_ok=True)

        tensorboard_folder = join(tensorboard_root,
                                  out_folder.split("/")[1] + "_" + datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
        writer = SummaryWriter(tensorboard_folder)
        print("tensorboard --logdir=" + tensorboard_folder)

    use_cuda = torch.cuda.is_available() and args.cuda

    if use_cuda:
        torch.cuda.init()

    device = torch.device('cuda:0' if use_cuda else 'cpu')

    ngpu = int(args.ngpu)

    root = args.root
    test_root = args.test_root
    dataset_path = args.dataset
    norm_type = args.norm_type
    coil20_unprocessed = args.coil20_unprocessed
    print_debug = args.print
    start_idx = args.start_idx
    stop_idx = args.stop_idx

    train_paths = utils.read_lines(args.train_paths) if args.train_paths is not None else None
    test_paths = utils.read_lines(args.test_paths) if args.test_paths is not None and train_paths is not None else None

    batch_size = args.batch_size
    lr_cnn = args.lr_cnn
    complement = args.complement
    seed = args.seed

    if args.som_only:
        params_file_som = args.params_file if args.params_file is not None else "arguments/default_som.lhs"

        if args.lhs:
            parameters = run_lhs_som(params_file_som, args.lhs_samples)
        else:
            parameters = utils.read_params(params_file_som)

        parameters = utils.parameters_start_stop(parameters, start_idx, stop_idx)

        if train_paths is None:
            train_som(root=root, train_path=dataset_path,  test_root=None, test_path=None, norm_type=norm_type,
                      parameters=parameters, out_folder=out_folder, batch_size=batch_size,
                      device=device, use_cuda=use_cuda, workers=args.workers, evaluate=args.eval,
                      summ_writer=writer, coil20_unprocessed=coil20_unprocessed,
                      save=args.save, load=args.load, model=args.model, semi=args.semi,
                      labels_sampling=args.labels_sampling, n_labels=args.n_labels)
        else:
            for i, train_path in enumerate(train_paths):
                test_path = test_paths[i] if test_paths is not None else None
                train_som(root=root, train_path=train_path, test_root=test_root, test_path=test_path,
                          norm_type=norm_type, parameters=parameters, out_folder=out_folder,
                          batch_size=batch_size, device=device, use_cuda=use_cuda,
                          workers=args.workers, evaluate=args.eval, summ_writer=writer,
                          coil20_unprocessed=coil20_unprocessed,
                          save=args.save, load=args.load, model=args.model, semi=args.semi,
                          labels_sampling=args.labels_sampling, n_labels=args.n_labels)

    elif args.combined:
        params_file_full = args.params_file if args.params_file is not None else "arguments/default_autoencoder.lhs"

        if args.lhs:
            parameters = run_lhs_autoencoder_model(params_file_full, args.lhs_samples)
        else:
            parameters = utils.read_params(params_file_full)

        parameters = utils.parameters_start_stop(parameters, start_idx, stop_idx)

        # Default custom model -> add more as it became needed (e.g., VAE)
        custom_model = 'autoencoder'

        train_model(root=root, dataset_path=dataset_path, parameters=parameters, device=device,
                    use_cuda=use_cuda, out_folder=out_folder, debug=args.debug, n_samples=args.n_samples,
                    batch_size=batch_size, coil20_unprocessed=coil20_unprocessed, alpha=args.alpha, beta=args.beta,
                    som_in=args.som_in, epochs=args.epochs, seed=seed, save=args.save, load=args.load,
                    model=args.model, lr_decay=args.lr_decay, semi=args.semi, custom_model=custom_model,
                    labels_sampling=args.labels_sampling, n_labels=args.n_labels, balanced=args.balanced,
                    evaluate=args.eval, use_wandb=args.wandb, wandb_project=args.wandb_project,
                    landscape=args.landscape)
                            
    else:
        params_file_full = args.params_file if args.params_file is not None else "arguments/default_full_model.lhs"

        if args.lhs:
            parameters = run_lhs_full_model(params_file_full, args.lhs_samples)
        else:
            parameters = utils.read_params(params_file_full)

        parameters = utils.parameters_start_stop(parameters, start_idx, stop_idx)

        train_full_model(root=root, dataset_path=dataset_path, parameters=parameters, device=device, use_cuda=use_cuda,
                         out_folder=out_folder, debug=args.debug, n_samples=args.n_samples, batch_size=batch_size,
                         summ_writer=writer, print_debug=print_debug, coil20_unprocessed=coil20_unprocessed,
                         labels_sampling=args.labels_sampling, n_labels=args.n_labels)