import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from collections import OrderedDict
import math 

class Sine(nn.Module):
    def __init__(self, alpha=1.0):
        super(Sine, self).__init__()
        self.alpha = alpha
        print(self.alpha)
    def forward(self, x):
        return torch.sin(self.alpha * x)

class Swish(nn.Module):
    def __init__(self, div=1.1):
        super().__init__()
        self.beta = nn.Parameter(torch.tensor([0.5]))
        self.div=div
        print(self.div)
    def forward(self, x):
        return (x * torch.sigmoid_(x * F.softplus(self.beta))).div_(self.div)
    
class MLP(nn.Module):
    def __init__(self, args, input_size, hidden_size, output_size, depth, activation):
        super(MLP, self).__init__()
   
        if activation == 'tanh':
            self.activation = torch.nn.Tanh
        elif activation == 'relu':
            self.activation = torch.nn.ReLU
        elif activation == 'sin':
            self.activation = Sine
        elif activation == 'silu':
            self.activation = torch.nn.SiLU
        elif activation == 'swish':
            self.activation = Swish
        else:
            raise NotImplementedError 
        if hasattr(args, 'swish_div'):
            swish_div = args.swish_div 
        else:
            swish_div = 1.1

        layers = [nn.Linear(input_size, hidden_size), self.activation()]
       
        for _ in range(depth-2):
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(self.activation())
        layers.append(nn.Linear(hidden_size, output_size))
       
        self.layers = nn.Sequential(*layers)

        if args.init:
            self.init_weights()

    def init_weights(self):
        for name,m in self.layers.named_modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                print('init_xavier in MLP module')
                nn.init.xavier_uniform_(m.weight, 1.0)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.layers(x)
       
class CodeBilinear(nn.Module):
    __constants__ = ['in1_features', 'in2_features', 'out_features']
    in1_features: int
    in2_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in1_features: int, in2_features: int, out_features: int, t_scale=1.0, device=None, dtype=None) -> None:
    
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(CodeBilinear, self).__init__()
        self.in1_features = in1_features
        self.in2_features = in2_features
        self.out_features = out_features
        self.A = nn.parameter.Parameter(torch.empty(out_features, in2_features, **factory_kwargs))
        self.B = nn.parameter.Parameter(torch.empty(out_features, in1_features, **factory_kwargs))
        self.bias = nn.parameter.Parameter(torch.empty(out_features, **factory_kwargs))
        self.t_scale = t_scale 
        self.reset_parameters()

    def reset_parameters(self) -> None:
        bound = 1 / math.sqrt(self.in1_features)
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        self.A.data *= self.t_scale 
        nn.init.kaiming_uniform_(self.B, a=math.sqrt(5))
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor:
        # input1: feat; input2: code
        # return amplitude=W_f*feat + W_c*code + bias
        return F.linear(input1, self.B) + F.linear(input2, self.A) + self.bias
        
class MFNBase(nn.Module):
    """
    Multiplicative filter network base class.
    Adapted from https://github.com/boschresearch/multiplicative-filter-networks
    Expects the child class to define the 'filters' attribute, which should be 
    a nn.ModuleList of n_layers+1 filters with output equal to hidden_size.
    """
    def __init__(self, in_size, hidden_size, code_size, out_size, n_layers, t_scale=1):
        super().__init__()
        self.first = 3
        self.bilinear = nn.ModuleList(
            [CodeBilinear(in_size, code_size, hidden_size, t_scale)] +
            [CodeBilinear(hidden_size, code_size, hidden_size, t_scale) for _ in range(int(n_layers))]
        )
        self.output_bilinear = nn.Linear(hidden_size, out_size)
        print(f'MFNBase scale {t_scale}')

    def forward(self, x, code):
        # will be overrided 
        out = self.filters[0](x) * self.bilinear[0](x*0., code)
        for i in range(1, len(self.filters)):
            out = self.filters[i](x) * self.bilinear[i](out, code)
        out = self.output_bilinear(out)
       
        return out 

class FourierLayerPhase(nn.Module):
    """
    Sine filter as used in FourierNet.
    Adapted from https://github.com/boschresearch/multiplicative-filter-networks
    """
    def __init__(self, in_features, in_features_2, out_features, weight_scale, t_scale=1.0):
        super().__init__()
        self.out_features = out_features
        self.in_features = in_features 
        self.weight = nn.parameter.Parameter(torch.empty((out_features, in_features)))
        self.weight_2 = nn.parameter.Parameter(torch.empty((out_features, in_features_2)))
        self.bias = nn.parameter.Parameter(torch.empty((out_features)))
        self.weight_scale = weight_scale
        self.t_scale = t_scale 
        self.reset_parameters()
        print(f'fourier layer phase filter scale {t_scale}')

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weight_2, a=math.sqrt(5))
        self.bias.data.uniform_(-np.pi, np.pi)
        self.weight.data *= self.weight_scale
        self.weight_2.data *= self.t_scale 

    def forward(self, x, code):
        # code modulates the phase as: sin(W_x*x+W_c*c+bias)
        freq = F.linear(x, self.weight) 
        phase = F.linear(code, self.weight_2)
        out = freq + phase 
        return torch.sin(out+self.bias)

class FourierNetBoth(MFNBase):

    def __init__(self, in_size, hidden_size, code_size, out_size, n_layers=3, input_scale=256.0, t_scale=1, filter_scale=1, **kwargs):
        super().__init__(in_size, hidden_size, code_size, out_size, n_layers, t_scale)
        self.filters = nn.ModuleList(
                [FourierLayerPhase(in_size, code_size, hidden_size, input_scale / np.sqrt(n_layers + 1), filter_scale) for _ in range(n_layers + 1)])


    def forward(self, x, code):
        # code modulate the phase in self.filters and the amplitude in self.bilinear 
        out = self.filters[0](x, code) * self.bilinear[0](x*0., code)
        for i in range(1, len(self.filters)):
            out = self.filters[i](x, code) * self.bilinear[i](out, code)
        out = self.output_bilinear(out)
      
        return out

class Derivative(nn.Module):
    def __init__(self, args, input_size, hidden_size, output_size, depth, activation):
        super(Derivative, self).__init__()
        
        self.layers = MLP(args, input_size, hidden_size, output_size, depth, activation)

    def get_grad(self, t, x_theta):
        x,theta = x_theta 
        return self.layers(torch.cat([x,theta], dim=-1))

    def forward(self, t, x_theta):
        x,theta = x_theta 
        return (self.layers(torch.cat([x,theta], dim=-1)), torch.zeros_like(theta))

class ODEBlock(nn.Module):

    def __init__(self, args, odeint_func):
        super(ODEBlock, self).__init__()
      
        self.odefunc = Derivative(args, args.ode_code_size+args.param_dim, args.ode_hidden_size, args.ode_code_size, args.ode_depth, args.ode_activation)
        self.method = args.ode_method
        self.odeint_func = odeint_func
        print(self.odeint_func, self.method)
        if hasattr(args, 'step_size'):
            self.step_size = args.step_size  
        else:
            self.step_size = 0
    def forward(self, t, z0, theta):
       
        if self.step_size > 0:
            out = self.odeint_func(self.odefunc, y0=(z0,theta), t=t, method=self.method, options=dict(step_size=self.step_size))
        else:
            out = self.odeint_func(self.odefunc, y0=(z0,theta), t=t, method=self.method)
       
        return out[0] # (z, theta)

class FISRwithODE(nn.Module):
    def __init__(self, args, odeint_func):
        super(FISRwithODE, self).__init__()
        # dyanmics model
        self.ode_block = ODEBlock(args, odeint_func)
        # decoder
        self.isr_block = FourierNetBoth(in_size=args.spatial_size, hidden_size=args.isr_hidden_size, code_size=args.ode_code_size, out_size=args.isr_output_size, n_layers=args.isr_depth-2, input_scale=args.fourier_scale, t_scale=args.fourier_scale_t, filter_scale=args.fourier_scale_filter)
        # learnable embeddings for initial conditions in the trainings set 
        self.z0 = nn.parameter.Parameter(torch.zeros(args.num_init_cond, args.ode_code_size).float())   
        # learnable embeddsing for predicted solutions, used for consistency regularization
        self.z_consistency = nn.parameter.Parameter(torch.zeros(args.num_init_cond, args.fix_points_t, args.batch_theta, args.ode_code_size))
            
    def get_dyn(self, t, z_in, theta):

        z = self.ode_block(t.reshape(-1), z_in.detach(), theta)
        return z # (batch_t, batch_theta, code_size)

    def get_dyn_grad(self, t, z, theta):
        dyn_grad = self.ode_block.odefunc.get_grad(t, (z, theta)) 
        return dyn_grad 

    def forward(self, x, z):
        return self.isr_block(x, z)
    
