import unittest
from torch import nn
import numpy as np
import inspect
import torch
from torch import sigmoid, tanh
from math import pi, sqrt
from torch.nn.functional import linear
from torch.nn.utils.rnn import PackedSequence
import warnings
from scipy.integrate import quadrature
from scipy.interpolate import PchipInterpolator as MonotonicSpline
from scipy.interpolate import interp1d
from pathlib import Path



def _filtered_eval(func, **kwargs):
    """
    Takes kwargs and passes ONLY the named parameters that are specified in the callable func
    :param func: Callable for which we'll filter the kwargs and then pass them
    :param kwargs:
    :return:
    """
    args = inspect.signature(func)
    right_ones = kwargs.keys() & args.parameters.keys()
    newargs = {key: kwargs[key] for key in right_ones}
    return func(**newargs)

def _no_grad_block_orthogonal(tensor, alpha):
    shp = tensor.shape
    assert shp[0] == shp[1], 'Non-square h->h'
    with torch.no_grad():
        blocks = [ torch.nn.init.orthogonal_(torch.zeros(2,2)) for i in range(shp[0]//2) ]
        q = block_diag(blocks)
        tensor.view_as(q).copy_(q)
        tensor.mul_(alpha)
    return tensor


def _initialize_only_hh(func, mode, iterator, hidden_size, gain, update_bias=0.):
    #assert hidden_size % 2 == 0, 'initialization requires even dimensional hidden states'  # changed by Ian for toy model
    for name, param in iterator:
        torch.nn.init.zeros_(param)
        # if 'bias' in name:
        #     torch.nn.init.zeros_(param)
        if 'bias_ih' in name and mode == GRU and '_no_grad_block_orthogonal' in inspect.getsource(func):
            torch.nn.init.constant_(param, val=update_bias)
        if 'weight_ih' in name:
            if mode not in ['LSTM','GRU']:
                torch.nn.init.normal_(param,mean=0., std=gain/sqrt(hidden_size))  # equivalent to Xavier Glorot's
                # normal initializer
            elif mode == 'LSTM':
                raise NotImplementedError
            elif mode == 'GRU':
                torch.nn.init.normal_(param[-hidden_size:, :], mean=0., std=gain/sqrt(hidden_size))
        elif 'weight_hh' in name:
            if mode not in ['LSTM','GRU']:
                func(param)
            elif mode == 'LSTM':
                raise NotImplementedError
            elif mode == 'GRU':
                func(param[-hidden_size:, :])

def _orthogonal(mode, iterator, hidden_size, alpha=1., gain=1.) -> None:
    func = lambda param: torch.nn.init.orthogonal_(param, gain=alpha)
    _initialize_only_hh(func, mode, iterator, hidden_size, gain)


def _critical(mode, iterator, hidden_size, seq_length: int = 1_000, data_norm: float = 1.):
    if 'RNN' in mode:
        raise NotImplementedError
    elif mode != 'RNN':
        raise NotImplementedError


__initializers__ = {
                    'orthogonal': _orthogonal,
                    }


def _split_init_kwargs(kwargs):
    """
    Split kwargs into accepted by base class `kwargs` and initializer specific ones `init_kwargs`
    :param kwargs: kwargs passed to any RNN architecture derived from RNNBase
    :return: kwargs
    :return: init_kwargs
    """
    __initializer_opts__ = ['init', 'min_angle', 'max_angle', 'reflect', 'alpha', 'gain', 'data_norm', 'seq_length',
                            'update_bias']

    init_kwargs = {x: kwargs[x] for x in kwargs if x in __initializer_opts__}
    kwargs = {x: kwargs[x] for x in kwargs if x not in __initializer_opts__}

    return kwargs, init_kwargs

# class LYAPUNOV(ABC):
#
#     def lyapunov_exponents(self):
#         pass  # TODO: write out Benettin algorithm for computing LE


class RNN(nn.RNN):  # CustomInit has to go first to ensure correct MRO
    def __init__(self, *args, **kwargs):
        """
        :param args:
        :param kwargs:
        """
        kwargs, init_kwargs = _split_init_kwargs(kwargs)
        self.init_kwargs = init_kwargs
        self.mt19937 = np.random.MT19937()
        self.hh_seed = np.random.get_state()
        super(RNN, self).__init__(*args, **kwargs)

    def reset_parameters(self) -> None:
        """
        Initializes weights (in place)
        :return:
        """
        kwargs = self.init_kwargs
        init = kwargs.pop('init', None)
        if init is None or init=="default":
            nn.RNN.reset_parameters(self)
        elif init == 'limitcycle':
            self.limitcycle(**kwargs)
        elif init == 'orthogonal':
            self.orthogonal(**kwargs)
        else:
            raise NotImplementedError

    def orthogonal(self, **kwargs):
        kwargs = {"mode":self.mode,
                  "iterator":self.named_parameters(),
                  "hidden_size":self.hidden_size,
                  **kwargs}
        _filtered_eval(_orthogonal,**kwargs)


    def Dfx(self, hx, input=None):
        if input is None:
            input = torch.zeros()
        if self.num_layers > 1:
            raise NotImplementedError
        else:
            Wih = self.weight_ih_l0
            Whh = self.weight_hh_l0
            bhh = self.bias_hh_l0
            bih = self.bias_ih_l0
            Df = Whh - torch.diag_embed(torch.tanh(linear(input, Wih, bih) + linear(hx, Whh, bhh))**2) @ Whh
            return Df

    def forward(self, input, hx=None):
        orig_input = input
        if isinstance(orig_input, PackedSequence):
            input, batch_sizes, sorted_indices, unsorted_indices = input
            max_batch_size = batch_sizes[0]
            max_batch_size = int(max_batch_size)
        else:
            max_batch_size = input.size(0) if self.batch_first else input.size(1)
        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            self.mt19937.state = self.hh_seed
            random = np.random.RandomState(self.mt19937)
            rand = random.rand(self.num_layers * num_directions,
                         1, self.hidden_size)
            rand = np.tile(rand, [1,max_batch_size,1])
            rand = torch.from_numpy(rand).to(device=orig_input.device, dtype=orig_input.dtype)
            hx = rand
        return nn.RNN.forward(self, orig_input, hx=hx)

class GRU(nn.GRU):

    def __init__(self, *args, **kwargs):
        kwargs, init_kwargs = _split_init_kwargs(kwargs)
        self.init_kwargs = init_kwargs
        self.mt19937 = np.random.MT19937()
        self.hh_seed = np.random.get_state()
        super(GRU, self).__init__(*args, **kwargs)

    def reset_parameters(self) -> None:
        """
        Initializes weights (in place)
        :return:
        """
        kwargs = self.init_kwargs
        init = kwargs.pop('init', None)
        if init is None or init=="default":
            nn.GRU.reset_parameters(self)
        elif init == 'limitcycle':
            self.limitcycle(**kwargs)
        elif init == 'orthogonal':
            self.orthogonal(**kwargs)
        else:
            raise NotImplementedError

    def orthogonal(self, **kwargs):
        kwargs = {"mode":self.mode,
                  "iterator":self.named_parameters(),
                  "hidden_size":self.hidden_size,
                  **kwargs}
        _filtered_eval(_orthogonal,**kwargs)
        # _orthogonal(mode=self.mode,
        #             iterator=self.named_parameters(), hidden_size=self.hidden_size, **kwargs)

    def Dfx(self, hx, input=None):
        if self.num_layers > 1:
            raise NotImplementedError
        else:
            Wih = self.weight_ih_l0
            Whh = self.weight_hh_l0
            bhh = self.bias_hh_l0
            bih = self.bias_ih_l0

            gh = torch.nn.functional.linear(hx, Whh, bhh)
            gi = torch.nn.functional.linear(input, Wih, bih)

            Whr, Whi, Whn = Whh.chunk(3, 0)

            i_r, i_i, i_n = gi.chunk(3, -1)
            h_r, h_i, h_n = gh.chunk(3, -1)
            r = sigmoid(i_r + h_r)
            i = sigmoid(i_i + h_i)
            n = tanh(i_n + r * h_n)

            drdh = torch.diag_embed( r* (1-r) )@Whr
            didh = torch.diag_embed( i* (1-i) )@Whi
            dndh = torch.diag_embed(1-n**2)@(drdh@torch.diag_embed(h_n) + torch.diag_embed(r)@Whn)

            Df = dndh + didh@torch.diag_embed(hx-n) + torch.diag_embed(i)@\
                 (torch.eye(self.hidden_size, self.hidden_size) - dndh)
            return Df

    def forward(self, input, hx=None):
        orig_input = input
        if isinstance(orig_input, PackedSequence):
            input, batch_sizes, sorted_indices, unsorted_indices = input
            max_batch_size = batch_sizes[0]
            max_batch_size = int(max_batch_size)
        else:
            max_batch_size = input.shape[0] if self.batch_first else input.shape[1]
        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            self.mt19937.state = self.hh_seed
            random = np.random.RandomState(self.mt19937)
            rand = random.rand(self.num_layers * num_directions,
                               1, self.hidden_size)
            rand = np.tile(rand, [1,max_batch_size,1])
            rand = torch.from_numpy(rand).to(device=orig_input.device, dtype=orig_input.dtype)
            hx = rand

        return nn.GRU.forward(self, orig_input, hx=hx)

class LSTM(nn.LSTM):

    def __init__(self, *args, init="forget_bias", **kwargs):
        kwargs, init_kwargs = _split_init_kwargs(kwargs)
        self.init_kwargs = init_kwargs
        self.mt19937 = np.random.MT19937()
        self.hh_seed = np.random.get_state()
        super(LSTM, self).__init__(*args, **kwargs)

    def reset_parameters(self) -> None:
        """
        Initializes weights (in place)
        :return:
        """
        kwargs = self.init_kwargs
        init = kwargs.pop('init', None)
        nn.LSTM.reset_parameters(self)
        if init == "forget_bias":
            self.set_forget_bias()

    def set_forget_bias(self):
        """
        Initializes the forget gate bias to be `large` as per
        #    http://proceedings.mlr.press/v37/jozefowicz15.pdf
        # and https://ieeexplore.ieee.org/document/818041
        :return:
        """
        for name, param in self.named_parameters():
            if 'bias_ih' in name:
                torch.nn.init.zeros_(param)
            elif 'bias_hh' in name:
                torch.nn.init.zeros_(param)
                torch.nn.init.constant_(param[self.hidden_size:2*self.hidden_size,:], val=2.)

    def Dfx(self, hx, input=None):
        if self.num_layers > 1:
            raise NotImplementedError
        else:
            raise NotImplementedError

    def forward(self, input, hx=None):
        orig_input = input
        if isinstance(orig_input, PackedSequence):
            input, batch_sizes, sorted_indices, unsorted_indices = input
            max_batch_size = batch_sizes[0]
            max_batch_size = int(max_batch_size)
        else:
            max_batch_size = input.size(0) if self.batch_first else input.size(1)
        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            self.mt19937.state = self.hh_seed
            random = np.random.RandomState(self.mt19937)
            rand = random.rand(self.num_layers * num_directions,
                               1, self.hidden_size)
            rand = np.tile(rand, [1,max_batch_size,1])
            rand = torch.from_numpy(rand).to(device=orig_input.device, dtype=orig_input.dtype)
            hx = (rand, rand)
        return nn.LSTM.forward(self, orig_input, hx=hx)

# linear dynamical system whose input is a linear transformation of the data
class LINEAR(nn.Module):

    def __init__(self, *args, **kwargs):

        super(LINEAR, self).__init__()

        kwargs, init_kwargs = _split_init_kwargs(kwargs)
        self.init_kwargs = init_kwargs

        self.batch_first = kwargs['batch_first']

        self.input_size = kwargs['input_size']
        self.hidden_size = kwargs['hidden_size']

        self.weight_ih_l0 = nn.Linear(self.input_size, self.hidden_size)
        self.weight_hh_l0 = nn.Linear(self.hidden_size, self.hidden_size, bias=False)

        self.mt19937 = np.random.MT19937()
        self.hh_seed = np.random.get_state()

    def reset_parameters(self) -> None:
        """
        Initializes weights (in place)
        """
        kwargs = self.init_kwargs
        init = kwargs.pop('init', None)
        if init is None or init=="default":
            nn.RNN.reset_parameters(self)
        elif init == 'limitcycle':
            self.limitcycle(**kwargs)
        elif init == 'orthogonal':
            self.orthogonal(**kwargs)
        else:
            raise NotImplementedError

    def orthogonal(self, **kwargs):
        kwargs = {"mode":self.mode,
                  "iterator":self.named_parameters(),
                  "hidden_size":self.hidden_size,
                  **kwargs}
        _filtered_eval(_orthogonal,**kwargs)

    def Dfx(self, hx, input=None):
        raise NotImplementedError

    def forward(self, x, hx=None):

        dev = x.get_device()

        N = x.shape[0]
        T = x.shape[1]

        hiddens = torch.zeros((N, T, self.hidden_size), dtype=torch.float, device=dev)
        initial = torch.randn((N, self.hidden_size), dtype=torch.float, device=dev)
        with torch.no_grad():
            hiddens[:, 0] = initial + self.weight_ih_l0(x[:, 0])

        for t in range(1, T):
            hidden = self.weight_hh_l0(hiddens[:, t - 1]) + self.weight_ih_l0(x[:, t - 1])
            with torch.no_grad():
                hiddens[:, t] = hidden

        return hiddens, hiddens


# Simple affine transformation of the last time step, doesn't take the past into account
class REGRESSION(nn.Module):

    def __init__(self, args, kwargs):

        super(REGRESSION, self, *args, **kwargs).__init__()

        self.size = kwargs['input_size']
        self.weights = nn.Linear(self.size, self.size)

        self.mt19937 = np.random.MT19937()
        self.hh_seed = np.random.get_state()

    def forward(self, x):

        device = 'cpu'
        ix = self.weights.weight.data.get_device()
        if ix > -1:
            device = ix
        x = x.to(device)

        outputs = self.weights(x)

        return outputs, outputs



__classes__ = [RNN, GRU]


class TestModel(unittest.TestCase):

    @staticmethod
    def create_model():
        # create some default test model
        raise NotImplementedError

    @staticmethod
    def load_data():
        # create some fake data
        raise NotImplementedError

    def test_init(self):
        from itertools import product
        for nn,init in product(__classes__, __initializers__.keys()):
            nn(10, 1024, num_layers=2, init=init)




if __name__ == '__main__':
    unittest.main()
