import torch

import functools
import numpy as np
import inspect

import torch.nn as nn
import torch.nn.functional as F

from torch import optim
from collections import namedtuple
from numpy import random
from scipy.optimize import brent, minimize_scalar, brentq, root_scalar # mellowmax paper recommends it

OptimizerSpec = namedtuple("OptimizerSpec", ["constructor", "kwargs", "lr_schedule"])


def default_optimizer_spec(learning_rate):
    return OptimizerSpec(optim.Adam, (), ConstantSchedule(learning_rate))

def gpu_assigner(is_gpu, gpu_id):
    if  is_gpu:
        assert torch.cuda.is_available(), 'Cuda is not available'
        device = torch.device("cuda:" + str(gpu_id))
    else:
        device = torch.device("cpu")
    return device

def _is_method(func):
    spec = inspect.signature(func)
    return 'self' in spec.parameters

def convert_args_to_tensor(positional_args_list=None, keyword_args_list=None, device='cpu'):
    """A decorator which converts args in positional_args_list to torch.Tensor

    Args:
        positional_args_list ([list]): [arguments to be converted to torch.Tensor. If None, 
        it will convert all positional arguments to Tensor]
        keyword_args_list ([list]): [arguments to be converted to torch.Tensor. If None, 
        it will convert all keyword arguments to Tensor]
        device ([str]): [pytorch will run on this device]
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            
            
            
            _device = device
            _keyword_args_list = keyword_args_list
            _positional_args_list = positional_args_list
            
            if keyword_args_list is None:
                _keyword_args_list = list(kwargs.keys())

            if positional_args_list is None:
                _positional_args_list = list(range(len(args)))
            
                if _is_method(func):
                    _positional_args_list = _positional_args_list[1:]
            
            args = list(args)
            for i, arg in enumerate(args):
                if i in _positional_args_list:
                    if type(arg) == np.ndarray:
                        if arg.dtype == np.double:
                            args[i] = torch.from_numpy(arg).type(torch.float32).to(_device)
                        else:
                            args[i] = torch.from_numpy(arg).to(_device)
                    elif type(arg) == list:
                        args[i] = torch.tensor(arg).to(_device)
                    elif type(arg) == torch.Tensor or type(arg) == int or type(arg) == float or type(arg) == bool:
                        continue
                    else:
                        raise ValueError('Arguments should be Numpy arrays, but argument in position {} is not: {}'.format(str(i), type(arg)))
            
            for key, arg in kwargs.items():
                if key in _keyword_args_list:
                    if type(arg) == np.ndarray:
                        if arg.dtype == np.double:
                            kwargs[key] = torch.from_numpy(arg).type(torch.float32).to(_device)
                        else:
                            kwargs[key] = torch.from_numpy(arg).to(_device) 
                    elif type(arg) == list:
                        kwargs[key] = torch.tensor(arg).to(_device)
                    elif type(arg) == torch.Tensor or type(arg) == int or type(arg) == float or type(arg) == bool:
                        continue
                    else:
                        raise ValueError('Arguments should be Numpy arrays, but argument in position {} is not: {}'.format(str(i), type(arg)))

            return func(*args, **kwargs)

        return wrapper

    return decorator

@convert_args_to_tensor([0], ['labels'])
def torch_one_hot(labels, one_hot_size):
    one_hot = torch.zeros(labels.shape[0], one_hot_size, device=labels.device)
    one_hot[torch.arange(labels.shape[0], device=labels.device), labels] = 1
    return one_hot

def default_optimizer():
    return OptimizerSpec(
        constructor=optim.Adam,
        lr_schedule=1e-3,
        kwargs={}
    )

def huber_loss(inp, target, delta=1.):
    # type: (Tensor, Tensor, float) -> Tensor
    # https://en.wikipedia.org/wiki/Huber_loss
    t = torch.abs(inp - target)
    return torch.where(
          t < delta, 0.5 * t ** 2,
          t * delta - (0.5 * delta ** 2)
          )


def constant_exploration_schedule(value):
   return ConstantSchedule(value)

def stepwise_exploration_schedule(num_timesteps, outside_value=0.1, portion_decay=0.1, exploration_type='epsilon-greedy'):

    if exploration_type == 'epsilon-greedy':
        starting_value = 1.0
        outside_value = outside_value
    elif exploration_type == 'softmax':
        pass
    elif exploration_type == 'resmax':
        starting_value = 0
        outside_value = 1024

    return PiecewiseSchedule(
        [
            (0, starting_value),
            (num_timesteps * portion_decay, outside_value),
        ], outside_value=outside_value
    )

class Schedule(object):
    def value(self, t):
        """Value of the schedule at time t"""
        raise NotImplementedError()


class ConstantSchedule(Schedule):
    def __init__(self, value):
        """Value remains constant over time.
        Parameters
        ----------
        value: float
            Constant value of the schedule
        """
        self._v = value

    def value(self, t):
        """See Schedule.value"""
        return self._v


def linear_interpolation(l, r, alpha):
    return l + alpha * (r - l)


class PiecewiseSchedule(Schedule):
    def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None):
        """Piecewise schedule.
        endpoints: [(int, int)]
            list of pairs `(time, value)` meanining that schedule should output
            `value` when `t==time`. All the values for time must be sorted in
            an increasing order. When t is between two times, e.g. `(time_a, value_a)`
            and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs
            `interpolation(value_a, value_b, alpha)` where alpha is a fraction of
            time passed between `time_a` and `time_b` for time `t`.
        interpolation: lambda float, float, float: float
            a function that takes value to the left and to the right of t according
            to the `endpoints`. Alpha is the fraction of distance from left endpoint to
            right endpoint that t has covered. See linear_interpolation for example.
        outside_value: float
            if the value is requested outside of all the intervals sepecified in
            `endpoints` this value is returned. If None then AssertionError is
            raised when outside value is requested.
        """
        idxes = [e[0] for e in endpoints]
        assert idxes == sorted(idxes)
        self._interpolation = interpolation
        self._outside_value = outside_value
        self._endpoints      = endpoints

    def value(self, t):
        """See Schedule.value"""
        for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
            if l_t <= t and t < r_t:
                alpha = float(t - l_t) / (r_t - l_t)
                return self._interpolation(l, r, alpha)

        # t does not belong to any of the pieces, so doom.
        assert self._outside_value is not None
        return self._outside_value

class LinearSchedule(Schedule):
    def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
        """Linear interpolation between initial_p and final_p over
        schedule_timesteps. After this many timesteps pass final_p is
        returned.
        Parameters
        ----------
        schedule_timesteps: int
            Number of timesteps for which to linearly anneal initial_p
            to final_p
        initial_p: float
            initial output value
        final_p: float
            final output value
        """
        self.schedule_timesteps = schedule_timesteps
        self.final_p            = final_p
        self.initial_p          = initial_p

    def value(self, t):
        """See Schedule.value"""
        fraction  = min(float(t) / self.schedule_timesteps, 1.0)
        return self.initial_p + fraction * (self.final_p - self.initial_p)

import warnings


def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
    """
    Utility function for computing output of convolutions
    takes a tuple of (h,w) and returns a tuple of (h,w)
    """
    
    if type(h_w) is not tuple:
        h_w = (h_w, h_w)
    
    if type(kernel_size) is not tuple:
        kernel_size = (kernel_size, kernel_size)
    
    if type(stride) is not tuple:
        stride = (stride, stride)
    
    if type(pad) is not tuple:
        pad = (pad, pad)
    
    h = (h_w[0] + (2 * pad[0]) - (dilation * (kernel_size[0] - 1)) - 1)// stride[0] + 1
    w = (h_w[1] + (2 * pad[1]) - (dilation * (kernel_size[1] - 1)) - 1)// stride[1] + 1
    
    return h, w

def get_same_padding_size(kernel_size=1, stride=1, dilation=1):
    """
    A utility function which calculated the padding size needed to 
    get the same padding functionality as same as tensorflow Conv2D implementation
    """
    neg_padding_size = (stride - dilation*kernel_size + dilation -1)/2
    if neg_padding_size>0:
        return 0
    return int(np.ceil(np.abs(neg_padding_size)))

def mello_max(X, w):
    # X is a vector input
    X = np.squeeze(X)
    if len(X.shape) == 1:
        C = np.max(X)
        vals = w * (X - C)
        exp = np.exp(vals.astype(np.float128))
        return C + np.log(exp.mean()) / w #is more numerically stable according to author
    else:
        raise NotImplementedError('Solving beta equation with more than one sample of q-values not implemented yet')

def beta_equation(beta, X, w):
    X = np.squeeze(X)
    if len(X.shape) == 1:
        mm = mello_max(X, w)
        diffs = X - mm
        exps = np.exp(beta * diffs.astype(np.float128))
        return (diffs * exps).sum()
    else:
        raise NotImplementedError('Solving beta equation with more than one sample of q-values not implemented yet')
class RootFinder:
    
    def __init__(self):
        self.temp = 1.

    def mellow_max_root_finder(self, X, omega, bracket=(-60000, 60000.)):
        try:
            self.temp = brentq(beta_equation, bracket[0], bracket[1], args=(X, omega))
        except:
            print('----- Failed -----')
            print(X, self.temp)
        # minimum = root_scalar(beta_equation, args=(X, omega), method='brentq', bracket=bracket)
        return max(self.temp, 0.)