from __future__ import division
from __future__ import print_function
from __future__ import absolute_import


import numpy as np
from copy import deepcopy


def none_constructor(model_init_cfg, misc=None):
    return GT(None)


def compile_cost(init_obs, ac_seqs, cfg, gt_dynamics, numpy_reward_function,
                 traj_id=0, cem_type=None, tf_data_dict=None):
    assert cem_type is None
    assert tf_data_dict is None

    t, nopt = 0, ac_seqs.shape[0]
    init_costs = np.zeros([nopt, 1])
    ac_seqs = np.reshape(ac_seqs, [-1, cfg['plan_hor'], cfg['dU']])
    ac_seqs = np.transpose(ac_seqs, [1, 0, 2])
    init_obs = np.tile(init_obs[None], [nopt, 1])
    cur_obs = init_obs
    total_cost = init_costs

    #expert_obs = gt_dynamics.expert_obs(traj_id)
    plan_depth = cfg['plan_hor']

    #timestep_left = int(len(expert_obs) - init_obs[0, -1] - 1)
    #plan_depth = min(cfg['plan_hor'], timestep_left)

    for i_iter in range(plan_depth):
        cur_acs = ac_seqs[t]
        next_obs, _ = gt_dynamics.predict(cur_obs, cur_acs)

        '''
        if i_iter == plan_depth - 1:
            delta_cost = -numpy_reward_function(next_obs, cur_acs, expert_obs)
            total_cost += delta_cost.reshape(total_cost.shape)
        else:
            delta_cost = 0.0
        '''
        #delta_cost = -numpy_reward_function(next_obs, cur_acs, expert_obs)
        delta_cost = -numpy_reward_function({'start_state':cur_obs,'action': cur_acs})
        total_cost += delta_cost.reshape(total_cost.shape)
        cur_obs = next_obs

    return total_cost


class GT:
    """ @brief: groundtruth dynamics
    """

    def __init__(self, params):
        """Initializes a class instance.

        Arguments:
            params (DotMap): A dotmap of model parameters.
                .name (str): Model name, used for logging/use in variable scopes.
                    Warning: Models with the same name will overwrite each other.
                .num_networks (int): (optional) The number of networks in the ensemble. Defaults to 1.
                    Ignored if model is being loaded.
                .model_dir (str/None): (optional) Path to directory from which model will be loaded, and
                    saved by default. Defaults to None.
                .load_model (bool): (optional) If True, model will be loaded from the model directory,
                    assuming that the files are generated by a model of the same name. Defaults to False.
                .sess (tf.Session/None): The session that this model will use.
                    If None, creates a session with its own associated graph. Defaults to None.
        """
        # Instance variables
        self.finalized = False
        self.layers, self.decays, self.optvars, self.nonoptvars = [], [], [], []
        self.scaler = None

        # Training objects
        self.optimizer = None
        self.sy_train_in, self.sy_train_targ = None, None
        self.train_op, self.mse_loss = None, None

        # Prediction objects
        self.sy_pred_in2d, self.sy_pred_mean2d_fac = None, None
        self.sy_pred_mean2d, self.sy_pred_var2d = None, None
        self.sy_pred_in3d, self.sy_pred_mean3d_fac = None, None
        self.num_nets = 1

        # the groundtruth dynamics environment
        if params is not None:
            self.name = 'non_tensorflow'
            self.model_dir = params.get('model_dir', None)

            self._misc_args = params.misc
            misc_info = {'reset_type': 'gym', 'groundtruth_model': True,
                         'expert_amc_dir': params.il_cfg.expert_amc_dir,
                         'add_timestep_into_ob': True}

            # TODO:
            #from dmbrl.env import im_dmhumanoid
            #self._dynamics_env = im_dmhumanoid.IMDMHumanoid(
            #    'cmu-humanoid-imitation', 1234, misc_info
            #)
            #self._numpy_reward_function = im_dmhumanoid.numpy_reward_function
    
            def reward_function(data_dict):
                action = data_dict['action']
                obs = data_dict['start_state']
                return - params.opt_cfg.obs_cost_fn(obs) - params.opt_cfg.ac_cost_fn(action)
            
            #self._dynamics_env = deepcopy(params.env)
            from mbbl.env.gym_env import pendulum
            self._dynamics_env = pendulum.env(env_name='gym_pendulum', rand_seed=1234,
                                misc_info={'reset_type': 'gym'})

            #self._numpy_reward_function = params.env.reward
            self._numpy_reward_function = reward_function
    
            # env should be reset in the Agent module
            self._dynamics_env.reset()

    def expert_obs(self, traj_id):
        return self._dynamics_env.expert_obs(traj_id)


    @property
    def is_probabilistic(self):
        return True if self.num_nets > 1 else False

    @property
    def is_tf_model(self):
        return False

    @property
    def sess(self):
        return None

    ###################################
    # Network Structure Setup Methods #
    ###################################

    def add(self, layer):
        pass

    def pop(self):
        pass

    def finalize(self, optimizer, optimizer_args=None, *args, **kwargs):
        self.finalized = True

    #################
    # Model Methods #
    #################

    def train(self, inputs, targets, batch_size=32, epochs=100,
              hide_progress=False, holdout_ratio=0.0, max_logging=5000):
        pass

    def predict(self, observations, actions):
        num_data = observations.shape[0]
        end_state = []
        for i_data in range(num_data):
            i_end_state = self._dynamics_env.fdynamics(
                {'start_state': observations[i_data], 'action': actions[i_data]}
            )
            end_state.append(i_end_state)
        return np.array(end_state), None

    def save(self, savedir=None):
        pass

    def _load_structure(self):
        pass

    #######################
    # Compilation methods #
    #######################

    def _compile_outputs(self, inputs):
        return None

    def _compile_losses(self, inputs, targets):
        return None
