import argparse
import os
import time
from datetime import datetime

import gin
import numpy as np
import torch
from torch.quasirandom import SobolEngine

from bounce import settings
from bounce.bounce import Bounce
from bounce.benchmarks import Benchmark, Labs
from bounce.util.benchmark import ParameterType

if __name__ == '__main__':
    then = time.time()
    parser = argparse.ArgumentParser(
        prog='Bounce',
        description='Bayesian Optimization',
        epilog='For more information, please contact the author.'
    )

    parser.add_argument(
        '--gin-files', type=str, nargs="+", default=['configs/default.gin'],
        help='Path to the config file'
    )
    parser.add_argument('--gin-bindings', type=str, nargs="+", default=[], )
    parser.add_argument('--random-search', action='store_true', help='Run random search')

    args = parser.parse_args()

    gin.parse_config_files_and_bindings(args.gin_files, args.gin_bindings)

    if not args.random_search:
        alg = Bounce()
        alg.run()
    else:
        alg_bindings = gin.get_bindings(settings.NAME)
        results_dir = alg_bindings['results_dir']
        benchmark: Benchmark = alg_bindings['benchmark']
        n_points = alg_bindings['maximum_number_evaluations']

        if benchmark.is_continuous:
            x_init = SobolEngine(benchmark.dim, scramble=True).draw(n_points)
        elif benchmark.is_binary:
            x_init = torch.randint(0, 2, (n_points, benchmark.dim))
        elif benchmark.is_categorical:

            x_init = torch.zeros((n_points, benchmark.representation_dim), dtype=torch.double)

            parameter_categories = []
            for p in benchmark.parameters:
                parameter_categories.append(torch.randint(0, p.n_realizations, (n_points, 1)))
            parameter_categories = torch.cat(parameter_categories, dim=1)

            indices_to_set = torch.arange(benchmark.dim) * torch.tensor(
                [p.n_realizations for p in benchmark.parameters]
            ) + parameter_categories
            values_to_set = torch.ones((n_points, benchmark.dim), dtype=torch.double)

            x_init.scatter_(1, indices_to_set, values_to_set)

        elif benchmark.is_ordinal:
            x_init = torch.randint(0, benchmark.representation_dim, (n_points, benchmark.dim))
        elif benchmark.is_mixed:
            x_init_cont = None
            x_init_cat = None
            x_init_bin = None

            if benchmark.n_continuous > 0:
                x_init_cont = SobolEngine(benchmark.n_continuous, scramble=True).draw(n_points).to(torch.double)
            if benchmark.n_categorical > 0:
                x_init_cat = torch.zeros(
                    (n_points,
                     sum(p.n_realizations for p in benchmark.parameters if p.type == ParameterType.CATEGORICAL)),
                    dtype=torch.double
                )
                parameter_categories = []

                categorical_parameters = [p for p in benchmark.parameters if p.type == ParameterType.CATEGORICAL]
                for p in categorical_parameters:
                    parameter_categories.append(torch.randint(0, p.n_realizations, (n_points, 1)))
                parameter_categories = torch.cat(parameter_categories, dim=1)

                indices_to_set = torch.arange(benchmark.n_categorical) * torch.tensor(
                    [p.n_realizations for p in categorical_parameters]
                ) + parameter_categories

                values_to_set = torch.ones((n_points, benchmark.n_categorical), dtype=torch.double)
                x_init_cat.scatter_(1, indices_to_set, values_to_set)
            if benchmark.n_binary > 0:
                x_init_bin = torch.randint(0, 2, (n_points, benchmark.n_binary), dtype=torch.double)

            x_init = torch.zeros((n_points, benchmark.representation_dim), dtype=torch.double)

            if x_init_cont is not None:
                x_init[:, benchmark.continuous_indices] = x_init_cont
            if x_init_cat is not None:
                x_init[:, benchmark.categorical_indices] = x_init_cat
            if x_init_bin is not None:
                x_init[:, benchmark.binary_indices] = x_init_bin

        if isinstance(benchmark, Labs):
            x_init = x_init * 2 - 1
        fxs = np.array([benchmark(_x).item() for _x in x_init])

        results_dir = os.path.join(results_dir,
                                   f'{type(benchmark).__name__}_d_{benchmark.dim}_flip_{benchmark.flip}_random_search')

        if not os.path.exists(results_dir):
            os.makedirs(results_dir)

        # get current time (with ms) as string
        timestr = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")
        # save as csv
        np.savetxt(results_dir + f'/x_init_{timestr}.csv', x_init.detach().cpu().numpy(), delimiter=',')
        np.savetxt(results_dir + f'/fxs_{timestr}.csv', fxs, delimiter=',')

    gin.clear_config()
    now = time.time()
    print(f"Total time: {now - then:.2f} seconds")
