import numpy as np
import torch
import os
from common import util, functional
from models.ensemble_dynamics import EnsembleModel
from operator import itemgetter
from common.normalizer import StandardNormalizer
from copy import deepcopy

from joblib import Parallel, delayed

import warnings
warnings.filterwarnings("ignore")


def clip_norm(x_square, max_grad_norm):
    return np.max([1, x_square ** 0.5 / max_grad_norm])


class TransitionModel:
    def __init__(self,
                 obs_space,
                 action_space,
                 static_fns,
                 lr,
                 noise_multiplier,
                 max_grad_norm,
                 local_epochs,
                 local_batch_size,
                 reward_penalty_coef,
                 no_clip,
                 uncertainty_estimator,
                 clipping_mode,
                 holdout_ratio=0.1,
                 inc_var_loss=False,
                 use_weight_decay=False,
                 load_model_name=None,
                 **kwargs):

        obs_dim = obs_space.shape[0]
        action_dim = action_space.shape[0]

        self.device = util.device
        self.model = EnsembleModel(obs_dim=obs_dim, action_dim=action_dim, device=self.device, **kwargs['model'])
        if load_model_name is not None:
            self.model = torch.load(os.path.join('models', 'saved_models', load_model_name + '.pt'))
        self.static_fns = static_fns
        self.lr = lr
        self.reward_penalty_coef = reward_penalty_coef
        self.deterministic = True
        self.uncertainty_estimator = uncertainty_estimator

        self.networks = {
            "model": self.model
        }
        self.obs_space = obs_space
        self.holdout_ratio = holdout_ratio
        self.inc_var_loss = inc_var_loss
        self.use_weight_decay = use_weight_decay
        self.obs_normalizer = StandardNormalizer()
        self.act_normalizer = StandardNormalizer()
        self.model_train_timesteps = 0

        self.noise_multiplier = noise_multiplier
        self.max_grad_norm = max_grad_norm
        self.clipping_mode = clipping_mode
        self.no_clip = no_clip
        self.nb_layers_per_model = len(self.model.hidden_dims) + (1 - self.deterministic) * 2
        self.nb_layers = self.nb_layers_per_model * kwargs["model"]["ensemble_size"]
        self.nb_models = kwargs["model"]["ensemble_size"]

        # For FedAvg
        self.local_epochs = local_epochs
        self.local_batch_size = local_batch_size

        # Initialize optimizers
        self.model_optimizer = torch.optim.SGD(self.model.parameters(), lr=1.0, momentum=0.9)

        # Create grad norms dict
        self.grad_norms = {}
        for name, param in self.model.named_parameters():
            cur_model = name.split('.')[0]
            if cur_model not in ['min_logvar', 'max_logvar']:
                layer_name = '.'.join(name.split('.')[:-1])
            else:
                layer_name = cur_model
            self.grad_norms[layer_name] = 0

    @torch.no_grad()
    def eval_data(self, data, update_elite_models=False):
        obs_list, action_list, next_obs_list, reward_list = \
            itemgetter("observations", 'actions', 'next_observations', 'rewards')(data)
        obs_list = torch.Tensor(obs_list)
        action_list = torch.Tensor(action_list)
        next_obs_list = torch.Tensor(next_obs_list)
        reward_list = torch.Tensor(reward_list)
        delta_obs_list = next_obs_list - obs_list
        obs_list, action_list = self.transform_obs_action(obs_list, action_list)
        model_input = torch.cat([obs_list, action_list], dim=-1).to(self.device)
        predictions = functional.minibatch_inference(args=[model_input], rollout_fn=self.model.predict,
                                                     batch_size=10000,
                                                     cat_dim=1)  # the inference size grows as model buffer increases
        groundtruths = torch.cat((delta_obs_list, reward_list), dim=1).to(self.device)
        eval_mse_losses, _ = self.model_loss(predictions, groundtruths, mse_only=True)
        if update_elite_models:
            elite_idx = np.argsort(eval_mse_losses.cpu().numpy())
            self.model.elite_model_idxes = elite_idx[:self.model.num_elite]
        return eval_mse_losses.detach().cpu().numpy(), None

    def reset_normalizers(self):
        self.obs_normalizer.reset()
        self.act_normalizer.reset()

    def update_normalizer(self, obs, action):
        self.obs_normalizer.update(obs)
        self.act_normalizer.update(action)

    def transform_obs_action(self, obs, action):
        obs = self.obs_normalizer.transform(obs)
        action = self.act_normalizer.transform(action)
        return obs, action

    def clip_gradients(self, parameters_dict):  # parameters_dict contains gradient data
        if not isinstance(parameters_dict, dict):
            parameters_dict = dict(parameters_dict)
        # Compute clipping factor:
        total_grad_norm = 0
        param_norms = dict.fromkeys(self.grad_norms, 0)  # Reset grad norms dict
        param_norms_model = {'min_logvar': 0, 'max_logvar': 0}
        for i in range(self.nb_models):
            param_norms_model[f'model_{i}'] = 0

        for name, param in parameters_dict.items():
            cur_model = name.split('.')[0]
            if cur_model not in ['min_logvar', 'max_logvar']:
                cur_layer = '.'.join(name.split('.')[:-1])
            else:
                cur_layer = cur_model
            param_norm = param.detach().norm(2)
            param_norms[cur_layer] += param_norm.item() ** 2
            param_norms_model[cur_model] += param_norm.item() ** 2
            total_grad_norm += param_norm.item() ** 2
        if self.clipping_mode == 'per_layer':
            clipping_factors = {k: clip_norm(v, self.max_grad_norm / np.sqrt(self.nb_layers_per_model)) for k, v in
                                param_norms.items()}
        elif self.clipping_mode == 'per_model':
            clipping_factors = {k: clip_norm(v, self.max_grad_norm / np.sqrt(self.nb_models)) for k, v in
                                param_norms_model.items()}

        new_params = {}
        for name, param in parameters_dict.items():
            cur_model = name.split('.')[0]
            if cur_model not in ['min_logvar', 'max_logvar']:
                cur_layer = '.'.join(name.split('.')[:-1])
            else:
                cur_layer = cur_model
            if self.clipping_mode == 'per_layer':
                new_params[name] = param / clipping_factors[cur_layer]
            elif self.clipping_mode == 'per_model':
                if cur_model not in ['min_logvar', 'max_logvar']:
                    new_params[name] = param / clipping_factors[cur_model]
                else:
                    new_params[name] = param

        return new_params

    def update_single_user(self, batch, base_params):
        obs_batch, action_batch, next_obs_batch, reward_batch = \
            itemgetter("observations", 'actions', 'next_observations', 'rewards')(batch)
        obs_batch = torch.Tensor(obs_batch)
        action_batch = torch.Tensor(action_batch)
        next_obs_batch = torch.Tensor(next_obs_batch)
        reward_batch = torch.Tensor(reward_batch)

        delta_obs_batch = next_obs_batch - obs_batch
        obs_batch, action_batch = self.transform_obs_action(obs_batch, action_batch)

        # Split batches
        batch_indices = np.arange(obs_batch.shape[0])
        np.random.shuffle(batch_indices)
        nb_batches = int(np.ceil(obs_batch.shape[0] / self.local_batch_size))

        # Clone base params into current params
        current_params = {}
        for name, param in base_params.items():
            current_params[name] = param.clone().detach()

        for epoch in range(self.local_epochs):
            for batch_idx in range(nb_batches):
                cur_batch_indices = batch_indices[
                                    batch_idx * self.local_batch_size: (batch_idx + 1) * self.local_batch_size]
                cur_obs_batch, cur_action_batch = obs_batch[cur_batch_indices], action_batch[cur_batch_indices]
                cur_delta_obs_batch, cur_reward_batch = delta_obs_batch[cur_batch_indices], reward_batch[
                    cur_batch_indices]

                # predict with model
                model_input = torch.cat([cur_obs_batch, cur_action_batch], dim=-1).to(self.device)
                predictions = self.model.predict(model_input)

                # compute training loss and gradients
                groundtruths = torch.cat((cur_delta_obs_batch, cur_reward_batch), dim=-1).to(self.device)
                train_mse_losses, train_var_losses = self.model_loss(predictions, groundtruths)
                # Sum MSE losses (with mean reductions) over ensembles... What if we average instead?
                train_mse_loss = torch.sum(train_mse_losses)
                train_var_loss = torch.sum(train_var_losses)
                train_transition_loss = train_mse_loss + train_var_loss
                train_transition_loss += 0.01 * torch.sum(self.model.max_logvar) - 0.01 * torch.sum(
                    self.model.min_logvar)  # why
                if self.use_weight_decay:
                    decay_loss = self.model.get_decay_loss()
                    train_transition_loss += decay_loss
                else:
                    decay_loss = None
                # update transition model
                self.model_optimizer.zero_grad()  # Reset gradients
                train_transition_loss.backward(create_graph=False)  # Compute gradient from loss

                gap_params = {}
                for name, param in self.model.named_parameters():
                    if param.grad is not None:
                        current_params[name] += - self.lr * param.grad
                    gap_params[name] = current_params[name] - base_params[name]
                if not self.no_clip:
                    clipped_gap_params = self.clip_gradients(gap_params)  # Clip gap_params
                else:
                    clipped_gap_params = gap_params

                for name, param in self.model.named_parameters():
                    current_params[name] = base_params[name] + clipped_gap_params[name]
                    param.data = current_params[name]

        losses = {'mse_loss': train_mse_loss.item(),
                  'var_loss': train_var_loss.item(),
                  'decay_loss': decay_loss.item() if decay_loss is not None else 0}

        return current_params, losses

    def update(self, data_batches, ep_weights, weight_multiplier):

        # Initialize Gradient Average Dict
        grad_avg = {}
        for name, param in self.model.named_parameters():
            grad_avg[name] = torch.zeros_like(param.data)

        base_params = {}
        for name, param in self.model.named_parameters():
            base_params[name] = param.clone().detach()

        # start_time = time.time()

        update_results = Parallel(n_jobs=8)(
            delayed(self.update_single_user)(batch, base_params) for batch in data_batches)
        updated_params = [res[0] for res in update_results]
        updated_losses = [res[1] for res in update_results]

        for i, current_params in enumerate(updated_params):
            # Store gradients
            grad_weight = ep_weights[i] * weight_multiplier
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    clipped_grad = current_params[name] - base_params[name]
                    grad_avg[name] += grad_weight * clipped_grad  # Averaged on the fly

        # Update gradient (manual step)
        self.model_optimizer.zero_grad()
        for name, param in self.model.named_parameters():
            # noise = torch.randn(param.grad.size()).to(self.device)
            scaled_noise = torch.normal(
                mean=0,
                std=self.noise_multiplier * self.max_grad_norm * weight_multiplier,
                size=grad_avg[name].size(),
                device=self.device
            )
            update = grad_avg[name] + scaled_noise
            param.grad = - update
        self.model_optimizer.step()

        # compute test loss for elite model
        return {
            "loss/train_model_loss_mse": updated_losses[0]['mse_loss'],
            "loss/train_model_loss_var": updated_losses[0]['var_loss'],
            "loss/train_model_loss": updated_losses[0]['var_loss'],
            "loss/decay_loss": updated_losses[0]['decay_loss'],
            "misc/max_std": self.model.max_logvar.mean().item(),
            "misc/min_std": self.model.min_logvar.mean().item()
        }

    def model_loss(self, predictions, groundtruths, mse_only=False):
        pred_means, pred_logvars = predictions
        if self.inc_var_loss and not mse_only:
            # Average over batch and dim, sum over ensembles.
            inv_var = torch.exp(-pred_logvars)
            mse_losses = torch.mean(torch.mean(torch.pow(pred_means - groundtruths, 2) * inv_var, dim=-1), dim=-1)
            var_losses = torch.mean(torch.mean(pred_logvars, dim=-1), dim=-1)
        elif mse_only:
            mse_losses = torch.mean(torch.pow(pred_means - groundtruths, 2), dim=(1, 2))
            var_losses = None
        else:
            assert 0
        return mse_losses, var_losses

    @torch.no_grad()
    def predict(self, obs, act, deterministic=False):
        """
        predict next_obs and rew
        """
        if len(obs.shape) == 1:
            obs = obs[None, ]
            act = act[None, ]
        if not isinstance(obs, torch.Tensor):
            obs = torch.FloatTensor(obs).to(self.device)
        if not isinstance(act, torch.Tensor):
            act = torch.FloatTensor(act).to(self.device)

        scaled_obs, scaled_act = self.transform_obs_action(obs, act)

        model_input = torch.cat([scaled_obs, scaled_act], dim=-1).to(self.device)
        pred_diff_means, pred_diff_logvars = self.model.predict(model_input)
        pred_diff_means = pred_diff_means.detach().cpu().numpy()
        # add curr obs for next obs
        obs = obs.detach().cpu().numpy()
        act = act.detach().cpu().numpy()
        ensemble_model_stds = pred_diff_logvars.exp().sqrt().detach().cpu().numpy()

        if self.deterministic:
            pred_diff_means = pred_diff_means
        else:
            pred_diff_means = pred_diff_means + np.random.normal(size=pred_diff_means.shape) * ensemble_model_stds

        num_models, batch_size, _ = pred_diff_means.shape
        model_idxes = np.random.choice(self.model.elite_model_idxes, size=batch_size)
        batch_idxes = np.arange(0, batch_size)

        pred_diff_samples = pred_diff_means[model_idxes, batch_idxes]

        next_obs, rewards = pred_diff_samples[:, :-1] + obs, pred_diff_samples[:, -1]
        terminals = self.static_fns.termination_fn(obs, act, next_obs)

        # penalty rewards
        penalty_learned_var = True
        if self.reward_penalty_coef != 0:
            if not penalty_learned_var:
                ensemble_means_obs = pred_diff_means[:, :, :-1]
                mean_obs_means = np.mean(ensemble_means_obs, axis=0)  # average predictions over models
                diffs = ensemble_means_obs - mean_obs_means
                normalize_diffs = False
                if normalize_diffs:
                    obs_dim = next_obs.shape[1]
                    obs_sigma = self.model.scaler.cached_sigma[0, :obs_dim]
                    diffs = diffs / obs_sigma
                dists = np.linalg.norm(diffs, axis=2)  # distance in obs space
                penalty = np.max(dists, axis=0)  # max distances over models
            else:
                if self.uncertainty_estimator == 'max_aleatoric':
                    penalty = np.amax(np.linalg.norm(ensemble_model_stds, axis=2), axis=0)
                elif self.uncertainty_estimator == 'max_diff':
                    diffs = []
                    for i in range(pred_diff_means.shape[0]):
                        for j in range(i+1, pred_diff_means.shape[0]):
                            diffs.append(pred_diff_means[i, :, :] - pred_diff_means[j, :, :])
                    diffs = np.array(diffs)
                    penalty = np.amax(np.linalg.norm(diffs, axis=2), axis=0)
                penalized_rewards = rewards - self.reward_penalty_coef * penalty
        else:
            penalized_rewards = rewards

        assert (type(next_obs) == np.ndarray)
        info = {'penalty': penalty, 'penalized_rewards': penalized_rewards}
        penalized_rewards = penalized_rewards[:, None]
        terminals = terminals[:, None]
        return next_obs, penalized_rewards, terminals, info

    def update_best_snapshots(self, val_losses):
        updated = False
        for i in range(len(val_losses)):
            current_loss = val_losses[i]
            best_loss = self.best_snapshot_losses[i]
            improvement = (best_loss - current_loss) / best_loss
            if improvement > 0.01:
                self.best_snapshot_losses[i] = current_loss
                self.save_model_snapshot(i)
                updated = True
                improvement = (best_loss - current_loss) / best_loss
                # print('epoch {} | updated {} | improvement: {:.4f} | best: {:.4f} | current: {:.4f}'.format(epoch, i, improvement, best, current))
        return updated

    def reset_best_snapshots(self):
        self.model_best_snapshots = [deepcopy(self.model.ensemble_models[idx].state_dict()) for idx in
                                     range(self.model.ensemble_size)]
        self.best_snapshot_losses = [1e10 for _ in range(self.model.ensemble_size)]

    def save_model_snapshot(self, idx):
        self.model_best_snapshots[idx] = deepcopy(self.model.ensemble_models[idx].state_dict())

    def load_best_snapshots(self):
        self.model.load_state_dicts(self.model_best_snapshots)

    def save_model(self, info):
        save_dir = os.path.join(util.logger.log_path, 'models')
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        model_save_dir = os.path.join(save_dir, "ite_{}".format(info))
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)
        for network_name, network in self.networks.items():
            save_path = os.path.join(model_save_dir, network_name + ".pt")
            torch.save(network, save_path)

    def load_model(self, info):
        save_dir = os.path.join(util.logger.log_path, 'models')
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        model_save_dir = os.path.join(save_dir, "ite_{}".format(info))
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)
        for network_name, network in self.networks.items():
            save_path = os.path.join(model_save_dir, network_name + ".pt")
            torch.save(network, save_path)
