import numpy as np
import torch

from collections import OrderedDict

from lfrl.core.rl_algorithms.torch_rl_algorithm import TorchTrainer
import lfrl.torch.pytorch_util as ptu


class MPPITrainer(TorchTrainer):

    def __init__(
            self,
            policy,
            replay_buffer=None,
            value_network=None,

            value_lr=1e-3,
            value_grad_steps=64,
            value_batch_size=32,
            value_horizon=32,
    ):
        super().__init__()

        self.policy = policy
        self.replay_buffer = replay_buffer  # note this should be the buffer used to train the value function
        self.value_network = value_network

        self.value_lr = value_lr
        self.value_grad_steps = value_grad_steps
        self.value_batch_size = value_batch_size
        self.value_horizon = value_horizon

        if self.value_network is not None:
            self.value_optim = torch.optim.Adam(self.value_network.parameters(), lr=self.value_lr)

            self.terminal_func = self.get_terminal_values

        self.eval_statistics = OrderedDict()
        self._need_to_update_eval_statistics = True

    def get_terminal_values(self, states):
        return self.value_network(states).view(-1)

    def train_from_torch(self, batch):
        if self.value_network is None:
            return

        transitions = self.replay_buffer.get_transitions()
        effective_length = transitions.shape[0] - self.value_horizon
        inds = np.random.randint(0, effective_length,
                                 size=(self.value_grad_steps * self.value_batch_size))

        # do some work so that we don't sample from the most recent horizon transitions
        max_size =self.replay_buffer.max_replay_buffer_size()
        if self.replay_buffer.total_entries > max_size:
            max_ind = self.replay_buffer.top() - self.value_horizon
            if max_ind < 0:
                # [xxx (top, < value_horizon) OOOOOOOO (max_size + max_ind) xxx]
                inds += self.replay_buffer.top()
            else:
                # [OOO (top - value_horizon) xxxxxx (top) OOOOOOOO]
                inds = (inds + self.replay_buffer.top()) % max_size

        obs_dim, action_dim = self.replay_buffer.obs_dim(), self.replay_buffer.action_dim()
        states = transitions[:, :obs_dim]
        rewards = transitions[:, obs_dim + action_dim]
        next_states = transitions[:, -obs_dim:]

        targets = np.zeros(transitions.shape[0])
        discount = 1
        with torch.no_grad():
            for k in range(self.value_horizon):
                cur_inds = (inds + k) % max_size
                targets += discount * rewards[cur_inds]
                discount *= self.policy.discount
            final_inds = (inds + self.value_horizon) % max_size
            final_states = ptu.from_numpy(next_states[final_inds])  # batch this?
            targets += discount * ptu.get_numpy(self.value_network(final_states))

        states = ptu.from_numpy(states[inds])
        for i in range(self.value_grad_steps):
            bi, ei = i * self.value_batch_size, (i+1) * self.value_batch_size
            cur_values = self.value_network(states[bi:ei])
            cur_targets = ptu.from_numpy(targets[bi:ei])
            loss = ((cur_values - cur_targets) ** 2).mean()

            self.value_optim.zero_grad()
            loss.backward()
            self.value_optim.step()

        if self._need_to_update_eval_statistics:
            self.eval_statistics['Value Loss'] = ptu.get_numpy(loss.item())

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self.policy.end_epoch(epoch)
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [self.policy.dynamics_model]

    def get_snapshot(self):
        return dict(
            dynamics_model=self.policy.dynamics_model,
            value_network=self.value_network,
        )
