import numpy as np

class PredictEnv:
    def __init__(self, model, env_name, model_type, logger, penalty_coeff=1.0):
        self.model = model
        self.env_name = env_name
        self.model_type = model_type
        self.penalty_coeff = penalty_coeff
        self.logger = logger

    def _termination_fn(self, env_name, obs, act, next_obs):
        if env_name.lower().startswith("hopper"):
            assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

            height = next_obs[:, 0]
            angle = next_obs[:, 1]
            not_done =  np.isfinite(next_obs).all(axis=-1) \
                        * np.abs(next_obs[:,1:] < 100).all(axis=-1) \
                        * (height > .7) \
                        * (np.abs(angle) < .2)

            done = ~not_done
            done = done[:,None]
            return done
        elif env_name == "Walker2d-v2":
            assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

            height = next_obs[:, 0]
            angle = next_obs[:, 1]
            not_done =  (height > 0.8) \
                        * (height < 2.0) \
                        * (angle > -1.0) \
                        * (angle < 1.0)
            done = ~not_done
            done = done[:,None]
            return done
        elif env_name.startswith('hopper'):
            # expanded = False
            # if len(obs.shape) == 1:
            #     expanded = True
            #     obs = np.expand_dims(obs, 0)
            #     act = np.expand_dims(act, 0)
            #     next_obs = np.expand_dims(next_obs, 0)
            # assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

            height = next_obs[:, 0]
            angle = next_obs[:, 1]
            not_done =  np.isfinite(next_obs).all(axis=-1) \
                        * np.abs(next_obs[:,1:] < 100).all(axis=-1) \
                        * (height > .7) \
                        * (np.abs(angle) < .2)

            done = ~not_done
            done = done[:,None]
            return done
        elif env_name.startswith('halfcheetah'):
            assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2
            done = np.array([False]).repeat(len(obs))
            done = done[:,None]
            return done
        else:
            assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2
            done = np.array([False]).repeat(len(obs))
            done = done[:,None]
            return done

    def _get_logprob(self, x, means, variances):

        k = x.shape[-1]

        ## [ num_networks, batch_size ]
        log_prob = -1/2 * (k * np.log(2*np.pi) + np.log(variances).sum(-1) + (np.power(x-means, 2)/variances).sum(-1))

        ## [ batch_size ]
        prob = np.exp(log_prob).sum(0)

        ## [ batch_size ]
        log_prob = np.log(prob)

        stds = np.std(means,0).mean(-1)

        return log_prob, stds

    def reset(self):
        if env_name.lower().startswith("hopper"):
            qpos = np.array([0., 1.25, 0., 0., 0., 0.]) + np_random.uniform(low=-.005, high=.005, size=6)
            qvel = np.array([0., 0., 0., 0., 0., 0.]) + np_random.uniform(low=-.005, high=.005, size=6)
            return np.concatenate(qpos[1:], np.clip(qvel, -10, 10))
        else:
            raise NotImplementedError


    def step(self, obs, act, qdiff=None, total_step=None, deterministic=False, return_unpenalized=False):
        if len(obs.shape) == 1:
            obs = obs[None]
            act = act[None]
            return_single = True
        else:
            return_single = False

        inputs = np.concatenate((obs, act), axis=-1)
        if self.model_type == 'pytorch':
            ensemble_model_means, ensemble_model_vars = self.model.predict(inputs)
        else:
            ensemble_model_means, ensemble_model_vars = self.model.predict(inputs, factored=True)
        ensemble_model_means[:,:,1:] += obs
        ensemble_model_stds = np.sqrt(ensemble_model_vars)

        if deterministic:
            ensemble_samples = ensemble_model_means
        else:
            ensemble_samples = ensemble_model_means + np.random.normal(size=ensemble_model_means.shape) * ensemble_model_stds


        num_models, batch_size, _ = ensemble_model_means.shape
        if self.model_type == 'pytorch':
            model_idxes = np.random.choice(self.model.elite_model_idxes, size=batch_size)
        else:
            model_idxes = self.model.random_inds(batch_size)
        batch_idxes = np.arange(0, batch_size)

        samples = ensemble_samples[model_idxes, batch_idxes]
        model_means = ensemble_model_means[model_idxes, batch_idxes]
        model_stds = ensemble_model_stds[model_idxes, batch_idxes]

        log_prob, dev = self._get_logprob(samples, ensemble_model_means, ensemble_model_vars)

        rewards, next_obs = samples[:,:1], samples[:,1:]
        terminals = self._termination_fn(self.env_name, obs, act, next_obs)

        batch_size = model_means.shape[0]
        return_means = np.concatenate((model_means[:,:1], terminals, model_means[:,1:]), axis=-1)
        return_stds = np.concatenate((model_stds[:,:1], np.zeros((batch_size,1)), model_stds[:,1:]), axis=-1)

        if self.penalty_coeff != 0:
            penalty = np.amax(np.linalg.norm(ensemble_model_stds, axis=2), axis=0)
            disagreement = np.std(ensemble_model_means, axis=0).mean(axis=-1)
            qpenalty = qdiff 
            penalty = np.expand_dims(penalty, 1)#np.expand_dims(penalty, 1)
           # print("ALL PENALTIES", penalty[:10], qpenalty[:10], disagreement[:10])
            #penalty = qpenalty
            assert penalty.shape == rewards.shape
            #self.logger.log('train/penalty', penalty.mean(), total_step)
            unpenalized_rewards = rewards
            penalized_rewards = rewards - self.penalty_coeff * penalty
            #print("REWARDS", rewards[:5], penalized_rewards[:5])
        else:
            penalty = None
            unpenalized_rewards = rewards
            penalized_rewards = rewards
        if return_single:
            next_obs = next_obs[0]
            return_means = return_means[0]
            return_stds = return_stds[0]
            unpenalized_rewards = unpenalized_rewards[0]
            penalized_rewards = penalized_rewards[0]
            terminals = terminals[0]


        info = {'mean': return_means, 'std': return_stds, 'log_prob': log_prob, 'dev': dev}
        if return_unpenalized:
            return next_obs, penalized_rewards, unpenalized_rewards, terminals, info
        return next_obs, penalized_rewards, terminals, info
