import torch
from torch import nn
from torch.distributions import Normal
import torch.nn.functional as F
import numpy as np

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def initialize_weights_xavier(m, gain=1.0):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight, gain=gain)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)


def create_linear_network(input_dim, output_dim, hidden_units=[],
                          hidden_activation=nn.ReLU(), output_activation=None,
                          initializer=initialize_weights_xavier):
    assert isinstance(input_dim, int) and isinstance(output_dim, int)
    assert isinstance(hidden_units, list) or isinstance(hidden_units, list)

    layers = []
    units = input_dim
    for next_units in hidden_units:
        layers.append(nn.Linear(units, next_units))
        layers.append(hidden_activation)
        units = next_units

    layers.append(nn.Linear(units, output_dim))
    if output_activation is not None:
        layers.append(output_activation)

    return nn.Sequential(*layers).apply(initialize_weights_xavier)


class BaseNetwork(nn.Module):

    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))


class StateActionFunction(BaseNetwork):

    def __init__(self, state_dim, action_dim, hidden_units=[256, 256]):
        super().__init__()

        self.min_bin_width = 1e-3
        self.min_bin_height = 1e-3
        self.min_derivative = 1e-3
        self.num_support = 32
        self.K = 32

        self.net = create_linear_network(
            input_dim=state_dim+action_dim,
            output_dim=256,
            hidden_units=hidden_units)
        self.fc_q = nn.Linear(256, 3 * self.K -1).apply(initialize_weights_xavier)
        # scale
        self.alpha = nn.Linear(256, 1).apply(initialize_weights_xavier)
        self.beta = nn.Linear(256, 1).apply(initialize_weights_xavier)

    def forward(self, x):
        batch_size = x.size(0)
        x = F.relu(self.net(x))
        # output
        spline_param = self.fc_q(x)
        # scale
        scale_a = self.alpha(x)
        scale_a = torch.exp(scale_a)
        scale_b = self.beta(x)

        # split the last dimention to W, H, D
        W, H, D = torch.split(spline_param, self.K, dim=1)
        W, H = torch.softmax(W, dim=1), torch.softmax(H, dim=1)
        W = self.min_bin_width + (1 - self.min_bin_width * self.K) * W
        H = self.min_bin_height + (1 - self.min_bin_height * self.K) * H
        D = self.min_derivative + F.softplus(D)
        D = F.pad(D, pad=(1,1))
        constant = np.log(np.exp(1- 1e-3) - 1)
        D[..., 0] = constant
        D[..., -1] = constant

        # start and end x of each bin
        cumwidths = torch.cumsum(W, dim=-1)
        cumwidths = F.pad(cumwidths, pad=(1,0), mode='constant', value=0.0)
        cumwidths[..., -1] = 1.0
        widths = cumwidths[..., 1:] - cumwidths[..., :-1]  # (batch_sz, K)

        # start and end y of each bin
        cumheights = torch.cumsum(H, dim=-1)
        cumheights = F.pad(cumheights, pad=(1,0), mode='constant', value=0.0)
        cumheights = scale_a * cumheights + scale_b
        heights = cumheights[..., 1:] - cumheights[..., :-1]
        
        # get bin index for each tau
        tau = torch.arange(0.5 * (1 / self.num_support), 1, 1 / self.num_support).to(device)
        tau = tau.expand((batch_size, self.num_support))

        cumwidths_expand = cumwidths.unsqueeze(dim=1)
        cumwidths_expand = cumwidths_expand.expand(-1, self.num_support, -1) # (batch_sz, num_support, K+1)
        
        bin_idx = self.searchsorted_(cumwidths_expand, tau)

        # collect number
        input_cumwidths = cumwidths.gather(-1, bin_idx)     # x_i 
        input_bin_widths = widths.gather(-1, bin_idx)       # x_i+1 - x_i

        input_cumheights = cumheights.gather(-1, bin_idx)   # y_i
        input_heights = heights.gather(-1, bin_idx)         # y_i+1 - y_i

        delta = heights / widths
        
        input_delta = delta.gather(-1, bin_idx)             # (y_i+1 - y_i) / (x_i+1 - x_i)

        input_derivatives = D.gather(-1, bin_idx)           # d_i
        input_derivatives_plus_one = D[..., 1:].gather(-1, bin_idx) # d_i+1
        
        # calculate quadratic spline for each tau
        theta = (tau - input_cumwidths) / input_bin_widths  # theta = (x - x_i) / (x_i+1 - x_i)
        
        theta_one_minus_theta = theta * (1 - theta)         # theta * (1 - theta)

        numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
        denominator = input_delta + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
        outputs = input_cumheights + numerator / denominator
        return outputs

    def searchsorted_(self, bin_locations, inputs):
        return torch.sum(inputs[..., None] >= bin_locations, dim= -1) - 1


class TwinnedStateActionFunction(BaseNetwork):

    def __init__(self, state_dim, action_dim, hidden_units=[256, 256]):
        super().__init__()

        self.net1 = StateActionFunction(state_dim, action_dim, hidden_units)
        self.net2 = StateActionFunction(state_dim, action_dim, hidden_units)

    def forward(self, states, actions):
        assert states.dim() == 2 and actions.dim() == 2

        x = torch.cat([states, actions], dim=1)
        value1 = self.net1(x)
        value2 = self.net2(x)
        return value1, value2


class GaussianPolicy(BaseNetwork):
    LOG_STD_MAX = 2
    LOG_STD_MIN = -20

    def __init__(self, state_dim, action_dim, hidden_units=[256, 256]):
        super().__init__()

        self.net = create_linear_network(
            input_dim=state_dim,
            output_dim=2*action_dim,
            hidden_units=hidden_units)

    def forward(self, states):
        assert states.dim() == 2

        # Calculate means and stds of actions.
        means, log_stds = torch.chunk(self.net(states), 2, dim=-1)
        log_stds = torch.clamp(
            log_stds, min=self.LOG_STD_MIN, max=self.LOG_STD_MAX)
        stds = log_stds.exp_()

        # Gaussian distributions.
        normals = Normal(means, stds)

        # Sample actions.
        xs = normals.rsample()
        actions = torch.tanh(xs)

        # Calculate entropies.
        log_probs = normals.log_prob(xs) - torch.log(1 - actions.pow(2) + 1e-6)
        entropies = -log_probs.sum(dim=1, keepdim=True)

        return actions, entropies, torch.tanh(means)
