"""Experiment that shows arbitrary off-policy behavior of TD."""

import argparse
from distutils.util import strtobool
import json
import os
from pathlib import Path
import pickle

import numpy as np
import ray
from ray import tune
import tensorflow as tf
import tensorflow_probability as tfp
from gym.envs.classic_control import MountainCarEnv
import matplotlib.pyplot as plt


from policy_evaluation.utils import PROJECT_ROOT, get_git_rev
from policy_evaluation import environments
from policy_evaluation import algorithms
from policy_evaluation import policies
from policy_evaluation import tasks
from policy_evaluation import value_functions

from .experiment_runner import (
    ExperimentRunner, datetime_stamp, set_gpu_memory_growth)


set_gpu_memory_growth(True)

tfd = tfp.distributions


DISCOUNT = 0.98
CURRENT_FILE_PATH = Path(__file__)
CACHE_DIR = CURRENT_FILE_PATH.parent / 'data' / CURRENT_FILE_PATH.stem


experiment_params = {
    # 'total_samples': 5000,  # l
    # 'epoch_length': 50,  # error_every
    'n_episodes': 1,  # n_eps
    'episodic': False,
    'name': "mountain_car",
    'title': "Mountain Car",
    'criterion': "RMSE",
}

run_params = {
    # 'run_eagerly': True,
    'num_samples': 5,  # n_indep
    # 'seed': 1,
    'verbose': 100,
}


environment_params = {
    'class_name': 'MountainCar',
    'config': {},
}


def normalize_mountain_car_states(states, environment):
    old_low = environment.observation_space.low
    old_high = environment.observation_space.high
    normalized_states = environments.utils.rescale_values(
        states, old_low, old_high, new_low=-1.0, new_high=1.0)
    assert np.all(np.abs(normalized_states) <= 1.0)
    np.testing.assert_allclose(np.mean(normalized_states), 0, atol=1e-1)
    return normalized_states


def normalize_mountain_car_rewards(rewards):
    normalized_rewards = rewards
    return normalized_rewards


class CustomMountainCarEnv(MountainCarEnv):
    def uniform_random_state(self):
        state = self.observation_space.sample()
        goal_state = self.goal_position <= state[0]
        while goal_state:
            state = self.observation_space.sample()
            goal_state = self.goal_position <= state[0]
        return state

    def reset(self, state=None):
        if state is None:
            return super(CustomMountainCarEnv, self).reset()

        self.state = np.array(state)
        return self.state.copy()


class MountainCarPolicy(policies.BasePolicy):
    def actions(self, inputs):
        actions = tf.cast(tf.sign(inputs[..., 1:2]) + 1.0, tf.int32)
        return actions

    def log_probs(self, *args, **kwargs):
        return tf.math.log(self.probs(*args, **kwargs))

    def probs(self, inputs, actions):
        probs = tf.where(
            actions == tf.cast(tf.sign(inputs[..., 1:2]) + 1.0, tf.int32),
            1.0,
            0.0)
        tf.debugging.assert_equal(probs, 1.0)
        return probs


class MountainCarExperimentRunner(ExperimentRunner):
    def _setup(self, variant):
        super(MountainCarExperimentRunner, self)._setup(variant)

        seed = variant['run_params']['seed']
        self.dataset = ray.get(variant['dataset_object_ids'][str(seed)])
        self.true_value_states, self.true_values = ray.get(
            variant['value_function_object_id'])

        algorithm_params = variant['algorithm_params']
        value_function_params = variant['value_function_params']

        if algorithm_params['class_name'] == 'BBORandomizedPrior':
            V_omega = value_functions.StateValueFunction(
                model=tf.keras.Sequential([
                    tf.keras.layers.Dense(
                        hidden_layer_size,
                        activation=value_function_params['activation'])
                    for hidden_layer_size
                    in value_function_params['hidden_layer_sizes']
                ] + [
                    tf.keras.layers.Dense(1, activation='linear')
                ]),
            )
            V_phi = value_functions.StateValueFunction(
                model=tf.keras.Sequential([
                    tf.keras.layers.Dense(
                        hidden_layer_size,
                        activation=value_function_params['activation'])
                    for hidden_layer_size
                    in value_function_params['hidden_layer_sizes']
                ] + [
                    tf.keras.layers.Dense(1, activation='linear')
                ]),
            )
            algorithm_params['config'].update({
                'V_omega': V_omega,
                'V_phi': V_phi,
            })

        elif algorithm_params['class_name'] == 'TD0':
            V = value_functions.StateValueFunction(
                model=tf.keras.Sequential([
                    tf.keras.layers.Dense(
                        hidden_layer_size,
                        activation=value_function_params['activation'])
                    for hidden_layer_size
                    in value_function_params['hidden_layer_sizes']
                ] + [
                    tf.keras.layers.Dense(1, activation='linear')
                ]),
            )
            algorithm_params['config'].update({'V': V})
        elif algorithm_params['class_name'] in ('TDC', 'GTD2'):
            V_theta = value_functions.StateValueFunction(
                model=tf.keras.Sequential([
                    tf.keras.layers.Dense(
                        hidden_layer_size,
                        activation=value_function_params['activation'])
                    for hidden_layer_size
                    in value_function_params['hidden_layer_sizes']
                ] + [
                    tf.keras.layers.Dense(1, activation='linear')
                ]),
            )
            # initialize model
            V_theta.values(self.dataset['samples']['state_0'][:1, ...])
            algorithm_params['config'].update({'V_theta': V_theta})
        else:
            raise ValueError(variant['algorithm_params']['class_name'])

        self.algorithm = getattr(algorithms, algorithm_params['class_name'])(
            **algorithm_params['config'])

        task_params = variant['task_params']
        assert (task_params['class_name']
                == 'ValuePredictionTask'), (
                    task_params['class_name'])
        task_params['config'].update({
            'algorithm': self.algorithm,
        })

        self.task = tasks.ValuePredictionTask(
            **task_params['config'])
        self.task.true_value_states = self.true_value_states
        self.task.true_values = self.true_values
        self.task.dataset = self.dataset

        self._all_values = []

    def _train(self, *args, **kwargs):
        result = super(MountainCarExperimentRunner, self)._train(
            *args, **kwargs)
        values = self.algorithm.V.values(self.true_value_states).numpy()
        self._all_values.append(values)
        if result[ray.tune.result.DONE]:
            self._plot_values()
            self._save_values()
        return result

    def _plot_values(self):
        num_saved_values = int(len(self._all_values))

        max_subplot_columns = 5
        max_subplot_rows = 5

        num_subplots = min(
            1 + num_saved_values,
            max_subplot_columns * max_subplot_rows)

        num_subplot_rows = int(np.ceil(num_subplots / max_subplot_columns))
        num_subplot_columns = min(num_subplots, max_subplot_columns)

        default_figsize = plt.rcParams.get('figure.figsize')
        figsize = np.array(
            (num_subplot_columns, num_subplot_rows)
        ) * np.max(default_figsize[0])
        figure, axes = plt.subplots(
            num_subplot_rows, num_subplot_columns, figsize=figsize)
        # 1, num_subplots, figsize=figsize, subplot_kw={'projection': '3d'})

        num_states = self.true_values.size
        states_per_side = int(np.sqrt(self.true_values.size))
        assert states_per_side ** 2 == num_states

        values_to_plot = [self.true_values] + [
            self._all_values[i] for i in
            np.linspace(0, num_saved_values - 1, num_subplots - 1, dtype=int)]

        contours = []
        for axis, values in zip(axes.flatten(), values_to_plot):
            contour = axis.contourf(
                self.true_value_states[..., 0].reshape([states_per_side] * 2),
                self.true_value_states[..., 1].reshape([states_per_side] * 2),
                values.reshape([states_per_side] * 2),
                levels=50)
            contours += [contour]

        _ = figure.colorbar(contours[0])

        figure_dir = os.path.join(os.getcwd(), 'figures')
        os.makedirs(figure_dir, exist_ok=True)
        figure_path = os.path.join(figure_dir, 'values.pdf')

        plt.savefig(figure_path)
        figure.clf()
        plt.close(figure)

    def _save_values(self):
        save_dir = os.path.join(os.getcwd(), 'figures')
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, 'values.pkl')
        with open(save_path, 'wb') as f:
            pickle.dump({
                'true_values': self.true_values,
                'true_value_states': self.true_value_states,
                '_all_values': self._all_values,
            }, f)


algorithm_params = {
    'bbo-rp': {
        'class_name': 'BBORandomizedPrior',
        'config': {
            'gamma': DISCOUNT,
            'num_phi_steps': 10,

            'phi_lr': 3e-3,
            'omega_lr': 1e-2,

            'prior_loc': 0.0,
            'prior_scale': 3.0,
            'prior_loss_weight': 1e-1,
        },
    },
    'td0': {
        'class_name': 'TD0',
        'config': {
            'gamma': DISCOUNT,
            'alpha': 3e-4,
        },
    },
    'tdc': {
        'class_name': 'TDC',
        'config': {
            'gamma': DISCOUNT,
            'alpha': 1e-5,
            'beta': 1e-2,
        },
    },
    'gtd2': {
        'class_name': 'GTD2',
        'config': {
            'gamma': DISCOUNT,
            'alpha': 1e-5,
            'beta': 3e-5,
        },
    },
}


value_function_params = {
    'hidden_layer_sizes': (256, ),
    # 'hidden_layer_sizes': tune.grid_search([
    #     (128, ),
    #     (256, ),
    # ]),
    'activation': tune.sample_from(lambda spec: (
        {
            'TDC': 'tanh',
            'GTD2': 'tanh',
        }.get(
            spec.get('config', spec)
            ['algorithm_params']
            ['class_name'],
            'relu')
    )),
}


def train(num_samples,
          num_steps,
          epoch_length,
          debug,
          use_wandb,
          algorithm='bbo-rp',
          experiment_name=None):
    ray.init(resources={}, local_mode=debug, include_webui=False)

    if num_samples < 1 or 25 < num_samples:
        raise ValueError("num_samples must be between 1 and 25.")
    seeds = np.sort(np.random.choice(25, num_samples, replace=False)).tolist()

    run_params.update({
        'run_eagerly': debug,
        'seed': tune.grid_search(seeds),
    })

    policy = MountainCarPolicy(input_shapes=[2], output_shape=[1])
    behavior_policy = target_policy = policy

    def generate_dataset(seed,
                         behavior_policy,
                         target_policy,
                         environment_params,
                         *args,
                         **kwargs):
        environment_params = environment_params.copy()
        environment = CustomMountainCarEnv(**environment_params['config'])

        dataset_key = json.dumps({
            'seed': seed,
            'environment': environment_params,
            'behavior_policy': behavior_policy.get_config(),
            'target_policy': target_policy.get_config(),
            **kwargs,
        }, sort_keys=True, separators=',:')

        cache_path = CACHE_DIR / 'datasets' / f'dataset-{dataset_key}'
        if os.path.exists(cache_path):
            with open(cache_path, 'rb') as f:
                dataset = pickle.load(f)
        else:
            dataset = environments.utils.generate_dataset(
                environment,
                behavior_policy,
                target_policy,
                *args,
                **kwargs)
            os.makedirs(CACHE_DIR / 'datasets', exist_ok=True)
            with open(cache_path, 'wb') as f:
                pickle.dump(dataset, f)

        dataset['samples']['state_0'] = normalize_mountain_car_states(
            dataset['samples']['state_0'], environment)
        dataset['samples']['state_1'] = normalize_mountain_car_states(
            dataset['samples']['state_1'], environment)
        dataset['samples']['reward'] = normalize_mountain_car_rewards(
            dataset['samples']['reward'])

        return dataset

    datasets = {
        seed: generate_dataset(
            seed,
            behavior_policy,
            target_policy,
            environment_params,
            num_samples=20000,  # num_steps,
            independent_samples=True)
        for seed in seeds
    }

    dataset_object_ids = {
        str(seed): ray.put(dataset, weakref=False)
        for seed, dataset in datasets.items()
    }

    def compute_value_function(policy, environment_params):
        environment_params = environment_params.copy()
        environment = CustomMountainCarEnv(**environment_params['config'])

        discount = DISCOUNT
        num_rollouts = 1000

        value_function_key = json.dumps({
            'environment': environment_params,
            'behavior_policy': behavior_policy.get_config(),
            'target_policy': target_policy.get_config(),
            'discount': discount,
            'num_rollouts': num_rollouts,
        }, sort_keys=True, separators=',:')

        cache_path = (
            CACHE_DIR
            / 'value_functions'
            / f'value_function-{value_function_key}')
        if os.path.exists(cache_path):
            with open(cache_path, 'rb') as f:
                value_function = pickle.load(f)
        else:
            value_function = environments.utils.estimate_value_function(
                environment,
                policy,
                discount=discount,
                num_rollouts=num_rollouts,
                max_rollout_length=1000)
            os.makedirs(CACHE_DIR / 'value_functions', exist_ok=True)
            with open(cache_path, 'wb') as f:
                pickle.dump(value_function, f)

        states, values = value_function
        states = normalize_mountain_car_states(states, environment)
        values = normalize_mountain_car_rewards(values)

        return states, values

    value_function = compute_value_function(target_policy, environment_params)
    value_function_object_id = ray.put(value_function, weakref=False)

    experiment_config = {
        'dataset_object_ids': dataset_object_ids,
        'value_function_object_id': value_function_object_id,
        'algorithm_params': algorithm_params[algorithm],
        'value_function_params': value_function_params,
        'experiment_params': {
            'total_samples': num_steps,
            'epoch_length': epoch_length,
            **experiment_params,
        },
        'run_params': run_params,
        'environment_params': environment_params,
        'task_params': {
            'class_name': 'ValuePredictionTask',
            'config': {
                'criteria': ['RMSE', 'MSE', ],
                'batch_size': 512,
            },
        },
        'git_rev': get_git_rev(PROJECT_ROOT),
    }

    if experiment_name is not None:
        experiment_name = '-'.join((datetime_stamp(), experiment_name))
    else:
        experiment_name = datetime_stamp()

    local_dir = os.path.join(
        PROJECT_ROOT, 'data', ('debug' if debug else ''), 'mountain_car')

    tune.run(
        MountainCarExperimentRunner,
        name=experiment_name,
        config=experiment_config,
        resources_per_trial={
            'cpu': 2,
            'gpu': 0.0,
        },
        local_dir=local_dir,
        num_samples=1,  # Use seeds to generate multiple samples
        checkpoint_freq=0,
        checkpoint_at_end=False,
        max_failures=0,
        restore=None,
        with_server=False,
        scheduler=None,
        loggers=(ray.tune.logger.CSVLogger, ray.tune.logger.JsonLogger),
        reuse_actors=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--mode',
        type=str,
        choices=('train', 'visualize'),
        default='train')

    parser.add_argument('--num-samples', type=int, default=25)
    parser.add_argument('--num-steps', type=int, default=5000)
    parser.add_argument('--epoch-length', type=int, default=25)
    parser.add_argument('--experiment-path', type=str, default=None)
    parser.add_argument('--experiment-name', type=str, default=None)
    parser.add_argument('--algorithm', type=str, default='bbo-rp')
    parser.add_argument(
        '--use-wandb',
        type=lambda x: bool(strtobool(x)),
        nargs='?',
        const=True,
        default=False)
    parser.add_argument(
        '--debug',
        type=lambda x: bool(strtobool(x)),
        nargs='?',
        const=True,
        default=False,
        help="Whether or not to execute sequentially to allow breakpoints.")

    args = parser.parse_args()
    if args.mode == 'train':
        train(num_samples=args.num_samples,
              num_steps=args.num_steps,
              epoch_length=args.epoch_length,
              debug=args.debug,
              use_wandb=args.use_wandb,
              algorithm=args.algorithm,
              experiment_name=args.experiment_name)
    elif args.mode == 'visualize':
        raise NotImplementedError(args.mode)
        if args.experiment_path is None:
            raise ValueError("Set '--experiment-path [path-to-experiment]'.")
        visualize_experiment(args.experiment_path)
