import torch
import torch.nn as nn
import numpy as np
from collections import OrderedDict
import pandas as pd
import pdb
import math
import torch.nn.functional as F
from typing import List, Literal, Tuple, Dict
from enum import Enum

def sine_init(m, scale=1.0):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            m.weight.uniform_(-scale * np.sqrt(6 / num_input) / 30, scale * np.sqrt(6 / num_input) / 30)


def first_layer_sine_init(m, scale=1.0):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            m.weight.uniform_(-scale / num_input, scale / num_input)


def output_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)


def init_weights_normal(m):
    if type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.kaiming_normal_(m.weight)

class Sine(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return torch.sin(30 * input)


class Functa(nn.Module):
    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features, 
                mod_type, mod_dim):
        super().__init__()

        self.w0 = 30.0
        self.mod_type = mod_type
        self.mod_dim = mod_dim
        
        self.net = []
        first_layer = nn.Linear(in_features, hidden_features)
        with torch.no_grad():
            first_layer.weight.uniform_(-1 / in_features, 1 / in_features)
        self.net.append(first_layer)
        
        for i in range(num_hidden_layers):
            layer = nn.Linear(hidden_features, hidden_features)
            with torch.no_grad():
                layer.weight.uniform_(-np.sqrt(6/hidden_features)/self.w0, 
                                    np.sqrt(6/hidden_features)/self.w0)
            self.net.append(layer)
        
        final_layer = nn.Linear(hidden_features, out_features)
        with torch.no_grad():
            final_layer.weight.uniform_(-np.sqrt(6/hidden_features)/self.w0,
                                      np.sqrt(6/hidden_features)/self.w0)
        self.net.append(final_layer)
        
        self.net = nn.Sequential(*self.net)
        

    def forward(self, x, mod):
        for i, layer in enumerate(self.net[:-1]):
            x = layer(x)
            
            if self.mod_type == "scale":
                x = torch.sin((mod[:,i]+1) * x)
            elif self.mod_type == "shift":
                x = torch.sin(x + mod[:,i])
            elif self.mod_type == "film":
                scale, shift = mod[:(self.mod_dim//2),i]+1, mod[(self.mod_dim//2):,i]
                x = torch.sin(scale * x + shift)
            elif self.mod_type == "spatial":
                x = torch.sin(x + mod[:,:,i])                
            else:
                x = torch.sin(self.w0 * x)
        return self.net[-1](x)






class INR_GFM(nn.Module):
    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features,
                 outermost_linear=True, nonlinearity='sine', scale=1, weight_init=None,
                 n_fourier_bases=49, high_freq=50, low_freq=50, phi_dim=100, oneway=False):
        super().__init__()

        self.scale = scale
        self.n_fourier_bases = n_fourier_bases
        self.oneway = oneway
        nls_and_inits = {
            'sine': (Sine(), lambda m: sine_init(m, scale), lambda m: first_layer_sine_init(m, scale)), 
        }
        nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None: 
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init

        self.input_layer = nn.Sequential(
            nn.Linear(in_features, hidden_features), nl
        )

        self.net = []
        for _ in range(num_hidden_layers):
            self.net.append(GFM_layer(
                n_input_dims=hidden_features,
                n_output_dims=hidden_features,
                n_fourier_bases=n_fourier_bases,
                is_reparam=True,
                activation=nl,
                high_freq_num=high_freq,
                low_freq_num=low_freq,
                phi_num=phi_dim
            ))
        self.net = nn.ModuleList(self.net)
        self.net_len = len(self.net)

        if outermost_linear:
            self.output_layer = nn.Sequential(nn.Linear(hidden_features, out_features))
        else:
            self.output_layer = nn.Sequential(
                nn.Linear(hidden_features, out_features), nl
            )

        self.input_layer.apply(nl_weight_init)
        if first_layer_init is not None:
            self.input_layer.apply(first_layer_init)
        
        if outermost_linear:
            self.output_layer.apply(output_layer_sine_init)
        
    def forward(self, coords, mod):

        x = self.input_layer(coords)
        
        for i, layer in enumerate(self.net):
            if self.oneway:
                x = layer(x, mod[:,i])
            else:
                x = layer(x, mod[:, :, i])
        
        x = self.output_layer(x)
        return x


class GFM_layer(nn.Module):
    def __init__(self, n_input_dims, n_output_dims, n_fourier_bases=None,
                 bias=True, is_reparam=False, is_first=False,
                 activation=None, high_freq_num=50, low_freq_num=50, phi_num=100):
        super().__init__()
        self.is_reparam = is_reparam
        self.n_input_dims = n_input_dims
        self.n_output_dims = n_output_dims
        self.n_fourier_bases = n_fourier_bases
        self.activation = activation
        
        self.high_freq_num = high_freq_num
        self.low_freq_num = low_freq_num
        self.phi_num = phi_num
        self.alpha = 0.01
        self.bases = self.init_bases()
        self.lamb = self.init_lamb()

        self.bias = None
        if bias:
            self.bias = nn.Parameter(torch.empty(n_output_dims))
            self.init_bias()

    def init_bases(self):
        phi_set = np.array([2*math.pi*i/self.phi_num for i in range(self.phi_num)])
        high_freq = np.array([i+1 for i in range(self.high_freq_num)])
        low_freq = np.array([(i+1)/self.low_freq_num for i in range(self.low_freq_num)])
        
        if len(low_freq) != 0:
            T_max = 2*math.pi/low_freq[0]
        else:
            T_max = 2*math.pi/min(high_freq)
            
        points = np.linspace(-T_max/2, T_max/2, self.n_fourier_bases)
        bases = torch.zeros((self.high_freq_num + self.low_freq_num)*self.phi_num, self.n_fourier_bases)
        
        i = 0
        for freq in low_freq:
            for phi in phi_set:
                base = torch.tensor([math.cos(freq*x+phi) for x in points])
                bases[i,:] = base
                i += 1
                
        for freq in high_freq:
            for phi in phi_set:
                base = torch.tensor([math.cos(freq*x+phi) for x in points])
                bases[i,:] = base
                i += 1
                
        bases = self.alpha * bases
        return nn.Parameter(bases, requires_grad=False)

    def init_lamb(self):
        m = (self.low_freq_num + self.high_freq_num) * self.phi_num
        lamb = torch.zeros(self.n_output_dims, m)
        
        with torch.no_grad():
            for i in range(m):
                dominator = torch.norm(self.bases[i,:], p=2)
                lamb[:,i].uniform_(-np.sqrt(6/m)/dominator/30, np.sqrt(6/m)/dominator/30)
                
        return nn.Parameter(lamb, requires_grad=True)
    

    def init_bias(self):
        with torch.no_grad():
            nn.init.zeros_(self.bias)

    def forward(self, input: torch.Tensor, fourier_mod: torch.Tensor = None):
        if fourier_mod.dim() == 2:
            lambda_mod_bias = fourier_mod[:, :self.n_fourier_bases]
            weights_mod_bias = fourier_mod[:, self.n_fourier_bases:]
        else:
            lambda_mod_bias = fourier_mod[:self.n_fourier_bases].unsqueeze(0)
            weights_mod_bias = fourier_mod[self.n_fourier_bases:].unsqueeze(0)
        
        weight_lambda = self.lamb + lambda_mod_bias.T
        weight = torch.matmul(weight_lambda, self.bases)
        
        out = torch.matmul(input, weight.transpose(0,1))
        out = out + self.bias
        out = out + weights_mod_bias/30
        
        return self.activation(out)

