from collections import OrderedDict
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from typing import Iterable, Optional, Dict

import rlkit.torch.models.ensemble
from rlkit.core.logging.logging import logger
import rlkit.torch.pytorch_util as ptu
from rlkit.core.rl_algorithms.torch_rl_algorithm import TorchTrainer


class BehaviorCloneTrainer(TorchTrainer):
    """
    This class trains a behavior-cloning policy given the dataset generated by some policy.
    This class is developed to be used in MBOP, but it may also be useful for other offline RL algorithms which
    need to learn the behavior policy.

    In the case of MBOP, the behavior_clone is an ensemble model that predicts the action to take given the current state
    and the previous action (s_t, a_{t-1}) |-> a_t.
    A more general mapping to learn would be s_t |-> a_t.

    Unlike the ensemble model in MBOP, in many of policy constraint methods for offline RL,
    the behavior policy is modeled as a generative model (e.g. VAE).

    Therefore, this trainer should only try to train 'some' model that maps an input (state or state and previous action)
    to an output (action) given the buffered data. For this to work, the model has to implement model.get_loss(x, y)
    method which optimizes the model parameters given a batch of input-output pairs.
    """
    def __init__(
            self,
            model,
            obs_dim,
            action_dim,
            learning_rate=1e-3,
            batch_size=256,
            optimizer_class=optim.Adam,
            obs_preproc=None,
            include_prev_action_as_input: bool = False,
            normalize_inputs: bool = False,
            normalize_outputs: bool = False,
            rng: Optional[np.random.Generator] = None,
            **kwargs,
    ):
        super().__init__()

        self.model = model
        self.optimizer = optimizer_class(self.model.parameters(), lr=learning_rate)

        self.obs_preproc = obs_preproc
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.batch_size = batch_size
        self.include_prev_action_as_input = include_prev_action_as_input
        self.model.include_prev_action_as_input = self.include_prev_action_as_input
        if kwargs.get('num_elites', False):
            self.num_elites = kwargs['num_elites']
        self.eval_statistics = OrderedDict()

        self.max_num_test = kwargs.get('max_num_test', int(1e3))
        self._normalize_inputs = normalize_inputs
        self._normalize_outputs = normalize_outputs
        self._is_ensemble = isinstance(self.model, rlkit.torch.models.ensemble.Ensemble)

        if rng is None:
            self._rng: np.random.Generator = np.random.default_rng()
        else:
            self._rng: np.random.Generator = rng
        
        network_dict = dict(
            behavior_clone=self.model,
        )
        if hasattr(self.model, 'elite_models'):
            network_dict.update({'behavior_clone_elite_models': self.model.elite_models})
        for key, module in network_dict.items():
            self.register_module(key, module)

        self.name = 'behavior_clone'

    def train_from_buffer(
            self,
            replay_buffer,
            holdout_pct=0.2,
            max_grad_steps=None,
            max_epochs_since_last_update=100,
            num_total_epochs=None,
            improvement_threshold: float = 0.01,
            use_best_parameters: bool = False,
            **kwargs
    ):
        """
        Trains the behavior_clone with the data in the buffer.

        Args:
            replay_buffer (SimpleReplayBuffer)
            holdout_pct (float): The fraction of data to be used as test set
            max_grad_steps (int): The maximum number of gradient steps to take for training
            epochs_since_last_update (int):
            improvement_threshold (bool):
        """
        # Reset elite models
        if self._is_ensemble:
            self.model.reset_elite_models()
        elif hasattr(self.model, 'trained') and self.model.trained:
            logger.log("Behavior cloning model already trained! Skipping training..", with_timestamp=True)
            return

        # Keep track of the best model parameters obtained from this round of training epochs
        best_weights: Optional[Dict] = None
        eval_score: Optional[torch.Tensor] = None
        training_losses, val_scores = [], []

        # This returns the entire array of (obs, act, rew, done, next_obs) from the buffer
        data = replay_buffer.get_transitions()

        obs, act = data[:, :self.obs_dim], data[:, self.obs_dim: self.obs_dim + self.action_dim]

        # When previous action is included as input
        if self.include_prev_action_as_input:
            next_obs, obs = data[:-1, -self.obs_dim:], obs[1:]

            idx_for_val = np.asarray(np.isclose(next_obs, obs)).nonzero()

            idx = (next_obs != obs).any(1)
            obs = next_obs[~idx, :]
            prev_act = act[:-1][~idx, :]
            act = act[1:][~idx, :]

            if obs.size != idx_for_val[0].size:
                print("[behavior_clone trainer] Warning: some transitions are not included due to numerical issues")

        # Preprocess observations if necessary
        if self.obs_preproc is not None:
            obs = self.obs_preproc(obs)

        # Input - target data
        x = np.concatenate((obs, prev_act), axis=-1) if self.include_prev_action_as_input else obs
        y = act

        # Generate the holdout set
        perm = self._rng.permutation(x.shape[0])
        x, y = x[perm], y[perm]

        n_test = min(int(x.shape[0] * holdout_pct), self.max_num_test)
        x_train, x_test = x[perm[n_test:]], x[perm[:n_test]]
        y_train, y_test = y[perm[n_test:]], y[perm[:n_test]]

        n_train = x.shape[0] - n_test

        # Standardize network inputs / outputs
        if hasattr(self.model, 'set_output_transforms') and callable(getattr(self.model, 'set_output_transforms')):
            self.model.set_output_transforms(False)
        if self._normalize_inputs:
            self.model.fit_input_stats(x_train)
        if self._normalize_outputs:
            y_mean, y_std = self.model.fit_output_stats(y_train)
            y_train = (y_train - y_mean) / (y_std + 1e-8)
            y_test = (y_test - y_mean) / (y_std + 1e-8)

        # Initial evaluation
        self.model.eval()
        best_val_score = self.evaluate(self.model, x_test, y_test)

        # # Skip training if the model is an ensemble and is already trained
        # if self.model.trained:
        #     logger.log("Behavior clone already trained! Setting the elite models and skipping training..", with_timestamp=True)
        #     self._set_elite(self.model, best_val_score)
        #     return

        # train until holdout set convergence
        num_epochs, num_steps = 0, 0
        num_epochs_since_last_update = 0
        self.prev_params = deepcopy(self.model.state_dict())
        best_holdout_loss = float('inf')

        # indices for random sampling;
        if isinstance(self.model, rlkit.torch.models.ensemble.Ensemble):
            shape = (self.model.ensemble_size, x_train.shape[0])
        else:
            shape = (x_train.shape[0], )

        while num_epochs_since_last_update < max_epochs_since_last_update and \
            (not max_grad_steps or num_steps < max_grad_steps):
            if num_total_epochs and num_epochs == num_total_epochs:
                break
            idxs = self._rng.integers(x_train.shape[0], size=shape)
            num_batches = int(np.ceil(n_train / self.batch_size))

            self.model.train()
            batch_losses = []
            for b in range(num_batches):
                if isinstance(self.model, rlkit.torch.models.ensemble.Ensemble):
                    b_idxs = idxs[:, b*self.batch_size: (b+1)*self.batch_size]
                    x_batch, y_batch = x_train[b_idxs], y_train[b_idxs]         # (ensemble_size, batch_size, input_dim)
                else:
                    b_idxs = idxs[b*self.batch_size: (b+1)*self.batch_size]
                    x_batch, y_batch = x_train[b_idxs], y_train[b_idxs]         # (batch_size, input_dim)
                x_batch, y_batch = ptu.from_numpy(x_batch), ptu.from_numpy(y_batch)

                self.optimizer.zero_grad()
                loss, _ = self.model.get_loss(x_batch, y_batch)
                loss.backward()
                self.optimizer.step()

                # Keep track of the training loss over time
                batch_losses.append(loss.detach().cpu())

            avg_batch_loss = np.mean(batch_losses).mean().item()
            training_losses.append(avg_batch_loss)
            num_steps += num_batches

            # Check if the validation score on average has improved
            self.model.eval()
            eval_score = self.evaluate(self.model, x_test, y_test)
            val_scores.append(eval_score.mean().item())
            maybe_best_weights, best_val_score = self.maybe_get_best_weights(
                self.model, best_val_score, eval_score, self.prev_params, improvement_threshold
            )

            # If there was an improvement, save the model parameters; otherwise, repeat the epochs
            if maybe_best_weights:
                best_weights = maybe_best_weights
                num_epochs_since_last_update = 0
            else:
                num_epochs_since_last_update += 1

            self.eval_statistics['Holdout Loss'] = ptu.get_numpy(eval_score.mean())
            self.eval_statistics['Training Epochs'] = num_epochs
            self.eval_statistics['Training Steps'] = num_steps

            if not (num_epochs_since_last_update == max_epochs_since_last_update or
                        (max_grad_steps and num_steps == max_grad_steps) or
                        (num_total_epochs is not None and num_epochs == num_total_epochs - 1)
            ):
                self.end_epoch(num_epochs)
            num_epochs += 1

        # Saving the best models
        if use_best_parameters:
            self._set_maybe_best_weights_and_elite(self.model, best_weights, best_val_score)
            if hasattr(self.model, 'elite_models'):
                self.register_module('behavior_clone_elite_models', self.model.elite_models)
        self.end_epoch(num_epochs - 1)

    def _set_maybe_best_weights_and_elite(
            self,
            model,
            best_weights: Optional[dict],
            best_val_score: torch.Tensor
    ):
        self._set_maybe_best_weights(model, best_weights, best_val_score)
        self._set_elite(model, best_val_score)

    def _set_elite(
            self,
            model,
            best_val_score
    ):
        if len(best_val_score.shape) > 0 and len(best_val_score) > 1 and hasattr(self, "num_elites"):
            sorted_indices = np.argsort(best_val_score.tolist())
            elite_models = sorted_indices[:self.num_elites]
            model.set_elite(elite_models)

    def train_from_torch(self, batch):
        pass

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._log_stats(epoch, prefix='Behavior Clone Learning/')
        snapshot = self.get_snapshot()
        logger.save_itr_params(epoch, snapshot, prefix='bc')
    
    def configure_logging(self, **kwargs):
        import wandb
        wandb.watch(self.model, **kwargs)

    def load(self, state_dict, prefix=''):
        name = 'bc_model'
        name = f"{prefix}/{name}" if prefix != '' else name
        if name in state_dict:
            try:
                self.model.load_state_dict(state_dict[name])
                if hasattr(self.model, 'trained'):
                    self.model.trained = True
            except RuntimeError:
                print(f"Failed to load state_dict[{name}]")
        
        if 'bc_elite_models' in state_dict and hasattr(self.model, 'elite_models'):
            self.model.set_elite(state_dict['bc_elite_models'])
