import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WEIGHTS_FINAL_INIT = 3e-3
BIAS_FINAL_INIT = 3e-4


def fan_in_uniform_init(tensor, fan_in=None):
    """Utility function for initializing actor and critic"""
    if fan_in is None:
        fan_in = tensor.size(-1)

    w = 1. / np.sqrt(fan_in)
    nn.init.uniform_(tensor, -w, w)


class Actor(nn.Module):
    def __init__(self, hidden_size, num_inputs, action_space):
        super(Actor, self).__init__()
        self.action_space = action_space
        num_outputs = action_space.shape[0]

        # Layer 1
        self.linear1 = nn.Linear(num_inputs, hidden_size[0])
        self.ln1 = nn.LayerNorm(hidden_size[0])

        # Layer 2
        self.linear2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.ln2 = nn.LayerNorm(hidden_size[1])

        # Output Layer
        self.mu = nn.Linear(hidden_size[1], num_outputs)

        # Weight Init
        fan_in_uniform_init(self.linear1.weight)
        fan_in_uniform_init(self.linear1.bias)

        fan_in_uniform_init(self.linear2.weight)
        fan_in_uniform_init(self.linear2.bias)

        nn.init.uniform_(self.mu.weight, -WEIGHTS_FINAL_INIT, WEIGHTS_FINAL_INIT)
        nn.init.uniform_(self.mu.bias, -BIAS_FINAL_INIT, BIAS_FINAL_INIT)

    def forward(self, inputs):
        x = inputs

        # Layer 1
        x = self.linear1(x)
        x = self.ln1(x)
        x = F.relu(x)

        # Layer 2
        x = self.linear2(x)
        x = self.ln2(x)
        x = F.relu(x)

        # Output
        mu = torch.tanh(self.mu(x))
        return mu


class Critic(nn.Module):
    def __init__(self, hidden_size, num_inputs, action_space, num_support):
        super(Critic, self).__init__()
        self.action_space = action_space
        num_outputs = action_space.shape[0]

        self.num_support = num_support
        self.K = num_support

        self.min_bin_width = 1e-3
        self.min_bin_height = 1e-3
        self.min_derivative = 1e-3

        # Layer 1
        self.linear1 = nn.Linear(num_inputs, hidden_size[0])
        self.ln1 = nn.LayerNorm(hidden_size[0])

        # Layer 2
        # In the second layer the actions will be inserted also 
        self.linear2 = nn.Linear(hidden_size[0] + num_outputs, hidden_size[1])
        self.ln2 = nn.LayerNorm(hidden_size[1])

        # Output layer (knots)
        self.V = nn.Linear(hidden_size[1], (3 * self.K - 1))

        # Scale
        self.alpha = nn.Linear(hidden_size[1], 1)
        self.beta = nn.Linear(hidden_size[1], 1)

        # Weight Init
        fan_in_uniform_init(self.linear1.weight)
        fan_in_uniform_init(self.linear1.bias)

        fan_in_uniform_init(self.linear2.weight)
        fan_in_uniform_init(self.linear2.bias)

        fan_in_uniform_init(self.V.weight)
        fan_in_uniform_init(self.V.bias)

        nn.init.uniform_(self.alpha.weight, -WEIGHTS_FINAL_INIT, WEIGHTS_FINAL_INIT)
        nn.init.uniform_(self.alpha.bias, -BIAS_FINAL_INIT, BIAS_FINAL_INIT)

        nn.init.uniform_(self.beta.weight, -WEIGHTS_FINAL_INIT, WEIGHTS_FINAL_INIT)
        nn.init.uniform_(self.beta.bias, -BIAS_FINAL_INIT, BIAS_FINAL_INIT)

    def forward(self, inputs, actions):
        batch_size = inputs.size(0)
        x = inputs

        # Layer 1
        x = self.linear1(x)
        x = self.ln1(x)
        x = F.relu(x)

        # Layer 2
        x = torch.cat((x, actions), 1)  # Insert the actions
        x = self.linear2(x)
        x = self.ln2(x)
        x = F.relu(x)

        # Output
        spline_param = self.V(x)
        scale_a = self.alpha(x)
        scale_a = torch.exp(scale_a)
        scale_b = self.beta(x)

        # split the last dimention to W, H, D
        W, H, D = torch.split(spline_param, self.K, dim=1)
        W, H = torch.softmax(W, dim=1), torch.softmax(H, dim=1)
        W = self.min_bin_width + (1 - self.min_bin_width * self.K) * W
        H = self.min_bin_height + (1 - self.min_bin_height * self.K) * H
        D = self.min_derivative + F.softplus(D)
        D = F.pad(D, pad=(1,1))
        constant = np.log(np.exp(1- 1e-3) - 1)
        D[..., 0] = constant
        D[..., -1] = constant

        # start and end x of each bin
        cumwidths = torch.cumsum(W, dim=-1)
        cumwidths = F.pad(cumwidths, pad=(1,0), mode='constant', value=0.0)
        cumwidths[..., -1] = 1.0
        widths = cumwidths[..., 1:] - cumwidths[..., :-1]  # (batch_sz, K)

        # start and end y of each bin
        cumheights = torch.cumsum(H, dim=-1)
        cumheights = F.pad(cumheights, pad=(1,0), mode='constant', value=0.0)
        cumheights = scale_a * cumheights + scale_b
        heights = cumheights[..., 1:] - cumheights[..., :-1]

        # get bin index for each tau
        tau = torch.arange(0.5 * (1 / self.num_support), 1, 1 / self.num_support).to(device)
        tau = tau.expand((batch_size, self.num_support))

        cumwidths_expand = cumwidths.unsqueeze(dim=1)
        cumwidths_expand = cumwidths_expand.expand(-1, self.num_support, -1) # (batch_sz, num_support, K+1)
        
        bin_idx = self.searchsorted_(cumwidths_expand, tau)

        # collect number
        input_cumwidths = cumwidths.gather(-1, bin_idx)     
        input_bin_widths = widths.gather(-1, bin_idx)       

        input_cumheights = cumheights.gather(-1, bin_idx) 
        input_heights = heights.gather(-1, bin_idx)        

        delta = heights / widths
        
        input_delta = delta.gather(-1, bin_idx)             

        input_derivatives = D.gather(-1, bin_idx)           
        input_derivatives_plus_one = D[..., 1:].gather(-1, bin_idx) 
        
        # calculate quadratic spline for each tau
        theta = (tau - input_cumwidths) / input_bin_widths  
        
        theta_one_minus_theta = theta * (1 - theta)   

        numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
        denominator = input_delta + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
        outputs = input_cumheights + numerator / denominator

        return outputs

    def searchsorted_(self, bin_locations, inputs):
        return torch.sum(inputs[..., None] >= bin_locations, dim= -1) - 1