import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import numpy as np
import scipy

def init_weights(net, init_dict, gain=1, input_class=None):
    def init_func(m):
        if input_class is None or type(m) == input_class:
            for key, value in init_dict.items():
                param = getattr(m, key, None)
                if param is not None:
                    if value == 'normal':
                        nn.init.normal_(param.data, 0.0, gain)
                    elif value == 'xavier':
                        nn.init.xavier_normal_(param.data, gain=gain)
                    elif value == 'kaiming':
                        nn.init.kaiming_normal_(param.data, a=0, mode='fan_in')
                    elif value == 'orthogonal':
                        nn.init.orthogonal_(param.data, gain=gain)
                    elif value == 'uniform':
                        nn.init.uniform_(param.data)
                    elif value == 'zeros':
                        nn.init.zeros_(param.data)
                    elif value == 'very_small':
                        nn.init.constant_(param.data, 1e-3*gain)
                    elif value == 'xavier1D':
                        nn.init.normal_(param.data, 0.0, gain/param.numel().sqrt())
                    elif value == 'identity':
                        nn.init.eye_(param.data)
                    else:
                        raise NotImplementedError('initialization method [%s] is not implemented' % value)
#activation functions
class quadratic(nn.Module):
    def __init__(self):
        super(quadratic,self).__init__()

    def forward(self,x):
        return x**2

class quadratic(nn.Module):
    def __init__(self):
        super(quadratic,self).__init__()

    def forward(self,x):
        return x*F.relu(x)

class cos(nn.Module):
    def __init__(self):
        super(cos,self).__init__()

    def forward(self,x):
        return torch.cos(x)

class sin(nn.Module):
    def __init__(self):
        super(sin,self).__init__()

    def forward(self,x):
        return torch.sin(x)

class swish(nn.Module):
    def __init__(self):
        super(swish,self).__init__()

    def forward(self,x):
        return torch.sigmoid(x)*x

class relu2(nn.Module):
    def __init__(self,order=2):
        super(relu2,self).__init__()
        self.a = nn.Parameter(torch.ones(1))
        self.order = order

    def forward(self,x):
        #return F.relu(self.a.to(x.device)*x)**(self.order)
        return F.relu(x)**(self.order)

class leakyrelu2(nn.Module):
    def __init__(self,order=2):
        super(leakyrelu2,self).__init__()
        self.a = nn.Parameter(torch.ones(1))
        #self.a = torch.ones(1)
        self.order = order

    def forward(self,x):
        return F.leaky_relu(self.a.to(x.device)*x)**self.order

class mod_softplus(nn.Module):
    def __init__(self):
        super(mod_softplus,self).__init__()

    def forward(self,x):
        return F.softplus(x) + x/2 - torch.log(torch.ones(1)*2).to(device=x.device)

class mod_softplus2(nn.Module):
    def __init__(self):
        super(mod_softplus2,self).__init__()

    def forward(self,x,d):
        return d*(1+d)*(2*F.softplus(x) - x  - 2*torch.log(torch.ones(1)*2).to(device=x.device))

class mod_softplus3(nn.Module):
    def __init__(self):
        super(mod_softplus3,self).__init__()

    def forward(self,x):
        return F.relu(x) + F.softplus(-torch.abs(x)) 

class swish(nn.Module):
    def __init__(self):
        super(swish,self).__init__()

    def forward(self,x):
        return x*torch.sigmoid(x) 

class soft2(nn.Module):
    def __init__(self):
        super(soft2,self).__init__()

    def forward(self,x):
        return torch.sqrt(x**2 + 1) / 2 + x / 2

class soft3(nn.Module):
    def __init__(self):
        super(soft3,self).__init__()

    def forward(self,x):
        return torch.logsigmoid(-x) 
class Shallow(nn.Module):
    def __init__(self,input_size,out_size):
        super(Shallow, self).__init__()
        self.net = nn.Sequential(nn.Linear(input_size,input_size),quadratic(),nn.Linear(input_size,out_size))

    def forward(self,x):
        return self.net(x)

class PositiveLinear(nn.Linear):
    def __init__(self, **args):
        super(PositiveLinear, self).__init__()


class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, layers, out_size, act=nn.LeakyReLU(), bn=True, bias=True):
        super(MLP, self).__init__()


        self.fc1 = nn.Linear(input_size,hidden_size, bias=False)
        if bn:
            self.bn = nn.BatchNorm1d(hidden_size)
        else:
            self.bn = None
        mid_list = []
        for i in range(layers):
            if bn:
                mid_list += [nn.Linear(hidden_size,hidden_size), nn.BatchNorm1d(hidden_size), act]
            else:
                mid_list += [nn.Linear(hidden_size,hidden_size, bias=False), act]
        self.mid = nn.Sequential(*mid_list)
        self.out = nn.Linear(hidden_size, out_size, bias=bias)
        self.act = act
        #init_weights(self, {'weights':'xavier', 'bias':'zeros'})

    def forward(self,x,cond=None):
        out = self.fc1(x)
        if self.bn:
            out = self.bn(out)
        out = self.act(out)
        out = self.mid(out)
        out = self.out(out)
        return (out)

class W0(nn.Module):
    def __init__(self, d, w_width, d_z):

        super(W0, self).__init__()
        self.W = nn.Parameter(nn.init.orthogonal_(torch.rand(w_width, 1, d, requires_grad=True)))
        self.b   = MLP (d, 64, 2, d_z, bn=False) 
        self.phi = MLP (d, 64, 2, d_z, bn=False) 
        #self.phi = ICNN(d, 256, 2)

    def forward(self, x, cond=None):

        out = self.b(x) + self.phi(self.W - x).mean(0)

        return (out)

from scipy.interpolate import interp1d
import math

class ICNN(nn.Module):
    def __init__(self, input_size, width, depth, 
            cond_size=0, 
            cond_width=0, 
            fn0=relu2(order=2), 
            fn=nn.LeakyReLU(), 
            fnu=nn.LeakyReLU()):

        super(ICNN, self).__init__()

        self.fn0 = fn0
        self.fn = fn
        self.cond_size = cond_size

        self.fc0 = nn.Linear(input_size,width,bias=False)

        if cond_size > 0:
            self.uc0   = nn.Linear(cond_size, cond_width, bias=True)
            self.cc0   = nn.Linear(cond_size, width, bias=False)
            mid_list   = [PICNN_block(input_size,width,width,cond_width,fn,fnu) for i in range(depth-1)]
            self.pout  = PICNN_block(input_size,width,1,cond_width,nn.Softplus(),fnu)
        else:
            mid_list = [ICNN_block(input_size,width,fn) for i in range(depth-1)]
            #mid_list += [ICNN_block(input_size,width,fn,no_x=True)]
            self.out_z = nn.Linear(width, 1, bias=False)
            self.out_x = nn.Linear(input_size, 1, bias=True)

        self.mid = nn.Sequential(*mid_list)
        init_weights(self, {'weight': 'orthogonal', 'bias': 'zeros'}, gain=1)

    def forward(self, x, cond=None):
        z0 = self.fc0(x)

        if self.cond_size > 0:
            u0 = self.uc0(cond)
            c0 = self.cc0(cond)
            z0 = self.fn0(z0 + c0)
            xn, zn, un = self.mid((x, z0, u0))
            _, z, _= self.pout((xn, zn, un))
            return z 
        else:
            z0 = self.fn0(z0)
            _, z = self.mid((x,z0))
            out = (z ** 2).sum(-1, keepdims=True)
            #out = (self.out_x(x) + self.out_z(z))
            #out =  self.out_z(z)
            return out

    def up_to_last(self, x, cond=None):
        z0 = self.fc0(x)

        if self.cond_size > 0:
            u0 = self.uc0(cond)
            c0 = self.cc0(cond)
            z0 = self.fn0(z0 + c0)
            _, z, _ = self.mid((x, z0, u0))
            return z
        else:
            z0 = self.fn0(z0)
            _, z = self.mid((x,z0))
            return z

    def second_to_last(self,x,z0):
        z0 = self.fn0(z0)
        _, z = self.mid((x,z0))
        return z

    def clamp(self):
        if self.cond_size == 0:
            self.out_z.weight.data.clamp_(0)
        for block in self.mid:
            block.clamp()

class ICNN_block(nn.Module):
    def __init__(self, x_size, zi_size, fn, no_x=False):
        super(ICNN_block, self).__init__()
        self.lin_x = nn.Linear(x_size, zi_size, bias=True)
        self.lin_z = nn.Linear(zi_size, zi_size, bias=False)
        self.fn = fn
        self.no_x = no_x

    def forward(self, input_):
        x = input_[0]
        z = input_[1]
        if self.no_x:
            out = self.fn(self.lin_z(z))
        else:
            out = self.fn(self.lin_x(x) + self.lin_z(z))
        return (x, out)

    def clamp(self):
        self.lin_z.weight.data.clamp_(0)

class PICNN_block(nn.Module):
    def __init__(self, x_size, zi_size, zout_size, ui_size, fn, fnu):
        super(PICNN_block, self).__init__()

        self.lin_u_hat = nn.Linear(ui_size, ui_size, bias=True)

        self.lin_u  = nn.Linear(ui_size, zout_size, bias=True)
        self.lin_uz = nn.Linear(ui_size, zi_size, bias=True)
        self.lin_ux = nn.Linear(ui_size, x_size,  bias=True)

        self.lin_x  = nn.Linear(x_size,  zout_size, bias=False)
        self.lin_z  = nn.Linear(zi_size, zout_size, bias=False)

        self.fn  = fn
        self.fnu = fnu


    def forward(self, input_):

        x = input_[0]
        z = input_[1]
        u = input_[2]

        u1  = self.fnu( self.lin_u_hat( u ) ) 

        pos = self.lin_z( z * F.relu( self.lin_uz( u ) ) )
        wx  = self.lin_x( x * self.lin_ux( u ) )
        wu  = self.lin_u( u )
        z1 = pos + wx + wu

        if self.fn:
            z1  = self.fn( z1 ) 

        return (x, z1, u1)

    def clamp(self):
        self.lin_z.weight.data.clamp_(0)

class PositionalEncodingLayer(nn.Module):
    def __init__(self, L=20, device='cpu'):
        super(PositionalEncodingLayer, self).__init__()
        scale1 = 2**torch.arange(0, L)*math.pi
        scale2 = 2**torch.arange(0, L)*math.pi + math.pi 
        self.scale = torch.stack((scale1,scale2),1).view(1,-1).to(device)

    def forward(self, x):
        xs = list(x.shape)
        vs = xs[:-1] + [-1]
        return torch.sin(x.unsqueeze(-1) @ self.scale).view(*vs)

class HistoryMLP(MLP):
    def __init__(self, t_history, history, input_size, hidden_size, layers, out_size, act=nn.LeakyReLU(), bn=False, bias=False, in_x=True, use_hist=True, pe=False):
        self.use_hist = use_hist
        if use_hist:
            self.interpolator = interp1d(t_history, history)
        if pe:
            super(HistoryMLP, self).__init__(2 * pe + int(use_hist) + int(in_x), hidden_size, layers, out_size, act, bn, bias)
            self.pe = PositionalEncodingLayer(L=pe)
        else:
            super(HistoryMLP, self).__init__(input_size + int(use_hist) + int(in_x), hidden_size, layers, out_size, act, bn, bias)
            self.pe = None

    def forward(self, x, t):
        if x is not None:
            if self.use_hist:
                hist = torch.tensor(self.interpolator(t)).float()
                x_in = torch.cat([x,hist,t], -1)
            else:
                x_in = torch.cat([x,t], -1)
                if self.pe is not None:
                    x_in = torch.cat([x,self.pe(t)], -1)
        else:
            if self.use_hist:
                hist = torch.tensor(self.interpolator(t)).float()
                x_in = torch.cat([hist,t], -1)
            else:
                x_in = t
                if self.pe is not None:
                    x_in = self.pe(t)
        return super().forward(x_in)

class SplineRegression(torch.nn.Module):
    def __init__(
            self,
            input_range,
            order=3,
            knots=10):
        super(SplineRegression, self).__init__()
        if isinstance(knots, int):
            knots = np.linspace(input_range[0], input_range[1], knots)
        num_knots = len(knots)

        knots = np.hstack([knots[0]*np.ones(order),
                           knots,
                           knots[-1]*np.ones(order)])
        self.basis_funcs = scipy.interpolate.BSpline(
            knots, np.eye(num_knots+order-1), k=order)
        self.linear = torch.nn.Linear(num_knots+order-1, 1)

        x = np.linspace(input_range[0], input_range[1], 100)
        y = self.basis_funcs(x)
        #print(y.shape)
        #plt.plot(x, y)
        #plt.show()

    def forward(self, x):
        x_shape = x.shape
        x_basis = self.basis_funcs(x.reshape(-1))
        x_basis = torch.from_numpy(x_basis).float()
        out = self.linear(x_basis)
        return out.reshape(x_shape)

