import torch
import torch.autograd
from torch import nn
from torch.distributions import Independent, Normal
from torch.distributions.utils import probs_to_logits
from torch.nn import functional as F
from torch.nn.init import calculate_gain
import torch.nn.init as init
from abc import ABC, abstractmethod
from collections import OrderedDict
from utils.utils import weighted_logsumexp, inv_softplus
import numpy as np
import math
from mpi4py import MPI


class NamedParamsNNPolicy(ABC, nn.Module):
    """
    Interface for policy with named parameters. These policies are used for lookahead gradients.
    """
    def __init__(self, bias=True):
        nn.Module.__init__(self)
        self.bias = bias

    def sgd_update(self, loss, lr, params=None, create_graph=True):
        """
        Apply one step of gradient descent on the loss function `loss`, with
        step-size `step_size`, and returns the updated parameters of the neural
        network.

        Args:
            loss (torch.Tensor): the loss for gradient calculation
            params (OrderedDict): parameters wrt which the gradient should be calculated
            create_graph: create gradient graph for gradient

        Returns:

        """
        if params is None:
            params = OrderedDict(self.named_parameters())
        grads = torch.autograd.grad(loss, params.values(), create_graph=create_graph)
        updated_params = OrderedDict()
        if isinstance(lr, OrderedDict):
            for name, param, grad in zip(params.keys(), params.values(), grads):
                updated_params[name] = param - F.softplus(lr[f"{name}lr"]) * grad
        else:
            for name, param, grad in zip(params.keys(), params.values(), grads):
                updated_params[name] = param - lr * grad

        return updated_params

    def create_named_layers(self, layer_sizes, extra_dims=(), param_dicts=None, prefix=""):
        """
        Creates named layers and add them as modules

        Args:
            layer_sizes (tuple): tuple with layer sizes (inp, sizes*, out)
            extra_dims (tuple): tuple with extra dimensions (eg. for options we use number of options)
            param_dicts (tuple): parameters can belong to multiple param dictionaries (example: inner_params, subpolicy)
            prefix (str): name of the part of the model
        """
        for i in range(1, len(layer_sizes)):
            weight_name = '{0}layer{1}weight'.format(prefix, i)
            weight_params = nn.Parameter(torch.Tensor(*extra_dims, layer_sizes[i - 1], layer_sizes[i]))
            self._init_weight(weight=weight_params.data, input_dim=layer_sizes[i - 1],
                              single_layer=(len(layer_sizes) == 2))
            self.register_parameter(name=weight_name, param=weight_params)

            if self.bias:
                bias_name = '{0}layer{1}bias'.format(prefix, i)
                bias_params = nn.Parameter(torch.Tensor(*extra_dims, layer_sizes[i]))
                self._init_bias(bias=bias_params.data, input_dim=layer_sizes[i - 1],
                                single_layer=(len(layer_sizes) == 2))
                self.register_parameter(name=bias_name, param=bias_params)
            if param_dicts is not None:
                for param_dict in param_dicts:
                    param_dict[weight_name] = weight_params
                    if self.bias:
                        param_dict[bias_name] = bias_params

    def named_forward_pass(self, inp, params, num_layers, nonlinearity, prefix=""):
        """
        Performs a forward pass through the named part of the network

        Args:
            inp (torch.Tensor): input
            params (OrderedDict): a dictionary with parameters that should be used
            num_layers (int): number of layers in the named part
            nonlinearity : non-linearity function that should be used
            prefix (str): name of the part of the model

        Returns:
            Output from the named part of the model
        """
        output = inp
        for i in range(1, num_layers + 1):
            weight = params['{0}layer{1}weight'.format(prefix, i)]
            bias = params['{0}layer{1}bias'.format(prefix, i)] if '{0}layer{1}bias'.format(prefix, i) in params.keys() else None
            if output.dim() == 2 and weight.dim() == 2:
                if bias is not None:
                    output = torch.addmm(bias, output, weight)
                else:
                    output = torch.mm(output, weight)
            else:
                if output.dim() == 2 and weight.dim() == 3:
                    output = output.repeat(weight.shape[0], 1, 1)
                    # output = torch.einsum('bi,pio->bpo', output, weight)
                    # output = torch.matmul(output, weight)
                if output.dim() == 3 and weight.dim() == 3:
                    # output = torch.einsum('bpi,pio->bpo', output, weight)
                    if bias is not None:
                        output = torch.baddbmm(bias.unsqueeze(dim=1), output, weight)
                    else:
                        output = torch.bmm(output, weight)
                else:
                    raise RuntimeError(f"Wrong dims in named forward pass {output.dim()} {weight.dim()}")
            # if bias is not None:
            #     output += bias
            if i < num_layers:
                output = nonlinearity(output)

        return output

    def synchronize(self, comm):
        """
        Broadcast parameter values to all threads in the comm
        :param comm: comm to broadcast to
        """
        for p in self.parameters():
            if comm.Get_rank() == 0:
                global_theta = p.data.view(-1).numpy()
                comm.Bcast([global_theta, MPI.FLOAT], root=0)
            else:
                global_theta = np.empty_like(p.data.view(-1))
                comm.Bcast([global_theta, MPI.FLOAT], root=0)
                p.data = torch.from_numpy(global_theta).reshape(p.data.shape)

    def _init_weight(self, weight, input_dim, single_layer,  nonlinearity="leaky_relu", a=math.sqrt(5)):
        """
        Function for weight initialization (similar to pytorch)
        """
        if single_layer:
            fan = input_dim
            gain = calculate_gain(nonlinearity, a)
            std = gain / math.sqrt(fan)
            bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
        else:
            bound = 0.1
        with torch.no_grad():
            weight.uniform_(-bound, bound)

    def _init_bias(self, bias, input_dim, single_layer):
        """
        Function for bias initialization (similar to pytorch)
        """
        if single_layer:
            fan = input_dim
            bound = 1 / math.sqrt(fan)
            with torch.no_grad():
                init.uniform_(bias, -bound, bound)
        else:
            with torch.no_grad():
                init.zeros_(bias)


class ContinuousHead(object):

    def __init__(self, output_dim, *args, **kwargs):
        self.output_dim = output_dim
        self.max_action = 1
        self.min_action = -1

    def sample_action(self, policy_params):
        means = torch.tanh(policy_params[:, :self.output_dim]) * self.max_action
        stds = torch.exp(policy_params[:, self.output_dim:])
        # shape 1 x action dim
        action = torch.normal(mean=means, std=stds)
        return action

    def selected_action_log_probs(self, policy_params, actions):
        means = torch.tanh(policy_params[:, :, :self.output_dim]) * self.max_action
        stds = torch.exp(policy_params[:, :, self.output_dim:])
        # means shape batch x options x action_dim
        diagonal_mvn = Independent(Normal(loc=means, scale=stds),
                                   reinterpreted_batch_ndims=1)
        log_probs = diagonal_mvn.log_prob(actions[:, None, :])
        return log_probs


class DiscreteHead(object):

    def __init__(self, output_dim, temperature, *args, **kwargs):
        self.temperature = temperature
        self.output_dim = output_dim

    def sample_action(self, policy_params):
        probs = F.softmax(policy_params / self.temperature, dim=-1)

        # Sample an action and return it together with new option
        return torch.multinomial(probs, num_samples=1)

    def selected_action_log_probs(self, policy_params, actions):
        log_probs = F.log_softmax(policy_params / self.temperature, dim=-1)
        selected_log_probs = torch.gather(log_probs, dim=2,
                     index=actions[:, :, None].expand(-1, log_probs.shape[1], 1)).squeeze(dim=2)
        return selected_log_probs


class CategoricalOptionsPolicy(NamedParamsNNPolicy):
    """
    Categorical policy with options

    Args:
        hidden_sizes_base (tuple): sizes of shared hidden layers
        hidden_sizes_option (tuple): sizes of policy over options' hidden layers
        hidden_sizes_termination (tuple): sizes of termination hidden layers
        hidden_sizes_subpolicy (tuple): sizes of policy hidden layers
        nonlinearity : nonlinearity that should be used (default is tanh)
    """

    def __init__(self, obs_dim, action_dim, opts, action_type, nonlinearity):
        NamedParamsNNPolicy.__init__(self, bias=not opts.no_bias)  # do not use bias to have tabular policy
        assert action_type in ("discrete", "continuous")
        self.no_bias = opts.no_bias

        self.temperature_options = opts.temperature_options
        self.temperature_actions = 1
        self.temperature_terminations = opts.temperature_terminations
        self.max_option_prob = opts.max_option_prob

        self.learn_lr_inner = opts.learn_lr_inner
        self.initial_lr_inner = opts.lr_inner

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.options = opts.options
        self.action_type = action_type
        self.fixed_std = opts.fixed_std
        if action_type == "discrete":
            self.subpolicy_head = DiscreteHead(self.action_dim, self.temperature_actions)
        else:
            self.subpolicy_head = ContinuousHead(self.action_dim)

        self.hidden_sizes_shared = tuple(opts.hidden_sizes_base)
        self.hidden_sizes_option = tuple(opts.hidden_sizes_option)
        self.hidden_sizes_termination = tuple(opts.hidden_sizes_termination)
        self.hidden_sizes_subpolicy = tuple(opts.hidden_sizes_subpolicy)

        # Set up parameter dicts
        self.inner_params = OrderedDict()  # Adjusted in inner grad steps (Options)
        self.outer_params = OrderedDict()  # Adjusted in outer grad steps (Subpolicies and Terminations)

        self.base_params = OrderedDict()  # Parameters of shared layers
        self.option_params = OrderedDict()  # Parameters of option layers
        self.termination_params = OrderedDict()  # Parameters of termination layers
        self.subpolicy_params = OrderedDict()  # Parameters of subpolicy layers

        self.nonlinearity = nonlinearity

        self.create_nns()
        if self.learn_lr_inner:
            self.lr_params = OrderedDict()
            self.create_lr_params(lr=self.initial_lr_inner, source_param_dict=self.option_params,
                                  param_dicts=(self.lr_params, self.outer_params))
        if opts.termination_prior > 0:
            self.set_terminations(opts.termination_prior)

    def forward(self, obs, params=None):
        # If no other params are specified used defaults
        if params is None:
            params = OrderedDict(self.named_parameters())

        # If there are shared layers use them
        if len(self.hidden_sizes_shared) == 0:
            base = obs
        else:
            base = self.named_forward_pass(obs, params, len(self.hidden_sizes_shared), self.nonlinearity, prefix="base")

        # Get termination probs (observations x options)
        termination_params = self.named_forward_pass(base, params, len(self.hidden_sizes_termination) + 1,
                                                     self.nonlinearity, prefix="termination").squeeze(dim=2)
        termination_probs = torch.sigmoid(termination_params / self.temperature_terminations)

        # Get option selection probs (observations x options)
        option_params = self.named_forward_pass(base, params, len(self.hidden_sizes_option) + 1,
                                                self.nonlinearity, prefix="options")
        option_probs = F.softmax(option_params / self.temperature_options, dim=1)
        option_probs = self.clip_option_probs(option_probs)


        # Get subpolicy distribution params for all options (observations x options x dist_params)
        subpolicy_params = self.named_forward_pass(base, params, len(self.hidden_sizes_subpolicy) + 1, self.nonlinearity,
                                                  prefix="subpolicy")

        termination_probs = termination_probs.transpose(0, 1)
        subpolicy_params = subpolicy_params.transpose(0, 1)

        if self.action_type != "discrete":
            if self.fixed_std <= 0:
                subpolicy_params = torch.cat([subpolicy_params,
                                              self.subpolicy_params["subpolicylogstd"].expand_as(subpolicy_params)],
                                             dim=-1)
            else:
                subpolicy_params = torch.cat([subpolicy_params,
                                              torch.log(torch.ones_like(subpolicy_params) * self.fixed_std)], dim=-1)

        return option_probs, termination_probs, subpolicy_params

    def get_action(self, obs, active_option=None, params=None):
        """
        Samples an action and a new option based on current observation and option. It is also possible to specify
        network parameters.

        Args:
            obs (torch.Tensor): observation
            active_option (int): active option
            params (OrderedDict): can use different parameters than the ones in the network

        Returns:
            Action and new option
        """
        # If no params are specified use default model params
        if params is None:
            params = OrderedDict(self.named_parameters())

        if len(self.hidden_sizes_shared) == 0:
            base = obs
        else:
            base = self.named_forward_pass(obs, params, len(self.hidden_sizes_shared), self.nonlinearity, prefix="base")

        # If current option is not specified (first action) sample an option otherwise update option
        if active_option is None:
            active_option = self.sample_option(base, params).item()
            termination = 0
        else:
            termination = self.sample_termination(inp=base, active_option=active_option, params=params)
            if termination:
                active_option = self.sample_option(inp=base, params=params).item()
            else:
                active_option = active_option

        action = self.sample_action(inp=base, active_option=active_option, params=params)
        return action, active_option, termination

    def get_update_data(self, traj_data, params=None):
        # If no params are specified use default model params
        if params is None:
            params = OrderedDict(self.named_parameters())

        action_log_probs, m_s, entropies = map(list, zip(*[self._single_traj_update_data(obs=p, actions=q, params=params)
                                                           for p, q in zip(traj_data["observations"], traj_data["actions"])]))
        traj_data["action_log_probs"] = action_log_probs
        traj_data["posterior_option_probs"] = m_s
        traj_data["entropies"] = entropies
        return traj_data

    def _single_traj_update_data(self, obs, actions, params=None):
        """
        Calculates log-probabilities from the trajectory data

        Args:
           params (OrderedDict): params to use (uses default model params by default)

        Returns:
            Trajectory data with log_probs.
        """
        # If no params are specified use default model params
        if params is None:
            params = OrderedDict(self.named_parameters())

        # Get relevant observations, selected actions and create placeholder for action probabilities
        if self.action_type == "discrete":
            actions = actions.long()

        option_probs, termination_probs, policy_params = self.forward(obs, params)

        # Calculate transition probs from from i to j, sum to 1 over dim=2 (observation x options x options)
        transition_probs = torch.matmul(option_probs[:, :, None], termination_probs[:, None, :]) \
                           + torch.diag_embed(1 - termination_probs[:, :])

        # Select the actions probabilities according to actions which were performed (observations x options)
        selected_action_log_probs = self.subpolicy_head.selected_action_log_probs(policy_params=policy_params, actions=actions)
        selected_action_probs = torch.exp(selected_action_log_probs)
        m_s = []
        for i in range(actions.shape[0]):
            # If the new episode starts we set m (probability of being in an option) to policy over option probabilities
            if i == 0:
                m_s.append(option_probs[i])
            # Otherwise perform an update according to IOPG
            else:
                m_s.append(torch.matmul(transition_probs[i], c_vec / torch.sum(c_vec)))
            # Calculate c_vec which is probability of being in options * probability of performing an action
            # in that option
            c_vec = m_s[-1] * selected_action_probs[i]
            c_vec = c_vec * (1 / torch.max(c_vec).detach())

        # Calculate action probabilities with one operation
        m_s = torch.stack(m_s, dim=0)
        marginalized_selected_action_log_probs = weighted_logsumexp(input=selected_action_log_probs, dim=1,
                                                                    weights=m_s, keepdim=True)
        option_entropies = -torch.sum(probs_to_logits(option_probs) * option_probs, dim=-1, keepdim=True)
        entropies = option_entropies
        return marginalized_selected_action_log_probs, m_s, entropies

    def sample_option(self, inp, params):
        weights = self.named_forward_pass(inp=inp, params=params, num_layers=len(self.hidden_sizes_option) + 1,
                                          nonlinearity=self.nonlinearity, prefix="options")
        probs = F.softmax(weights / self.temperature_options, dim=-1)
        probs = self.clip_option_probs(probs)
        option = torch.multinomial(probs, num_samples=1)
        return option

    def sample_termination(self, inp, active_option, params):
        weights = self.named_forward_pass(inp=inp, params=params,
                                          num_layers=len(self.hidden_sizes_termination) + 1,
                                          nonlinearity=self.nonlinearity, prefix="termination")
        # Use sigmoid to get termination probability (0, 1)
        termination_prob = torch.sigmoid(weights[active_option, :, :] / self.temperature_terminations)
        termination = torch.bernoulli(termination_prob).item()
        return termination

    def sample_action(self, inp, active_option, params):
        subpolicy_params = self.named_forward_pass(inp=inp, params=params,
                                                    num_layers=len(self.hidden_sizes_subpolicy) + 1,
                                                    nonlinearity=self.nonlinearity, prefix="subpolicy")[active_option, :, :]
        if self.action_type != "discrete":
            if self.fixed_std <= 0:
                subpolicy_params = torch.cat([subpolicy_params,
                                               self.subpolicy_params["subpolicylogstd"][active_option][None, :]], dim=-1)
            else:
                subpolicy_params = torch.cat([subpolicy_params,
                                               torch.log(torch.ones_like(subpolicy_params) * self.fixed_std)], dim=-1)
        return self.subpolicy_head.sample_action(policy_params=subpolicy_params)

    def create_nns(self):
        base_layer_sizes = (self.obs_dim,) + self.hidden_sizes_shared

        # If there are no shared layers we do not need base part
        if len(self.hidden_sizes_shared) == 0:
            options_layer_sizes = (self.obs_dim,) + self.hidden_sizes_option + (self.options,)
            termination_layer_sizes = (self.obs_dim,) + self.hidden_sizes_termination + (1,)
            policy_layer_sizes = (self.obs_dim,) + self.hidden_sizes_subpolicy + (self.action_dim,)

        else:
            options_layer_sizes = (self.hidden_sizes_shared[-1],) + self.hidden_sizes_option + (self.options,)
            termination_layer_sizes = (self.hidden_sizes_shared[-1],) + self.hidden_sizes_termination + (1,)
            policy_layer_sizes = (self.hidden_sizes_shared[-1],) + self.hidden_sizes_subpolicy + \
                                 (self.action_dim,)
            # If there are shared layers create them
            self.create_named_layers(layer_sizes=base_layer_sizes, param_dicts=(self.outer_params, self.base_params),
                                     prefix="base")

        # Create all the rest (policy over options, options x terminations and options x policies)
        self.create_named_layers(layer_sizes=options_layer_sizes, param_dicts=(self.inner_params, self.option_params),
                                 prefix="options")
        self.create_named_layers(layer_sizes=termination_layer_sizes, extra_dims=(self.options,),
                                 param_dicts=(self.outer_params, self.termination_params), prefix="termination")
        self.create_named_layers(layer_sizes=policy_layer_sizes, extra_dims=(self.options,),
                                 param_dicts=(self.outer_params, self.subpolicy_params),
                                 prefix="subpolicy")

        if self.fixed_std <= 0:
            log_std_params = nn.Parameter(torch.zeros(self.options, self.action_dim))
            log_std_name = "subpolicylogstd"
            self.register_parameter(name=log_std_name, param=log_std_params)
            self.outer_params[log_std_name] = log_std_params
            self.subpolicy_params[log_std_name] = log_std_params

    def create_lr_params(self, lr, source_param_dict, param_dicts):
        for key in source_param_dict.keys():
            name = f"{key}lr"
            lr_param = nn.Parameter(inv_softplus(torch.tensor(lr)))
            self.register_parameter(name=name, param=lr_param)
            for param_dict in param_dicts:
                param_dict[name] = lr_param

    def set_terminations(self, termination_prob):
        weight_value = torch.log(torch.tensor([termination_prob / (1 - termination_prob)]))
        termination_weights = torch.ones_like(self.terminationlayer1weight.data) * weight_value

        self.terminationlayer1weight.data = termination_weights

    def clip_option_probs(self, option_probs):
        if self.options == 1:
            return option_probs
        else:
            mix_coeff = (self.options * self.max_option_prob - 1) / (self.options - 1)
            return option_probs * mix_coeff + (1-mix_coeff) * (1 / self.options)


def net_test():
    state_dim = 8
    action_dim = 5
    options = 4
    batch_size = 200
    prefix = ""
    states = torch.rand(batch_size, state_dim)
    use_bias = True
    net = nn.Sequential(
            nn.Linear(state_dim, 64, bias=use_bias),
            nn.ReLU(),
            nn.Linear(64, 64, bias=use_bias),
            nn.ReLU(),
            nn.Linear(64, options, bias=use_bias)
    )

    outputs1 = net.forward(states)

    policy = NamedParamsNNPolicy(bias=use_bias)

    module_indices = [0, 2, 4]
    for l in range(len(module_indices)):
        weight = net._modules[f'{module_indices[l]}'].weight.data.clone().transpose(0,1)
        weight_name = '{0}layer{1}weight'.format(prefix, l+1)
        weight_params = nn.Parameter(weight)
        policy.register_parameter(name=weight_name, param=weight_params)
        if use_bias:
            bias = net._modules[f'{module_indices[l]}'].bias.data.clone()
            bias_name = '{0}layer{1}bias'.format(prefix, l+1)
            bias_params = nn.Parameter(bias)
            policy.register_parameter(name=bias_name, param=bias_params)
    outputs2 = policy.named_forward_pass(states, params=OrderedDict(policy.named_parameters()), num_layers=3,
                                        nonlinearity=torch.relu, prefix=prefix)
    assert torch.all(outputs1 == outputs2), torch.max(torch.abs(outputs1-outputs2))


def multinet_test():
    state_dim = 8
    action_dim = 5
    options = 4
    batch_size = 200
    prefix = ""
    states = torch.rand(batch_size, state_dim)
    use_bias = True
    nets = [
        nn.Sequential(
            nn.Linear(state_dim, 64, bias=use_bias),
            nn.ReLU(),
            nn.Linear(64, 64, bias=use_bias),
            nn.ReLU(),
            nn.Linear(64, action_dim, bias=use_bias)
        ) for _ in range(options)
    ]
    outputs1 = torch.cat([nets[i].forward(states)[None, :, :] for i in range(options)], dim=0)

    policy = NamedParamsNNPolicy(bias=use_bias)

    module_indices = [0, 2, 4]
    for l in range(len(module_indices)):
        weight = torch.cat([nets[i]._modules[f'{module_indices[l]}'].weight.data.clone()[None,:,:] for i in range(options)], dim=0).transpose(1,2)
        weight_name = '{0}layer{1}weight'.format(prefix, l+1)
        weight_params = nn.Parameter(weight)
        policy.register_parameter(name=weight_name, param=weight_params)
        if use_bias:
            bias = torch.cat([nets[i]._modules[f'{module_indices[l]}'].bias.data.clone()[None, :] for i in range(options)], dim=0)
            bias_name = '{0}layer{1}bias'.format(prefix, l+1)
            bias_params = nn.Parameter(bias)
            policy.register_parameter(name=bias_name, param=bias_params)
    outputs2 = policy.named_forward_pass(states, params=OrderedDict(policy.named_parameters()), num_layers=3,
                                        nonlinearity=torch.relu, prefix=prefix)
    assert torch.all(outputs1 == outputs2), torch.max(torch.abs(outputs1-outputs2))


if __name__ == '__main__':
        multinet_test()
        net_test()