from collections import OrderedDict

import numpy as np
import torch
import torch.optim as optim

import lfrl.torch.pytorch_util as ptu
from lfrl.core.rl_algorithms.torch_rl_algorithm import TorchTrainer


class MBRLTrainer(TorchTrainer):
    def __init__(
            self,
            ensemble,
            num_elites=None,
            learning_rate=1e-3,
            batch_size=256,
            optimizer_class=optim.Adam,
            weight_decay=1e-8,
            train_start=True,
            train_call_freq=1,
            **kwargs
    ):
        super().__init__()

        self.ensemble = ensemble
        self.ensemble_size = ensemble.ensemble_size
        self.num_elites = min(num_elites, self.ensemble_size) if num_elites \
                          else self.ensemble_size

        self.obs_dim = ensemble.obs_dim
        self.action_dim = ensemble.action_dim
        self.batch_size = batch_size
        self.weight_decay = weight_decay
        self.train_start = train_start
        self.train_call_freq = train_call_freq

        self.optimizer = self.construct_optimizer(
            ensemble, optimizer_class, learning_rate)

        self._n_train_steps_total = 0
        self._train_calls = 0
        self._need_to_update_eval_statistics = True
        self.eval_statistics = OrderedDict()

    def construct_optimizer(self, model, optimizer_class, lr):
        return optimizer_class(model.parameters(), lr=lr, weight_decay=self.weight_decay)

    def train_from_buffer(
        self,
        replay_buffer,
        holdout_pct=0.05,
        max_epochs_since_update=5,
        max_grad_steps=1000,
        test_batch_size=8192,
    ):
        self._train_calls += 1
        if self._train_calls > 1 and (self._train_calls % self.train_call_freq != 0):
            return

        data = replay_buffer.get_transitions()
        x = data[:,:self.obs_dim + self.action_dim]  # inputs  s, a
        y = data[:,self.obs_dim + self.action_dim:]  # predict r, d, ns
        y[:,-self.obs_dim:] -= x[:,:self.obs_dim]    # predict delta

        # normalize network inputs
        self.ensemble.fit_input_stats(x)

        # generate holdout set
        inds = np.random.permutation(data.shape[0])
        x, y = x[inds], y[inds]
        x, y = ptu.from_numpy(x), ptu.from_numpy(y)

        n_train = max(int((1-holdout_pct) * data.shape[0]), 10)
        n_test = data.shape[0] - n_train

        x_train, y_train = x[:n_train], y[:n_train]
        x_test, y_test = x[n_train:], y[n_train:]

        # train until holdout set convergence
        num_epochs, num_steps = 0, 0
        num_epochs_since_last_update = 0
        best_holdout_loss = float('inf')

        if not ptu.gpu_enabled():
            # if not using GPU, we're probably running on EC2, and don't have much memory
            test_batch_size = min(test_batch_size, 2048)

        num_batches = int(np.ceil(n_train / self.batch_size))
        num_test_batches = int(np.ceil(n_test / test_batch_size))
        best_params = []

        while num_epochs_since_last_update < max_epochs_since_update and num_steps < max_grad_steps:
            # generate idx for each model to perform a batch-level bootstrap
            self.ensemble.train()
            for b in range(num_batches):
                b_idxs = np.random.randint(n_train, size=(self.ensemble_size*self.batch_size))
                x_batch, y_batch = x_train[b_idxs], y_train[b_idxs]
                x_batch = x_batch.view(self.ensemble_size, self.batch_size, -1)
                y_batch = y_batch.view(self.ensemble_size, self.batch_size, -1)

                loss = self.ensemble.get_loss(x_batch, y_batch)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                num_steps += 1
                if num_steps >= max_grad_steps:
                    break

            # stop training based on holdout loss improvement
            self.ensemble.eval()
            holdout_losses = [0] * self.ensemble_size
            holdout_errors = [0] * self.ensemble_size

            with torch.no_grad():
                for b in range(num_test_batches):
                    bi, ei = b * test_batch_size, min((b+1) * test_batch_size, n_test)
                    x_batch, y_batch = x_test[bi:ei], y_test[bi:ei]
                    batch_losses, batch_errors = self.ensemble.get_loss(
                        x_batch, y_batch, split_by_model=True, return_l2_error=True)
                    for i in range(self.ensemble_size):
                        holdout_losses[i] += ptu.get_numpy(batch_losses[i]) * (ei-bi) / n_test
                        holdout_errors[i] += ptu.get_numpy(batch_errors[i]) * (ei-bi) / n_test

            holdout_loss = sum(sorted(holdout_losses)[:self.num_elites]) / self.num_elites

            # if num_steps % 100 == 0:
            print('Num steps %d' % num_steps)
            print(holdout_loss, loss.item() / self.ensemble_size)

            if num_epochs == 0 or \
               (best_holdout_loss - holdout_loss) / abs(best_holdout_loss) > 0.01:
                best_holdout_loss = holdout_loss
                num_epochs_since_last_update = 0
                # best_params = self.ensemble.get_params()
            else:
                num_epochs_since_last_update += 1

            num_epochs += 1

        self.ensemble.elites = np.argsort(holdout_losses)
        # self.ensemble.set_params(best_params)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False

            self.eval_statistics['Model Elites Holdout Loss'] = \
                np.mean(holdout_loss)
            self.eval_statistics['Model Holdout Loss'] = \
                np.mean(sum(holdout_losses)) / self.ensemble_size
            self.eval_statistics['Model Training Epochs'] = num_epochs
            self.eval_statistics['Model Training Steps'] = num_steps

            for i in range(self.ensemble_size):
                name = 'M%d' % (i+1)
                self.eval_statistics[name + ' Loss'] = \
                    np.mean(holdout_losses[i])
                self.eval_statistics[name + ' L2 Error'] = \
                    np.mean(holdout_errors[i])

    def train_from_torch(self, batch, idx=None):
        raise NotImplementedError

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [
            self.ensemble
        ]

    def get_snapshot(self):
        return dict(
            ensemble=self.ensemble
        )


class MBRLLSTMTrainer(MBRLTrainer):

    def __init__(self, train_seq_length, **kwargs):
        super().__init__(**kwargs)
        self.train_seq_length = train_seq_length

    def construct_optimizer(self, model, optimizer_class, lr):
        optimizer = optimizer_class(model.parameters(), lr=lr)
        return optimizer

    def train_from_buffer(
        self,
        replay_buffer,
        holdout_pct=0.2,
        max_epochs_since_update=5,
        max_grad_steps=4000,
        test_batch_size=8192,
    ):
        data = replay_buffer.get_transitions()
        x = data[:,:self.obs_dim + self.action_dim]  # inputs  s, a
        y = data[:,self.obs_dim + self.action_dim:]  # predict r, d, ns
        y[:,-self.obs_dim:] -= x[:,:self.obs_dim]    # predict delta

        # normalize network inputs
        self.ensemble.fit_input_stats(x)

        # have to make batch of sequences for LSTM training
        # NOTE: this assumes replay buffer/data is *correctly ordered*,
        #       i.e. the replay buffer does not wrap around!
        #     + it also assumes no resets
        # ALSO: the holdout set is currently not constructed properly
        num_ex = data.shape[0] - self.train_seq_length + 1
        seq_x = np.zeros((num_ex, self.train_seq_length, x.shape[-1]))
        seq_y = np.zeros((num_ex, self.train_seq_length, y.shape[-1]))

        for i in range(num_ex):
            bi, ei = i, i + self.train_seq_length
            seq_x[i], seq_y[i] = x[bi:ei], y[bi:ei]
        x, y = seq_x, seq_y

        # generate holdout set
        inds = np.random.permutation(num_ex)
        x, y = x[inds], y[inds]
        x, y = ptu.from_numpy(x), ptu.from_numpy(y)

        n_train = int((1-holdout_pct) * num_ex)
        n_test = num_ex - n_train

        x_train, y_train = x[:n_train], y[:n_train]
        x_test, y_test = x[n_train:], y[n_train:]

        # train until holdout set convergence
        num_epochs, num_steps = 0, 0
        num_epochs_since_last_update = 0
        best_holdout_loss = float('inf')

        if not ptu.gpu_enabled():
            # if not using GPU, we're probably running on EC2, and don't have much memory
            test_batch_size = min(test_batch_size, 2048)

        num_batches = int(np.ceil(n_train / self.batch_size))
        num_test_batches = int(np.ceil(n_test / test_batch_size))

        while num_epochs_since_last_update < max_epochs_since_update and num_steps < max_grad_steps:
            # generate idx for each model to bootstrap
            self.ensemble.train()
            for b in range(num_batches):
                b_idxs = np.random.randint(n_train, size=(self.ensemble_size*self.batch_size))
                x_batch, y_batch = x_train[b_idxs], y_train[b_idxs]
                x_batch = x_batch.view(self.ensemble_size, self.batch_size, *x_train[0].shape)
                y_batch = y_batch.view(self.ensemble_size, self.batch_size, *y_train[0].shape)
                loss = self.ensemble.get_loss(x_batch, y_batch)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            num_steps += num_batches

            # stop training based on holdout loss improvement
            self.ensemble.eval()
            holdout_losses = [0] * self.ensemble_size
            holdout_errors = [0] * self.ensemble_size

            with torch.no_grad():
                for b in range(num_test_batches):
                    bi, ei = b * test_batch_size, min((b+1) * test_batch_size, n_test)
                    x_batch, y_batch = x_test[bi:ei], y_test[bi:ei]
                    batch_losses, batch_errors = self.ensemble.get_loss(
                        x_batch, y_batch, split_by_model=True, return_l2_error=True)
                    for i in range(self.ensemble_size):
                        holdout_losses[i] += ptu.get_numpy(batch_losses[i]) * (ei-bi) / n_test
                        holdout_errors[i] += ptu.get_numpy(batch_errors[i]) * (ei-bi) / n_test

            holdout_loss = sum(sorted(holdout_losses)[:self.num_elites]) / self.num_elites

            if num_epochs == 0 or \
               (best_holdout_loss - holdout_loss) / abs(best_holdout_loss) > 0.01:
                best_holdout_loss = holdout_loss
                num_epochs_since_last_update = 0
            else:
                num_epochs_since_last_update += 1

            num_epochs += 1

        self.ensemble.elites = np.argsort(holdout_losses)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False

            self.eval_statistics['Model Elites Holdout Loss'] = \
                np.mean(holdout_loss)
            self.eval_statistics['Model Holdout Loss'] = \
                np.mean(sum(holdout_losses)) / self.ensemble_size
            self.eval_statistics['Model Training Epochs'] = num_epochs
            self.eval_statistics['Model Training Steps'] = num_steps

            for i in range(self.ensemble_size):
                name = 'M%d' % (i+1)
                self.eval_statistics[name + ' Loss'] = \
                    np.mean(holdout_losses[i])
                self.eval_statistics[name + ' L2 Error'] = \
                    np.mean(holdout_errors[i])
