
from functools import partial
from torch.nn import init
from torch.nn.parameter import Parameter
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from functools import reduce
import operator


from rational_kat_cu.kat_rational import KAT_Group

 
class KAN(nn.Module):
    """MLP as used in Vision Transformer, MLP-Mixer and related networks."""

    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_cfg=dict(type="KAT", act_init=["identity", "gelu"]),
            bias=True,
            drop=0.,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
        self.act1 = KAT_Group(mode = act_cfg['act_init'][0])
        self.drop1 = nn.Dropout(drop)
        self.act2 = KAT_Group(mode = act_cfg['act_init'][1])
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.act1(x)
        x = self.drop1(x)
        x = self.fc1(x)
        x = self.act2(x)
        x = self.drop2(x)
        x = self.fc2(x)
        return x
 

 
 
 
ACTIVATION = {"Sigmoid": torch.nn.Sigmoid(),
              "Tanh": torch.nn.Tanh(),
              "ReLU": torch.nn.ReLU(),
              "LeakyReLU": torch.nn.LeakyReLU(0.1),
              "ELU": torch.nn.ELU(),
              "GELU": torch.nn.GELU()
              }
 
 
 
class MLP(torch.nn.Module):
        def __init__(self, input_dim, hidden_dim, output_dim, n_layer, act):
            super().__init__()
            self.input_dim = input_dim
            self.hidden_dim = hidden_dim
            self.output_dim = output_dim
            self.n_layer = n_layer
            self.act = act
            self.input = torch.nn.Linear(self.input_dim, self.hidden_dim)
            self.hidden = torch.nn.ModuleList([torch.nn.Linear(self.hidden_dim, self.hidden_dim) for _ in range(self.n_layer)])
            self.output = torch.nn.Linear(self.hidden_dim, self.output_dim)
            
        def forward(self, x):
            r = self.act(self.input(x))
            for i in range(0, self.n_layer):
                r = r + self.act(self.hidden[i](r))
            r = self.output(r)
            return r
 

    
#---------------------------------------------------------
# initialization methods set
#---------------------------------------------------------
init_s = {"uniform": partial(torch.nn.init.uniform_),
          "normal": partial(torch.nn.init.normal_),
          "ones": partial(torch.nn.init.ones_),
          "zeros": partial(torch.nn.init.zeros_),
          "xavier_u": partial(torch.nn.init.xavier_uniform_),
          "xavier_n": partial(torch.nn.init.xavier_normal_),
          "kaiming_u": partial(torch.nn.init.kaiming_uniform_),
          "kaiming_n": partial(torch.nn.init.kaiming_normal_),
          "trunc_n": partial(torch.nn.init.trunc_normal_),
          "orthogonal": partial(torch.nn.init.orthogonal_)}
 
class KHINRNet(nn.Module):
    
    def __init__(self, n_block, n_mode, n_dim, n_head, n_layer, x_dim, y1_dim, y2_dim, attn, act, data):
        super().__init__()
        self.n_block = n_block
        self.n_mode = n_mode
        self.n_dim = n_dim
        self.n_head = n_head
        self.n_layer = n_layer
 
        self.act = ACTIVATION[act]
        
        self.x_dim = x_dim
        self.y1_dim = y1_dim
        self.y2_dim = y2_dim
 

 
 
        self.trunk_mlp = nn.Sequential(GaborLayer(self.x_dim, self.n_dim, 64/np.sqrt(6), 1/np.sqrt(6), 1)) #1/np.sqrt(6)))
        
                                                                                                                                                         
 
        self.filters = nn.ModuleList(
             [GaborLayer(self.x_dim, self.n_dim, self.n_dim/np.sqrt(n_block+1), 1/(n_block+1), 1) for _ in range(n_block)]    # Coomenting for idea 1.755
        )


        
        self.Weight_g = nn.ParameterList()
        self.Weight_z = nn.ParameterList()
        self.Weight_lat = nn.ParameterList()
        self.bias     = nn.ParameterList()
 
        for _ in range(self.n_block):
            g = nn.Parameter(torch.empty(self.n_dim, self.n_dim))
            z = nn.Parameter(torch.empty(self.n_dim, self.n_dim))
            b = nn.Parameter(torch.empty(self.n_dim))
 
            lat = nn.Parameter(torch.empty(self.n_dim, self.n_dim))


            init.kaiming_uniform_(g, a=math.sqrt(5))
            init.kaiming_uniform_(z, a=math.sqrt(5))
            init.kaiming_uniform_(lat, a=math.sqrt(5))
 

            bound = 1 / math.sqrt(2)
            init.uniform_(b, -bound, bound)
 
            self.Weight_g.append(g)
            self.Weight_z.append(z)
            self.bias.append(b)
            self.Weight_lat.append(lat)
 
        
        self.branch_mlp = MLP(self.x_dim + self.y1_dim , self.n_dim, self.n_dim, self.n_layer, self.act)
        self.out_mlp = MLP(self.n_dim, self.n_dim, self.y2_dim, self.n_layer, self.act)

        self.mode_mlp = torch.nn.Linear(self.n_dim, self.n_mode)

        
        
        self.Wv = torch.nn.Linear(self.n_dim, self.n_dim)
        self.ln = torch.nn.LayerNorm(self.n_dim)
 
        self.attn_blocks = torch.nn.Sequential(*[KAN(self.n_dim,self.n_dim,self.n_dim) for _ in range(0, self.n_block)])
 


        if  data == "ssh":
            self.latents = nn.Parameter(torch.FloatTensor(365, n_dim))
        elif data == "chl":
            self.latents = nn.Parameter(torch.FloatTensor(579, n_dim))
        elif data == "gst":
            self.latents = nn.Parameter(torch.FloatTensor(1024, n_dim))
        elif data == "sst":
            self.latents = nn.Parameter(torch.FloatTensor(360, n_dim))
        else:
            print(f"No {data} data found for model init")
            exit()
        
        self.latents = init_s["zeros"](self.latents)
     
     
    def _init_weights(self, module):
        if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.0002)
            if isinstance(module, torch.nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, torch.nn.LayerNorm):
            module.weight.data.fill_(1.0)
            module.bias.data.zero_()
 
 
 
    def forward(self, query, y, idx):
        
   
        x = self.trunk_mlp(y[...,:self.x_dim])
 
     
        coords = query
        query = self.trunk_mlp(query)
   
        score_encode = torch.softmax(self.mode_mlp(x) , dim=1)
        
        y = self.branch_mlp(y)

        v = self.Wv(self.ln(y))
 
 
        z = torch.einsum("bij,bic->bjc", score_encode, v)

        z_s = []
        for block in self.attn_blocks:   
            z = block(z)
            z_s.append(z)
 
        latents = []
        for latent in z_s:
            
            z = torch.mean(latent, dim = 1).unsqueeze(1)*query
            
            latents.append(z)
        
        outputs = []
      
        latent = self.latents[idx,:]
        for i, z in enumerate(latents):
    

     
            linear_coord  = self.filters[i](coords)@self.Weight_g[i]
            linear_latent = z @ self.Weight_z[i]

            linear_latent_time = torch.einsum('bj,oj->bo', latent, self.Weight_lat[i])
            linear_latent_time = linear_latent_time.unsqueeze(1)

            out = linear_coord + linear_latent + self.bias[i] + linear_latent_time
            outputs.append(out)
 
        final_output = reduce(operator.mul, outputs)

        
        r = self.out_mlp(final_output)
 
        return r
 
 
# Gabor layer

class GaborLayer(nn.Module):
    def __init__(self, in_features, out_features, weight_scale, alpha=1.0, beta=1.0):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.mu = nn.Parameter(2 * torch.rand(out_features, in_features) - 1)
        self.gamma = nn.Parameter(
            torch.distributions.gamma.Gamma(alpha, beta).sample((out_features,))
        )
        self.linear.weight.data *= weight_scale * torch.sqrt(self.gamma[:, None])
        self.linear.bias.data.uniform_(-np.pi, np.pi)
        self.param = Parameter(torch.rand(out_features))
        return
 
    def forward(self, x):
        D = (
            (x ** 2).sum(-1)[..., None]
            + (self.mu ** 2).sum(-1)[None, :]
            - 2 * x @ self.mu.T
        )
        return torch.sin(1.5*self.param*self.linear(x)) * torch.exp(-0.5 * D * self.gamma[None, :])
 
 
