import numpy as np

from collections import OrderedDict
import copy

from lfrl.policies.base.base import ExplorationPolicy
from lfrl.samplers import model_policy_rollout_torch_online


class PolicyEnsemble(ExplorationPolicy):

    def __init__(
            self,
            dynamics_model,
            policies,
            rollouts_per_policy=50,
            horizon=100,
            discount=.99,
            rollout_every=25,
    ):
        self.dynamics_model = dynamics_model
        self.policies = policies
        self.rollouts_per_policy = rollouts_per_policy
        self.horizon = horizon
        self.discount = discount
        self.rollout_every = rollout_every

        self._best_policy = policies[0]

        self.eval_statistics = OrderedDict()
        self._n_timesteps = 0

    def get_action(self, observation):
        if self._n_timesteps % self.rollout_every == 0:
            obs_copied = np.repeat(observation[None], self.rollouts_per_policy, axis=0)
            best_returns = -float('inf')
            for i, policy in enumerate([self._best_policy] + self.policies):
                returns = model_policy_rollout_torch_online(
                    self.dynamics_model,
                    policy,
                    obs_copied,
                    max_path_length=self.horizon,
                    gamma=self.discount,
                )
                cur_returns = returns.mean()
                self.eval_statistics['Policy %d Returns' % i] = cur_returns
                if cur_returns > best_returns:
                    best_returns = cur_returns
                    self._best_policy = policy
            self._best_policy = copy.deepcopy(self._best_policy)
            print(self.eval_statistics)
        self._n_timesteps += 1
        action, *_ = self._best_policy.get_action(observation, deterministic=True)
        return action, dict()

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        return
        for policy in self.policies:
            policy.end_epoch(epoch)

    def to(self, device):
        for policy in self.policies:
            policy.to(device=device)

    def train(self, mode):
        for policy in self.policies:
            policy.train(mode)

    def eval(self, mode):
        for policy in self.policies:
            policy.eval(mode)
