import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.modules.module import T

import numpy as np

from dynamics.utils import RollingNormalizer


def swish(x):
    return x * torch.sigmoid(x)


class EnsembleDenseLayer(nn.Module):
    def __init__(self, n_in, n_out, ensemble_size, non_linearity='leaky_relu'):
        """
        linear + activation Layer
        there are `ensemble_size` layers
        computation is done using batch matrix multiplication
        hence forward pass through all models in the ensemble can be done in one call

        weights initialized with xavier normal for leaky relu and linear, xavier uniform for swish
        biases are always initialized to zeros

        Args:
            n_in: size of input vector
            n_out: size of output vector
            ensemble_size: number of models in the ensemble
            non_linearity: 'linear', 'swish' or 'leaky_relu'
        """

        super().__init__()

        weights = torch.zeros(ensemble_size, n_in, n_out).float()
        biases = torch.zeros(ensemble_size, 1, n_out).float()

        for weight in weights:
            if non_linearity == 'swish':
                nn.init.xavier_uniform_(weight)
            elif non_linearity == 'leaky_relu':
                nn.init.kaiming_normal_(weight)
            elif non_linearity == 'tanh':
                nn.init.kaiming_normal_(weight)
            elif non_linearity == 'linear':
                nn.init.xavier_normal_(weight)

        self.weights = nn.Parameter(weights)
        self.biases = nn.Parameter(biases)

        if non_linearity == 'swish':
            self.non_linearity = swish
        elif non_linearity == 'leaky_relu':
            self.non_linearity = F.leaky_relu
        elif non_linearity == 'tanh':
            self.non_linearity = torch.tanh
        elif non_linearity == 'linear':
            self.non_linearity = lambda x: x

    def forward(self, inp):
        op = torch.baddbmm(self.biases, inp, self.weights)
        return self.non_linearity(op)


class EnsembleModel(nn.Module):
    def __init__(
        self,
        state_dim,
        action_dim,
        mod_reward=False,
        hidden_units=128,
        ensemble_size=10,
        lr=1e-3,
        min_logvar=-5.0,
        max_logvar=-1.0,
        lr_decay=0.999,
        norm_actions=False,
        num_layers=2
    ):
        assert num_layers >= 2, "minimum depth of model is 2"

        super().__init__()

        layers = []
        for lyr_idx in range(num_layers + 1):
            if lyr_idx == 0:
                lyr = EnsembleDenseLayer(action_dim + state_dim, hidden_units, ensemble_size, non_linearity='leaky_relu')
            elif 0 < lyr_idx < num_layers:
                lyr = EnsembleDenseLayer(hidden_units, hidden_units, ensemble_size, non_linearity='leaky_relu')
            elif lyr_idx == num_layers:
                lyr = EnsembleDenseLayer(hidden_units, state_dim + state_dim, ensemble_size, non_linearity='linear')
            layers.append(lyr)

        self.layers = nn.Sequential(*layers)

        self.to('cpu')

        self.normalizer = RollingNormalizer(state_dim=state_dim, action_dim=action_dim)

        # Use a more robust optimizer
        self.model_optimizer = torch.optim.Adam(self.parameters(), lr=lr)

        self.mod_reward = mod_reward
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.hidden_units = hidden_units
        self.num_layers = num_layers
        self.ensemble_size = ensemble_size
        self.device = 'cpu'

        self.norm_actions = norm_actions
        self.lr = lr
        self.lr_decay = lr_decay

        self.min_logvar = torch.tensor(min_logvar)
        self.max_logvar = torch.tensor(max_logvar)
        #
        # self.min_logvar = torch.tensor(-20.)
        # self.max_logvar = torch.tensor(-.5)

    @staticmethod
    def _build_single_model(input_dim, hidden_units, output_dim, num_layers):
        """
        Utility method to construct a single network with the specified number of hidden layers.
        """
        layers = []
        # First layer
        layers.append(nn.Linear(input_dim, hidden_units))
        layers.append(nn.LeakyReLU())

        # Additional hidden layers
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_units, hidden_units))
            layers.append(nn.LeakyReLU())

        # Final layer
        layers.append(nn.Linear(hidden_units, output_dim))

        return nn.Sequential(*layers)

    def pre_norm_inputs(self, states_, actions_):
        states = states_
        actions = actions_

        if self.normalizer is None:
            return states, actions

        states = self.normalizer.normalize_states(states)
        if self.norm_actions:
            actions = self.normalizer.normalize_actions(actions)
        return states, actions

    def pre_norm_outputs(self, outputs):
        outputs = outputs

        if self.normalizer is None:
            return outputs

        if self.mod_reward:
            state_deltas = outputs[:, :, :-1]
            rewards = outputs[:, :, -1:]
            state_deltas = self.normalizer.normalize_state_deltas(state_deltas)
            rewards = self.normalizer.normalize_rewards(rewards)
            outputs = torch.cat([state_deltas, rewards], dim=-1)
        else:
            outputs = self.normalizer.normalize_state_deltas(outputs)
        return outputs

    def post_denorm_outputs(self, delta_mean, var):
        # denormalize to return in raw state space
        if self.normalizer is not None:
            if self.mod_reward:
                delta_mean = self.normalizer.denormalize_state_delta_means(delta_mean[:, :, :-1])
                rew_mean = self.normalizer.denormalize_rewards_means(delta_mean[:, :, -1:])

                delta_var = self.normalizer.denormalize_state_delta_vars(var[:, :, :-1])
                rew_var = self.normalizer.denormalize_rewards_vars(var[:, :, -1:])

                delta_mean = torch.cat([delta_mean, rew_mean], dim=-1)
                var = torch.cat([delta_var, rew_var], dim=-1)
            else:
                delta_mean = self.normalizer.denormalize_state_delta_means(delta_mean)
                var = self.normalizer.denormalize_state_delta_vars(var)
        return delta_mean, var

    def propagate_network(self, states, actions):
        """
        Propagate the ensemble of models
        :param states:
        :param actions:
        :return:
        """
        x = torch.cat([states, actions], dim=1)
        x = x.unsqueeze(0).repeat(self.ensemble_size, 1, 1)
        out = self.layers(x.float())
        delta_mean, logvar = torch.split(out, out.size(2) // 2, dim=2)

        log_var = torch.sigmoid(logvar)
        log_var = self.min_logvar + (self.max_logvar - self.min_logvar) * log_var
        var_s = torch.exp(log_var)

        return delta_mean, var_s

    def forward(self, states, actions):
        """
        Takes in states and actions, normalizes them, and passes them through the ensemble of models and denormalizes
        :param states:
        :param actions:
        :param rewards:
        :param norm_actions:
        :return:
        """

        states_norm, actions_norm = self.pre_norm_inputs(states, actions)
        means, var_s = self.propagate_network(states_norm, actions_norm)

        # Denormalize the means and variances
        means, var_s = self.post_denorm_outputs(means, var_s)

        # Separate state and reward predictions to add back the states to the diff states. If we modelling reward
        if self.mod_reward:
            state_means, reward_means = means[:, :, :-1], means[:, :, -1:]
            state_vars, reward_vars = var_s[:, :, :-1], var_s[:, :, -1:]
            next_state_means = states + state_means
            means = torch.cat([next_state_means, reward_means], dim=-1)
            var_s = torch.cat([state_vars, reward_vars], dim=-1)
        else:
            means = states + means

        return means, var_s

    def forward_all(self, states, actions):
        """
        Takes in states and actions, normalizes them, and passes them through the ensemble of models and denormalizes
        :param states:
        :param actions:
        :param rewards:
        :param norm_actions:
        :return:
        """

        # states = states.unsqueeze(0).repeat(self.ensemble_size, 1, 1)
        # actions = actions.unsqueeze(0).repeat(self.ensemble_size, 1, 1)
        next_state_means, next_state_vars = self(states, actions)
        return next_state_means, next_state_vars

    def train_models(self, states, actions, diff_states, rewards):
        """
        Train the ensemble of models
        :param states:
        :param actions:
        :param diff_states:
        :param rewards:
        :return:
        """

        # Prepare input and targets
        if self.mod_reward:
            targets = torch.cat([diff_states, rewards], dim=1)
        else:
            targets = diff_states

        # Zero gradients
        self.model_optimizer.zero_grad()

        # Normalize inputs and outputs
        states_norm, actions_norm = self.pre_norm_inputs(states, actions)
        targets = self.pre_norm_outputs(targets)

        # Get predictions and compute loss
        mean_preds, var_preds = self.propagate_network(states_norm, actions_norm)

        batch_loss = (mean_preds - targets) ** 2 / var_preds + torch.log(var_preds)
        batch_loss = torch.mean(batch_loss)
        batch_loss.backward()
        self.model_optimizer.step()

        return batch_loss.item()

    def decay_lr(self):
        """Decay learning rate"""
        for param_group in self.model_optimizer.param_groups:
            param_group['lr'] *= self.lr_decay
