import torch

import functools
import numpy as np
import inspect

import torch.nn as nn
import torch.nn.functional as F

from torch import optim
from collections import namedtuple
from numpy import random
from scipy.optimize import brent, minimize_scalar, brentq, root_scalar # mellowmax paper recommends it

OptimizerSpec = namedtuple("OptimizerSpec", ["constructor", "kwargs", "lr_schedule"])


def default_optimizer_spec(learning_rate):
    return OptimizerSpec(optim.Adam, (), ConstantSchedule(learning_rate))

def gpu_assigner(is_gpu, gpu_id):
    if  is_gpu:
        assert torch.cuda.is_available(), 'Cuda is not available'
        device = torch.device("cuda:" + str(gpu_id))
    else:
        device = torch.device("cpu")
    return device

def _is_method(func):
    spec = inspect.signature(func)
    return 'self' in spec.parameters

def convert_args_to_tensor(positional_args_list=None, keyword_args_list=None, device='cpu'):
    """A decorator which converts args in positional_args_list to torch.Tensor

    Args:
        positional_args_list ([list]): [arguments to be converted to torch.Tensor. If None, 
        it will convert all positional arguments to Tensor]
        keyword_args_list ([list]): [arguments to be converted to torch.Tensor. If None, 
        it will convert all keyword arguments to Tensor]
        device ([str]): [pytorch will run on this device]
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            
            
            
            _device = device
            _keyword_args_list = keyword_args_list
            _positional_args_list = positional_args_list
            
            if keyword_args_list is None:
                _keyword_args_list = list(kwargs.keys())

            if positional_args_list is None:
                _positional_args_list = list(range(len(args)))
            
                if _is_method(func):
                    _positional_args_list = _positional_args_list[1:]
            
            args = list(args)
            for i, arg in enumerate(args):
                if i in _positional_args_list:
                    if type(arg) == np.ndarray:
                        if arg.dtype == np.double:
                            args[i] = torch.from_numpy(arg).type(torch.float32).to(_device)
                        else:
                            args[i] = torch.from_numpy(arg).to(_device)
                    elif type(arg) == list:
                        args[i] = torch.tensor(arg).to(_device)
                    elif type(arg) == torch.Tensor or type(arg) == int or type(arg) == float or type(arg) == bool:
                        continue
                    else:
                        raise ValueError('Arguments should be Numpy arrays, but argument in position {} is not: {}'.format(str(i), type(arg)))
            
            for key, arg in kwargs.items():
                if key in _keyword_args_list:
                    if type(arg) == np.ndarray:
                        if arg.dtype == np.double:
                            kwargs[key] = torch.from_numpy(arg).type(torch.float32).to(_device)
                        else:
                            kwargs[key] = torch.from_numpy(arg).to(_device) 
                    elif type(arg) == list:
                        kwargs[key] = torch.tensor(arg).to(_device)
                    elif type(arg) == torch.Tensor or type(arg) == int or type(arg) == float or type(arg) == bool:
                        continue
                    else:
                        raise ValueError('Arguments should be Numpy arrays, but argument in position {} is not: {}'.format(str(i), type(arg)))

            return func(*args, **kwargs)

        return wrapper

    return decorator

@convert_args_to_tensor([0], ['labels'])
def torch_one_hot(labels, one_hot_size):
    one_hot = torch.zeros(labels.shape[0], one_hot_size, device=labels.device)
    one_hot[torch.arange(labels.shape[0], device=labels.device), labels] = 1
    return one_hot

def default_optimizer():
    return OptimizerSpec(
        constructor=optim.Adam,
        lr_schedule=1e-3,
        kwargs={}
    )

def huber_loss(inp, target, delta=1.):
    # type: (Tensor, Tensor, float) -> Tensor
    # https://en.wikipedia.org/wiki/Huber_loss
    t = torch.abs(inp - target)
    return torch.where(
          t < delta, 0.5 * t ** 2,
          t * delta - (0.5 * delta ** 2)
          )

def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
    """
    Utility function for computing output of convolutions
    takes a tuple of (h,w) and returns a tuple of (h,w)
    """
    
    if type(h_w) is not tuple:
        h_w = (h_w, h_w)
    
    if type(kernel_size) is not tuple:
        kernel_size = (kernel_size, kernel_size)
    
    if type(stride) is not tuple:
        stride = (stride, stride)
    
    if type(pad) is not tuple:
        pad = (pad, pad)
    
    h = (h_w[0] + (2 * pad[0]) - (dilation * (kernel_size[0] - 1)) - 1)// stride[0] + 1
    w = (h_w[1] + (2 * pad[1]) - (dilation * (kernel_size[1] - 1)) - 1)// stride[1] + 1
    
    return h, w

def get_same_padding_size(kernel_size=1, stride=1, dilation=1):
    """
    A utility function which calculated the padding size needed to 
    get the same padding functionality as same as tensorflow Conv2D implementation
    """
    neg_padding_size = (stride - dilation*kernel_size + dilation -1)/2
    if neg_padding_size>0:
        return 0
    return int(np.ceil(np.abs(neg_padding_size)))

def stepwise_exploration_schedule(num_timesteps, outside_value=0.1, portion_decay=0.1, exploration_type='epsilon-greedy'):

    if exploration_type == 'epsilon-greedy':
        starting_value = 1.0
        outside_value = outside_value
    elif exploration_type == 'softmax':
        pass
    elif exploration_type == 'resmax':
        starting_value = 0
        outside_value = 1024

    return PiecewiseSchedule(
        [
            (0, starting_value),
            (num_timesteps * portion_decay, outside_value),
        ], outside_value=outside_value
    )

def episode_based_linear_exploration_schedule(value):
   return EpisodeBasedLinearSchedule(value)

# --------------- From here on are adopted from https://github.com/berkeleydeeprlcourse/homework_fall2019 -------

def constant_exploration_schedule(value):
   return ConstantSchedule(value)

class Schedule(object):
    def value(self, t):
        """Value of the schedule at time t"""
        raise NotImplementedError()


class ConstantSchedule(Schedule):
    def __init__(self, value):
        """Value remains constant over time.
        Parameters
        ----------
        value: float
            Constant value of the schedule
        """
        self._v = value

    def value(self, t):
        """See Schedule.value"""
        return self._v


class EpisodeBasedLinearSchedule(Schedule):
    def __init__(self, value):
        """Value remains constant over time.
        Parameters
        ----------
        value: float
            Constant value of the schedule
        """
        self._v = value

    def value(self, t):
        """See Schedule.value"""
        assert t>0
        return self._v/t


def linear_interpolation(l, r, alpha):
    return l + alpha * (r - l)

class PiecewiseSchedule(Schedule):
    def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None):
        """Piecewise schedule.
        endpoints: [(int, int)]
            list of pairs `(time, value)` meanining that schedule should output
            `value` when `t==time`. All the values for time must be sorted in
            an increasing order. When t is between two times, e.g. `(time_a, value_a)`
            and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs
            `interpolation(value_a, value_b, alpha)` where alpha is a fraction of
            time passed between `time_a` and `time_b` for time `t`.
        interpolation: lambda float, float, float: float
            a function that takes value to the left and to the right of t according
            to the `endpoints`. Alpha is the fraction of distance from left endpoint to
            right endpoint that t has covered. See linear_interpolation for example.
        outside_value: float
            if the value is requested outside of all the intervals sepecified in
            `endpoints` this value is returned. If None then AssertionError is
            raised when outside value is requested.
        """
        idxes = [e[0] for e in endpoints]
        assert idxes == sorted(idxes)
        self._interpolation = interpolation
        self._outside_value = outside_value
        self._endpoints      = endpoints

    def value(self, t):
        """See Schedule.value"""
        for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
            if l_t <= t and t < r_t:
                alpha = float(t - l_t) / (r_t - l_t)
                return self._interpolation(l, r, alpha)

        # t does not belong to any of the pieces, so doom.
        assert self._outside_value is not None
        return self._outside_value

class LinearSchedule(Schedule):
    def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
        """Linear interpolation between initial_p and final_p over
        schedule_timesteps. After this many timesteps pass final_p is
        returned.
        Parameters
        ----------
        schedule_timesteps: int
            Number of timesteps for which to linearly anneal initial_p
            to final_p
        initial_p: float
            initial output value
        final_p: float
            final output value
        """
        self.schedule_timesteps = schedule_timesteps
        self.final_p            = final_p
        self.initial_p          = initial_p

    def value(self, t):
        """See Schedule.value"""
        fraction  = min(float(t) / self.schedule_timesteps, 1.0)
        return self.initial_p + fraction * (self.final_p - self.initial_p)

import warnings
# def stable_exp(V):
    # # V is a numpy vector
    # print('V: ', V)
    # with warnings.catch_warnings():
        # try:
            # ex = np.exp(V-np.max(V))*np.exp(np.max(V))
        # except Warning:
            # print('----------Raised-----------')
            # print(V)
    # return np.exp(V-np.max(V))*np.exp(np.max(V))

# MellowMax Implementation
# def mello_max(X, w):
#     # X is a vector input
#     X = np.squeeze(X)
#     if len(X.shape) == 1:
#         C = np.max(X)
#         vals = w * (X - C)
#         exp = np.exp(vals.astype(np.float128))
#         return C + np.log(exp.mean()) / w #is more numerically stable according to author
#     else:
#         raise NotImplementedError('Solving beta equation with more than one sample of q-values not implemented yet')

# def mellow_max(q, omega):
#     c = np.max(q)  # to avoid overflow
#     return c + np.log(np.mean(np.exp((1 / omega) * (q - c)))) / (1 / omega)
#
#
# def mellow_max_optimize(beta, q, omega):
#     '''
#     function for brent q to optimize
#     '''
#     mm = mellow_max(q, omega)
#     return np.sum(
#         np.multiply(
#             np.exp(beta * (q - mellow_max(q))),
#             (q - mm)))
#
#
# def beta_equation(beta, X, w):
#     X = np.squeeze(X)
#     if len(X.shape) == 1:
#         mm = mello_max(X, w)
#         diffs = X - mm
#         exps = np.exp(beta * diffs.astype(np.float128))
#         return (diffs * exps).sum()
#     else:
#         raise NotImplementedError('Solving beta equation with more than one sample of q-values not implemented yet')
#
#
# class RootFinder:
#
#     def __init__(self):
#         self.temp = 1.
#
#     # def mellow_max_root_finder(self, X, omega, bracket=(-60000, 60000.)):
#     #     try:
#     #         self.temp = brentq(beta_equation, bracket[0], bracket[1], args=(X, omega))
#     #     except:
#     #         print('----- Failed -----')
#     #         print(X, self.temp)
#     #     # minimum = root_scalar(beta_equation, args=(X, omega), method='brentq', bracket=bracket)
#     #     return max(self.temp, 0.)
#     @staticmethod
#     def mellow_max_root_finder(q, omega):
#         """
#         Return probability distribution p over actions representing a stochastic policy
#
#         arguments:
#             q: values for each action for a fixed state
#
#         returns:
#             beta: the temperature value of softmax
#         """
#
#         q = q.flatten()
#
#         f = lambda beta: mellow_max_optimize(beta, q, omega)
#
#         bounds = np.array([-1e8, 1e8])
#
#         while True:
#             try:
#                 beta = brentq(f, bounds[0], bounds[1])
#                 break
#             except:
#                 bounds *= 10
#                 print('Failed Q: ', q)
#                 print('Bounds: ', bounds)
#         return beta


# class MellowMaxPolicy(object):
#     '''
#     from https://arxiv.org/pdf/1612.05628.pdf
#     '''
#
#     def __init__(self):
#         pass
#
#     def get_p(self, q):
#         """
#         Return probability distribution p over actions representing a stochastic policy
#
#         arguments:
#             q: values for each action for a fixed state
#
#         returns:
#             p: probability of each action
#         """
#
#         q = q.flatten()
#
#         def f(beta):
#             '''
#             function for brent q to optimize
#             '''
#             nonlocal q
#             mm = self.mellow_max(q)
#             return np.sum(
#                 np.multiply(
#                     np.exp(beta * (q - self.mellow_max(q))),
#                     (q - mm)))
#
#         bounds = np.array([-1e8, 1e8])
#
#         while True:
#             try:
#                 beta = brentq(f, bounds[0], bounds[1])
#                 break
#             except:
#                 bounds *= 10
#         p = softmax(beta * q)
#         return p
#
#     def mellow_max(self, q, omega):
#         c = np.max(q)  # to avoid overflow
#         return c + np.log(np.mean(np.exp((1 / omega) * (q - c)))) / (1 / omega)
#
#     def get_action(self, q):
#         """
#         Select action accoring to policy
#         arguments:
#             q: values for each action for a fixed state
#
#         returns:
#             a: the selected action
#         """
#
#         try:
#             p = self.get_p(q)
#             p /= p.sum()  # normalize for fp issues with sampling
#             num_actions = q.shape[0]
#             a = np.random.choice(num_actions, p=p)
#         except:
#             print(p)
#             assert False
#         return a


def mellow_max(q, omega):
    c = np.max(q)  # to avoid overflow
    return c + np.log(np.mean(np.exp((1 / omega) * (q - c)))) / (1 / omega)


def mellow_max_optimize(beta, q, omega):
    '''
    function for brent q to optimize
    '''
    mm = mellow_max(q, omega).astype(np.float128)
    return np.sum(
        np.multiply(
            np.exp(beta * (q - mm)),
            (q - mm)))


class RootFinder:

    def __init__(self):
        self.temp = 1.0
        self.beta = 1.0
        self.failed_count = 0

    def mellow_max_root_finder(self, q, omega):
        """
        Return probability distribution p over actions representing a stochastic policy

        arguments:
            q: values for each action for a fixed state

        returns:
            beta: the temperature value of softmax
        """

        q = q.flatten()


        f = lambda beta: mellow_max_optimize(beta, q, omega)

        bounds = np.array([-1e4, 1e4])

        while True:
            try:
                beta = brentq(f, bounds[0], bounds[1])
                break
            except Exception as e:
                bounds *= 10
                if bounds[1] == float('+inf'):
                    beta = self.beta
                    self.failed_count += 1
                    print('Failed Q: ', q)
                    break

        self.beta = beta
        return self.beta

