import os
import math
import time
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F


def sample_gumbel(shape, eps=1e-10):
    U = torch.rand(shape).cuda()
    return -torch.log(-torch.log(U + eps) + eps)


def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)


def gumbel_softmax(logits, temperature=0.5, batch_size=1, hard=False):
    """
    ST-gumple-softmax
    input: [n_hidden, n_class]
    return: flatten --> [B, n_hidden, n_class] an one-hot vector
    """
    n_hidden, categorical_dim = logits.size()
    logits = logits[None, :, :].repeat(batch_size, 1, 1)

    y = gumbel_softmax_sample(logits, temperature)

    if not hard:
        return y

    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # Set gradients w.r.t. y_hard gradients w.r.t. y
    y_hard = (y_hard - y).detach() + y
    return y_hard

class MLP(nn.Module):

    def __init__(self, config):

        super(MLP, self).__init__()

        self.config = config

        self.state_dim = config['data']['state_dim']
        self.action_dim = config['data']['action_dim']

        self.gumbel = config['train']['gumbel'] # if we use gumbel softmax
        self.n_history = config['train']['n_history'] # how many steps in history we are using as input

        self.layers = []
        input_dim = (self.state_dim + self.action_dim) * self.n_history
        for i in config['train']['architecture']:
            self.layers.append(nn.Linear(input_dim, i))
            self.layers.append(nn.ReLU())
            input_dim = i
        self.layers.append(nn.Linear(input_dim, self.state_dim))
        self.model = nn.Sequential(*self.layers)

        N = 3 # corresponds to three probabilities: ReLU, ID or Zero
        masks = []
        for i in config['train']['architecture']:
            masks.append(nn.Parameter(torch.log(torch.ones(i, N) / N)))
        self.mask_prob = nn.ParameterList(masks)

        # create the initial mask, in which no relus are masked out
        self.mask = []
        for i in config['train']['architecture']:
            self.mask.append(np.ones((1, i)))
        self.mask = [torch.FloatTensor(d).cuda() for d in self.mask]

        # create the initial bounds
        self.lb = []
        self.ub = []
        for i in config['train']['architecture']:
            self.lb.append(np.ones(i) * np.inf)
            self.ub.append(np.ones(i) * (-np.inf))
        if not config['planning']['big_m']:
            self.lb.insert(0,np.ones((self.state_dim + self.action_dim) * self.n_history) * np.inf)
            del self.lb[-1]
            self.ub.insert(0,np.ones((self.state_dim + self.action_dim) * self.n_history) * (-np.inf))
            del self.ub[-1]

        # set up model save directory
        cwd = os.getcwd()
        train_dir = os.path.join(cwd, 'experiments', config['experiment_name'])
        self.model_dir = os.path.join(train_dir, 'models')


    def forward(self,
                input, # dict: w/ keys ['observation', 'action']
                ): # torch.FloatTensor B x state_dim

        """
        input['observation'].shape is [B, n_history, obs_dim]
        input['action'].shape is [B, n_history, action_dim]

        mask: shape is [#Relu layers, 1, #output units at that Relu layer]
        1 means not affected, 0 means negative/zero, 2 means positive/ID
        """

        # [B, n_history, obs_dim]
        state = input['observation']
        # [B, n_history, action_dim]
        action = input['action']

        B, n_history, state_size = state.shape

        # flatten the observation and action inputs
        # then concatenate them
        # thus of shape (B, n_history * obs_dim + n_history * action_dim)
        input = torch.cat([state.view(B, -1), action.view(B, -1)], 1).float()

        activations = []

        mask_ID, mask_ZERO, mask_ReLU = [], [], []
        for m in self.mask:
            mask_ID.append((m == 2).float())
            mask_ZERO.append((m == 0).float())
            mask_ReLU.append((m == 1).float())
            assert torch.all(torch.eq(mask_ReLU[-1], 1. - (mask_ID[-1] + mask_ZERO[-1])))
            assert torch.sum(mask_ID[-1]) + \
                    torch.sum(mask_ZERO[-1]) + \
                    torch.sum(mask_ReLU[-1]) == m.size(1)
        # go through all the layers

        for i in range(0, len(self.layers)-1, 2):
            if not self.config['planning']['big_m']:
                activations.append(input)
            # pass it through the affine function
            a = self.model[i](input)
            if self.config['planning']['big_m']:
                activations.append(a)

            # first apply the masks
            # note we don't care about relu_ZERO because it's zero anyway, so it's automatically added to relu_ID
            relu_ID = a * mask_ID[i//2]

            # for the remainder of the network, we can use gumbel softmax to compute
            # the activation is calculated as weighted sum of different functional choices, as given by mask_prob
            if self.gumbel:
                a_ID = a
                a_ZERO = a * 0.
                a_ReLU = F.relu(a)

                mask_type = gumbel_softmax(
                    self.mask_prob[i//2],
                    batch_size=B,
                    temperature=self.config['train']['gumbel_settings']['temperature'],
                    hard=self.config['train']['gumbel_settings']['hard'])

                a = mask_type[:, :, 0] * a_ZERO + \
                        mask_type[:, :, 1] * a_ReLU + \
                        mask_type[:, :, 2] * a_ID
            else:
                a = F.relu(a)

            # note we only use the gumbel output on places of mask_ReLU
            # the other places we have already computed with mask_ID and mask_ZERO (not added since zero)
            input = relu_ID + a * mask_ReLU[i//2]

        # apply the final affine layer
        output = self.model[len(self.layers)-1](input)

        activations = [d.data.cpu().numpy() for d in activations]

        # output: B x state_dim
        # always predict the residual
        output = output + state[:, -1]

        return output, activations


    def update_mask_based_on_mask_prob(self, n_remain):
        # update the masks based on the derived mask probability
        sort_lst = []
        mask = []

        n_neuron = 0
        n_determined = 0

        n_relu_layer = len(self.mask)

        for idx_relu in range(n_relu_layer):
            rec = []

            n_neuron_this_layer = self.mask[idx_relu].shape[1]
            prob = F.softmax(self.mask_prob[idx_relu], -1)

            mask.append(np.ones((1, n_neuron_this_layer)))

            for idx_node in range(n_neuron_this_layer):
                n_neuron += 1

                if self.mask[idx_relu][0, idx_node] != 1: # if the Relu is already determined
                    # making the mask persistent across pruning iterations
                    mask[idx_relu][0, idx_node] = self.mask[idx_relu][0, idx_node]
                    n_determined += 1
                else:
                    p = prob[idx_node].data.cpu().numpy()
                    sort_lst.append((max(p[0], p[2]), idx_relu, idx_node, p))

        sort_lst.sort(key=lambda x : x[0]) # sort based on probability
        v_clip = 0.


        for i in range(n_remain):
            v, idx_relu, idx_node, p = sort_lst[i] # start with the lowest probability of not Relu
            assert self.mask[idx_relu][0, idx_node] == 1 # assert that this was previously still Relu
            mask[idx_relu][0, idx_node] = 1
            v_clip = v

        for i in range(n_remain, n_neuron - n_determined): # after fixing the n_remain number of Relus, change other Relus to ID or zero
            v, idx_relu, idx_node, p = sort_lst[i]
            assert self.mask[idx_relu][0, idx_node] == 1
            if v == p[0]:
                mask[idx_relu][0, idx_node] = 0
            elif v == p[2]:
                mask[idx_relu][0, idx_node] = 2

        print("")
        print("Updated mask...")
        print('n_neuron', n_neuron)
        print('n_determined', n_determined)
        print('n_remain', n_remain)
        print('remain_ratio', n_remain / float(n_neuron))
        print('clip value', v_clip)
        print("")

        mask = [torch.FloatTensor(d).cuda() for d in mask]
        self.mask = mask
        return

    def rollout_model(self,
                      state_init, # [B, n_history, obs_dim]
                      action_seq, # [B, n_history + n_rollout - 1, action_dim]
                      ):

        """
        Rolls out the dynamics model for the given number of steps
        """

        assert len(state_init.shape) == 3
        assert len(action_seq.shape) == 3

        B, n_history, obs_dim = state_init.shape
        _, n_tmp, action_dim = action_seq.shape

        # if state_init and action_seq have same size in dim=1
        # then we are just doing 1 step prediction
        n_rollout = n_tmp - n_history + 1
        assert n_rollout > 0, "n_rollout = %d must be greater than 0" % (n_rollout)

        state_cur = state_init
        state_pred_list = []
        activation_list = []


        for i in range(n_rollout):

            # [B, n_history, action_dim]
            actions_cur = action_seq[:, i:i+n_history]
            # state_cur is [B, n_history, obs_dim]

            model_input = {'observation': state_cur, 'action': actions_cur}

            # [B, obs_dim]
            obs_pred, activations = self.forward(model_input)

            activation_list.append(activations)

            state_cur = torch.cat([state_cur[:, 1:], obs_pred.unsqueeze(1)], 1)
            state_pred_list.append(obs_pred)

        # [B, n_rollout, obs_dim]
        state_pred_tensor = torch.stack(state_pred_list, axis=1)

        return {'state_pred': state_pred_tensor,
                'activation': activation_list}

    def save_model(self, name):
        save_base_path = '%s/%s' % (self.model_dir, name)
        # save both the model in binary form, and also the state dict
        torch.save(self.state_dict(), save_base_path + "_state_dict.pth")
        torch.save(self, save_base_path + "_model.pth")

    def update_bounds(self, activation):
        # activation: [data 1: [rollout step 1: [layer 1 activations: [batch_size, #neurons at layer 1], layer 2 activations, ...], rollout step 2...], data 2: ...]
        for data in activation:
            for rollout in data:
                for i, layer in enumerate(rollout):
                    assert(layer.shape[0] == self.config['train']['batch_size'])
                    assert(len(layer.shape) == 2)
                    self.lb[i] = np.minimum(self.lb[i], layer.min(axis=0))
                    self.ub[i] = np.maximum(self.ub[i], layer.max(axis=0))
