"""Optimize simulation using optuna"""

import argparse
import copy
from distutils.util import strtobool
import pathlib
import subprocess
import sys

import femio
import numpy as np
import optuna
import siml


PARAMETER_SCALE = 0.02


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'input_directory',
        type=pathlib.Path,
        help='Input data directory')
    parser.add_argument(
        'output_directory',
        type=pathlib.Path,
        help='Output directory')
    parser.add_argument(
        '-n', '--n-trial',
        type=int,
        default=100,
        help='The number of trial [100]')
    parser.add_argument(
        '-l', '--parameter-length',
        type=int,
        default=2,
        help='The length of the parameter [100]')
    parser.add_argument(
        '-p', '--parameter-range',
        type=float,
        nargs='+',
        default=(.5 * PARAMETER_SCALE, PARAMETER_SCALE),
        help='The length of the parameter [100]')
    parser.add_argument(
        '-x', '--variable-name',
        type=str,
        default='thermal_conductivity',
        help='The name of the variable to optimize')
    parser.add_argument(
        '-y', '--objective-name',
        type=str,
        default='TEMPERATURE',
        help='The name of the objective variable')
    parser.add_argument(
        '-r', '--read-npy',
        type=strtobool,
        default=0,
        help='If True, read femio npy files if exist [True]')
    parser.add_argument(
        '-t', '--time-series',
        type=strtobool,
        default=1,
        help='If True, read femio as time series data [True]')
    parser.add_argument(
        '-e', '--error-threshold',
        type=float,
        default=None,
        help='If fed, raise ValueError if the result is above the threshold')
    args = parser.parse_args()

    study = Study(args)
    study.run()
    return


class Study:

    def __init__(self, settings):
        self.settings = settings

        # Prepare study
        input_info = str(self.settings.input_directory).replace('/', '-')
        study_name = f"{input_info}_{siml.util.date_string()}"
        self.output_directory = self.settings.output_directory
        self.output_directory.mkdir(parents=True, exist_ok=True)

        sqlite_file_name = self.output_directory / 'data.db'
        print(f"Save SQLite in: {sqlite_file_name}")
        self.storage = f"sqlite:///{sqlite_file_name}"

        self.study = optuna.create_study(
            study_name=study_name,
            storage=self.storage,
            sampler=optuna.samplers.TPESampler(),
            load_if_exists=True, pruner=optuna.pruners.MedianPruner())
        return

    def callback_print(self, study, frozen_trial):
        print(f"Current best trial number: {study.best_trial.number}")
        print(f"Current best value: {study.best_trial.value}")
        return

    def callback_exit(self, study, frozen_trial):
        sys.exit()
        return

    def run(self):
        """Perform hyperparameter search study.

        Parameters
        -----------
        None

        Returns
        --------
        None
        """
        objective = Objective(self.settings)

        # Optimize
        callbacks = (self.callback_print,)
        self.study.optimize(
            objective, n_trials=self.settings.n_trial, catch=(),
            callbacks=callbacks)

        # Visualize the best result
        print('=== Best Trial ===')
        print(self.study.best_trial)
        df = self.study.trials_dataframe()
        df.to_csv(self.settings.output_directory / 'trials.csv')

        if self.settings.error_threshold is not None \
                and self.study.best_trial.value \
                > self.settings.error_threshold:
            raise ValueError('Results not converged')

        return self.study


class Objective():

    DICT_DTYPE = {
        'int': int,
        'float': float,
    }

    def __init__(self, settings):
        self.settings = settings
        self.base_fem_data = self._load_fem_data(self.settings.input_directory)
        return

    def __call__(self, trial):
        """Objective function to make optimization for Optuna.

        Parameters
        -----------
            trial: optuna.trial.Trial
        Returns
        --------
            loss: float
                Loss value for the trial
        """
        output_directory = self.settings.output_directory \
            / f"trial{trial.number}"

        self._generate_simulation_input(trial, output_directory)
        self._perform_simulation(output_directory)
        objective = self._evaluate(output_directory)

        return objective

    def _load_fem_data(self, data_directory):
        fem_data = femio.FEMData.read_directory(
            'fistr', data_directory, read_npy=self.settings.read_npy,
            time_series=self.settings.time_series)
        return fem_data

    def _generate_simulation_input(self, trial, output_directory):
        new_fem_data = copy.deepcopy(self.base_fem_data)
        new_fem_data.settings['frequency'] = 10
        new_fem_data.settings['beta'] = 1.
        parameters = np.array([
            trial.suggest_uniform(
                f"{self.settings.variable_name}_{i}",
                low=self.settings.parameter_range[0],
                high=self.settings.parameter_range[1])
            for i in range(self.settings.parameter_length)])
        new_fem_data.materials[
            self.settings.variable_name].data[0, 0][:, 0] = parameters
        new_setting = new_fem_data.materials[self.settings.variable_name].data
        print(f"New setting: {new_setting}")
        new_fem_data.write('fistr', output_directory / 'mesh')
        return

    def _perform_simulation(self, output_directory):
        print(f"Performing simulation in: {output_directory}")
        fistr_log_file = 'fistr.log'
        sp = subprocess.run(
            f"cd {output_directory} && fistr1 > {fistr_log_file} 2>&1",
            shell=True, check=True)
        print(sp)
        return

    def _evaluate(self, output_directory):
        fem_data = self._load_fem_data(output_directory)
        result = fem_data.nodal_data.get_attribute_data(
            self.settings.objective_name)
        answer = self.base_fem_data.nodal_data.get_attribute_data(
            self.settings.objective_name)
        diff = result - answer
        if len(diff.shape) == 3:
            diff_dict = {
                f"diff_{i}": d for i, d in enumerate(diff)}
            diff_dict.update({
                f"answer_{i}": d for i, d in enumerate(answer)})
            diff_dict.update({
                f"result_{i}": d for i, d in enumerate(result)})
        else:
            diff_dict = {'diff': diff, 'answer': answer, 'result': result}
        fem_data.nodal_data.reset()
        fem_data.nodal_data.update_data(
            fem_data.nodes.ids, diff_dict)
        fem_data.elemental_data.reset()
        fem_data.write('ucd', output_directory / 'mesh.inp')
        return self._rmse(result, answer)

    def _rmse(self, x, y):
        return np.mean((x - y)**2)**.5

    def _suggest_parameter(self, trial, dict_hyperparameter):
        parameter_type = dict_hyperparameter['type']
        if parameter_type == 'categorical':
            choices = dict_hyperparameter['choices']
            ids = [c['id'] for c in choices]
            suggested_id = trial.suggest_categorical(
                dict_hyperparameter['name'], ids)
            for choice in choices:
                if choice['id'] == suggested_id:
                    parameter = choice['value']
            # NOTE: Since list of list, list of dict, ... not supported, just
            # choose index instead of suggest_categorical

        elif parameter_type == 'discrete_uniform':
            if 'dtype' not in dict_hyperparameter:
                dict_hyperparameter['dtype'] = 'float'
            parameter = self.DICT_DTYPE[dict_hyperparameter['dtype']](
                trial.suggest_discrete_uniform(
                    dict_hyperparameter['name'],
                    low=dict_hyperparameter['low'],
                    high=dict_hyperparameter['high'],
                    q=dict_hyperparameter['step']))

        elif parameter_type == 'uniform':
            parameter = trial.suggest_uniform(
                dict_hyperparameter['name'],
                low=dict_hyperparameter['low'],
                high=dict_hyperparameter['high'])

        elif parameter_type == 'loguniform':
            if 'dtype' not in dict_hyperparameter:
                dict_hyperparameter['dtype'] = 'float'
            parameter = self.DICT_DTYPE[dict_hyperparameter['dtype']](
                trial.suggest_loguniform(
                    dict_hyperparameter['name'],
                    low=dict_hyperparameter['low'],
                    high=dict_hyperparameter['high']))

        elif parameter_type == 'int':
            parameter = trial.suggest_int(
                dict_hyperparameter['name'],
                low=dict_hyperparameter['low'],
                high=dict_hyperparameter['high'])

        else:
            raise ValueError(f"Unsupported parameter type: {parameter_type}")

        print(f"\t{dict_hyperparameter['name']}: {parameter}")
        return parameter

    def _create_dict_setting(self, trial):
        # Suggest hyperparameters
        print('--')
        print(f"Trial: {trial.number}")
        print('Current hyperparameters:')
        hyperparameters = {
            dict_hyperparameter['name']:
            self._suggest_parameter(trial, dict_hyperparameter)
            for dict_hyperparameter
            in self.settings.optuna.hyperparameters}
        return self._generate_dict(
            self.settings.optuna.setting, hyperparameters)

    def _generate_dict(self, default_settings, dict_replace):
        if isinstance(default_settings, list):
            return [
                self._generate_dict(d, dict_replace) for d in default_settings]
        elif isinstance(default_settings, dict):
            return {
                key: self._generate_dict(value, dict_replace)
                for key, value in default_settings.items()}
        elif isinstance(default_settings, str):
            if default_settings in dict_replace:
                return dict_replace[default_settings]
            else:
                return default_settings
        elif isinstance(default_settings, int) or isinstance(
                default_settings, float):
            return default_settings
        else:
            raise ValueError(
                f"Unknown data type: {default_settings.__class__}")


if __name__ == '__main__':
    main()
