import torch
import torch.nn as nn

import sys

from models.s4_model import get_s4block 

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from typing import Union, Tuple, List, Optional
from libs.positional_encoding_module import RotaryEmbedding, apply_rotary_pos_emb, SirenNet
from libs.basics import PreNorm, PostNorm, GeAct, MLP, masked_instance_norm, GroupNorm, InstanceNorm
# from libs.attention import LowRankKernel
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from torch.nn.init import xavier_uniform_, constant_, xavier_normal_, orthogonal_
from libs.positional_encoding_module import RotaryEmbedding, GaussianFourierFeatureTransform, \
    apply_rotary_pos_emb, apply_2d_rotary_pos_emb, SirenNet, apply_3d_rotary_pos_emb


class PoolingReducer(nn.Module):
    def __init__(self,
                 in_dim,
                 hidden_dim,
                 out_dim):
        super().__init__()
        self.to_in = nn.Linear(in_dim, hidden_dim, bias=False)
        self.out_ffn = PreNorm(in_dim, MLP([hidden_dim, hidden_dim, out_dim], GeAct(nn.GELU())))

    def forward(self, x):
        # note that the dimension to be pooled will be the last dimension
        # x: b nx ... c
        x = self.to_in(x)
        # pool all spatial dimension but the first one
        ndim = len(x.shape)
        # we don't pool in 1D
        # x = x.mean(dim=tuple(range(2, ndim-1)))
        x = self.out_ffn(x)
        return x  # b nx c

class LowRankKernel(nn.Module):
    # low rank kernel, ideally operates only on one dimension
    def __init__(self,
                 dim,
                 dim_head,
                 heads,
                 positional_embedding='rotary',
                 pos_dim=1,
                 normalize=False,
                 softmax=False,
                 residual=True,
                 dropout=0,
                 scaling=1,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.dim_head = dim_head
        self.heads = heads
        self.normalize = normalize
        self.residual = residual
        if dropout > 1e-6:
            self.dropout = nn.Dropout(dropout)
        else:
            self.dropout = nn.Identity()

        self.to_q = nn.Linear(dim, dim_head*heads, bias=False)
        self.to_k = nn.Linear(dim, dim_head*heads, bias=False)

        assert positional_embedding in ['rff', 'rotary', 'learnable', 'none']
        self.positional_embedding = positional_embedding
        self.pos_dim = pos_dim

        if positional_embedding == 'rff':
            self.pos_emb = GaussianFourierFeatureTransform(pos_dim, dim_head, scale=1,
                                                           learnable=False, num_heads=heads)
        elif positional_embedding == 'rotary':
            self.pos_emb = RotaryEmbedding(dim_head//self.pos_dim, min_freq=1/64)
        elif positional_embedding == 'learnable':
            self.pos_emb = nn.Sequential(
                GaussianFourierFeatureTransform(pos_dim, dim_head * heads // 2, scale=1,
                                                learnable=True),
                nn.Linear(dim_head * heads // 2, dim_head*heads, bias=False),
                nn.GELU(),
                nn.Linear(dim_head*heads, dim_head*heads, bias=False))
        else:
            pass
        self.init_gain = 0.02   # 1 / np.sqrt(dim_head)
        # self.diagonal_weight = nn.Parameter(1 / np.sqrt(dim_head) *
        #                                     torch.ones(heads, 1, 1), requires_grad=True)
        self.initialize_qk_weights()
        self.softmax = softmax

        self.residual = residual
        if self.residual:
            self.gamma = nn.Parameter(torch.tensor(1 / np.sqrt(dim_head)), requires_grad=True)
        else:
            self.gamma = 0
        self.scaling = scaling

    def initialize_qk_weights(self):
        xavier_uniform_(self.to_q.weight, gain=self.init_gain)
        xavier_uniform_(self.to_k.weight, gain=self.init_gain)
        # torch.nn.init.normal_(self.to_q.weight, std=self.init_gain)
        # torch.nn.init.normal_(self.to_k.weight, std=self.init_gain)

    def normalize_wrt_domain(self, x):
        x = (x - x.mean(dim=-2, keepdim=True)) / (x.std(dim=-2, keepdim=True) + 1e-5)
        return x

    def forward(self, u_x, u_y=None, pos_x=None, pos_y=None):
        # u_x, u_y: b n c
        # u_x is from the first source
        # u_y is from the second source
        # pos: b n d
        if u_y is None:
            u_y = u_x

        n = u_y.shape[1]

        q = self.to_q(u_x)
        k = self.to_k(u_y)

        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
        if self.normalize:
            q = self.normalize_wrt_domain(q)
            k = self.normalize_wrt_domain(k)

        if self.positional_embedding != 'none' and pos_x is None:
            raise ValueError('positional embedding is not none but pos is None')

        if self.positional_embedding != 'rotary' and \
                self.positional_embedding != 'none' and \
                self.positional_embedding != 'rff':
            pos_x_emb = self.pos_emb(pos_x)
            if pos_y is None:
                pos_y_emb = pos_x_emb
            else:
                pos_y_emb = self.pos_emb(pos_y)
            q = q * pos_x_emb
            k = k * pos_y_emb
        elif self.positional_embedding == 'rff':

            pos_x_emb = self.pos_emb(pos_x, unfold_head=True)
            if pos_y is None:
                pos_y_emb = pos_x_emb
            else:
                pos_y_emb = self.pos_emb(pos_y, unfold_head=True)

            # duplicate q, k
            q_ = torch.cat((q, q), dim=-1)
            k_ = torch.cat((k, k), dim=-1)
            q = q_ * pos_x_emb
            k = k_ * pos_y_emb

        elif self.positional_embedding == 'rotary':
            if self.pos_dim == 2:
                assert pos_x.shape[-1] == 2
                q_freqs_x = self.pos_emb.forward(pos_x[..., 0], q.device)
                q_freqs_y = self.pos_emb.forward(pos_x[..., 1], q.device)
                q_freqs_x = repeat(q_freqs_x, 'b n d -> b h n d', h=q.shape[1])
                q_freqs_y = repeat(q_freqs_y, 'b n d -> b h n d', h=q.shape[1])

                if pos_y is None:
                    k_freqs_x = q_freqs_x
                    k_freqs_y = q_freqs_y
                else:
                    k_freqs_x = self.pos_emb.forward(pos_y[..., 0], k.device)
                    k_freqs_y = self.pos_emb.forward(pos_y[..., 1], k.device)
                    k_freqs_x = repeat(k_freqs_x, 'b n d -> b h n d', h=k.shape[1])
                    k_freqs_y = repeat(k_freqs_y, 'b n d -> b h n d', h=k.shape[1])

                q = apply_2d_rotary_pos_emb(q, q_freqs_x, q_freqs_y)
                k = apply_2d_rotary_pos_emb(k, k_freqs_x, k_freqs_y)
            elif self.pos_dim == 1:
                assert pos_x.shape[-1] == 1

                q_freqs = self.pos_emb.forward(pos_x[..., 0], q.device).unsqueeze(0)
                q_freqs = repeat(q_freqs, '1 n d -> b h n d', b=q.shape[0], h=q.shape[1])

                if pos_y is None:
                    k_freqs = q_freqs
                else:
                    k_freqs = self.pos_emb.forward(pos_y[..., 0], k.device).unsqueeze(0)
                    k_freqs = repeat(k_freqs, '1 n d -> b h n d', b=q.shape[0], h=q.shape[1])

                q = apply_rotary_pos_emb(q, q_freqs)
                k = apply_rotary_pos_emb(k, k_freqs)
            else:
                raise Exception('Currently doesnt support relative embedding > 2 dimensions')
        else:  # do nothing
            pass

        K = torch.einsum('bhid,bhjd->bhij', q, k) * self.scaling  # if not on uniform grid, need to consider quadrature weights
        K = self.dropout(K)
        if self.softmax:
            K = F.softmax(K, dim=-1)
        if self.residual:
            K = K + self.gamma * torch.eye(n).to(q.device).view(1, 1, n, n) / n
        return K


class FABlock1D(nn.Module):
    # contains factorization and attention on each axis
    def __init__(self,
                 dim,
                 dim_head,
                 latent_dim,
                 heads,
                 dim_out,
                 use_rope=True,
                 kernel_multiplier=3,
                 scaling_factor=1.0):
        super().__init__()

        self.dim = dim
        self.latent_dim = latent_dim
        self.heads = heads
        self.dim_head = dim_head
        self.in_norm = nn.LayerNorm(dim)
        self.to_v = nn.Linear(self.dim, heads * dim_head, bias=False)
        self.to_in = nn.Linear(self.dim, self.dim, bias=False)

        self.to_x = nn.Sequential(
            PoolingReducer(self.dim, self.dim, self.latent_dim),
        )
        # self.to_y = nn.Sequential(
        #     Rearrange('b nx ny c -> b ny nx c'),
        #     PoolingReducer(self.dim, self.dim, self.latent_dim),
        # )

        positional_encoding = 'rotary' if use_rope else 'none'
        use_softmax = False
        self.low_rank_kernel_x = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads,
                                               positional_embedding=positional_encoding,
                                               residual=False,  # add a diagonal bias
                                               softmax=use_softmax,
                                               scaling=1 / np.sqrt(dim_head * kernel_multiplier)
                                               if kernel_multiplier > 4 or use_softmax else scaling_factor)
        # self.low_rank_kernel_y = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads,
        #                                        positional_embedding=positional_encoding,
        #                                        residual=False,
        #                                        softmax=use_softmax,
        #                                        scaling=1 / np.sqrt(dim_head * kernel_multiplier)
        #                                        if kernel_multiplier > 4 or use_softmax else scaling_factor)

        self.to_out = nn.Sequential(
            GroupNorm(heads, dim_head * heads, domain_wise=True, affine=False),
            nn.Linear(dim_head * heads, dim_out, bias=False),
            nn.GELU(),
            nn.Linear(dim_out, dim_out, bias=False))

    def forward(self, u, pos_lst):
        # x: B Sx H
        # we do not need the initial projection, already in projected space
        # u = self.in_norm(u)
        v = self.to_v(u)
        u = self.to_in(u)

        # does not really pool
        u_x = self.to_x(u)
        # u_y = self.to_y(u)

        # pos_x, pos_y = pos_lst
        pos_x = pos_lst[0]
        k_x = self.low_rank_kernel_x(u_x, pos_x=pos_x)
        # k_y = self.low_rank_kernel_y(u_y, pos_x=pos_y)

        u_phi = rearrange(v, 'b l (h c) -> b h l c', h=self.heads)
        u_phi = torch.einsum('bhls,bhlc->bhsc', k_x, u_phi)
        # u_phi = torch.einsum('bhlm,bhimc->bhilc', k_y, u_phi)
        u_phi = rearrange(u_phi, 'b h l c -> b l (h c)', h=self.heads)
        return self.to_out(u_phi)



class FABlock2D(nn.Module):
    # contains factorization and attention on each axis
    def __init__(self,
                 dim,
                 dim_head,
                 latent_dim,
                 heads,
                 dim_out,
                 use_rope=True,
                 kernel_multiplier=3,
                 scaling_factor=1.0):
        super().__init__()

        self.dim = dim
        self.latent_dim = latent_dim
        self.heads = heads
        self.dim_head = dim_head
        self.in_norm = nn.LayerNorm(dim)
        self.to_v = nn.Linear(self.dim, heads * dim_head, bias=False)
        self.to_in = nn.Linear(self.dim, self.dim, bias=False)

        self.to_x = nn.Sequential(
            PoolingReducer(self.dim, self.dim, self.latent_dim),
        )
        self.to_y = nn.Sequential(
            Rearrange('b nx ny c -> b ny nx c'),
            PoolingReducer(self.dim, self.dim, self.latent_dim),
        )

        positional_encoding = 'rotary' if use_rope else 'none'
        use_softmax = False
        self.low_rank_kernel_x = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads,
                                               positional_embedding=positional_encoding,
                                               residual=False,  # add a diagonal bias
                                               softmax=use_softmax,
                                               scaling=1 / np.sqrt(dim_head * kernel_multiplier)
                                               if kernel_multiplier > 4 or use_softmax else scaling_factor)
        self.low_rank_kernel_y = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads,
                                               positional_embedding=positional_encoding,
                                               residual=False,
                                               softmax=use_softmax,
                                               scaling=1 / np.sqrt(dim_head * kernel_multiplier)
                                               if kernel_multiplier > 4 or use_softmax else scaling_factor)

        self.to_out = nn.Sequential(
            GroupNorm(heads, dim_head * heads, domain_wise=True, affine=False),
            nn.Linear(dim_head * heads, dim_out, bias=False),
            nn.GELU(),
            nn.Linear(dim_out, dim_out, bias=False))

    def forward(self, u, pos_lst):
        # x: b c h w
        u = self.in_norm(u)
        v = self.to_v(u)
        u = self.to_in(u)

        u_x = self.to_x(u)
        u_y = self.to_y(u)

        pos_x, pos_y = pos_lst
        k_x = self.low_rank_kernel_x(u_x, pos_x=pos_x)
        k_y = self.low_rank_kernel_y(u_y, pos_x=pos_y)

        u_phi = rearrange(v, 'b i l (h c) -> b h i l c', h=self.heads)
        u_phi = torch.einsum('bhij,bhjmc->bhimc', k_x, u_phi)
        u_phi = torch.einsum('bhlm,bhimc->bhilc', k_y, u_phi)
        u_phi = rearrange(u_phi, 'b h i l c -> b i l (h c)', h=self.heads)
        return self.to_out(u_phi)


class FABlock3D(nn.Module):
    # contains factorization and attention on each axis
    def __init__(self,
                 dim,
                 dim_head,
                 latent_dim,
                 heads,
                 dim_out,
                 use_rope=True,
                 kernel_multiplier=3,
                 scaling_factor=1.0):
        super().__init__()

        self.dim = dim
        self.latent_dim = latent_dim
        self.heads = heads
        self.dim_head = dim_head
        self.in_norm = nn.LayerNorm(dim)
        self.to_v = nn.Linear(self.dim, heads * dim_head, bias=False)
        self.to_in = nn.Linear(self.dim, self.dim, bias=False)

        self.to_x = nn.Sequential(
            PoolingReducer(self.dim, self.dim, self.latent_dim),
        )
        self.to_y = nn.Sequential(
            Rearrange('b nx ny nz c -> b ny nx nz c'),
            PoolingReducer(self.dim, self.dim, self.latent_dim),
        )
        self.to_z = nn.Sequential(
            Rearrange('b nx ny nz c -> b nz nx ny c'),
            PoolingReducer(self.dim, self.dim, self.latent_dim),
        )

        positional_encoding = 'rotary' if use_rope else 'none'
        use_softmax = False
        self.low_rank_kernel_x = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads,
                                               positional_embedding=positional_encoding,
                                               residual=False,  # add a diagonal bias
                                               softmax=use_softmax,
                                               scaling=1 / np.sqrt(dim_head * kernel_multiplier)
                                               if kernel_multiplier > 4 or use_softmax else scaling_factor)
        self.low_rank_kernel_y = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads,
                                               positional_embedding=positional_encoding,
                                               residual=False,
                                               softmax=use_softmax,
                                               scaling=1 / np.sqrt(dim_head * kernel_multiplier)
                                               if kernel_multiplier > 4 or use_softmax else scaling_factor)
        self.low_rank_kernel_z = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads,
                                               positional_embedding=positional_encoding,
                                               residual=False,
                                               softmax=use_softmax,
                                               scaling=1 / np.sqrt(dim_head * kernel_multiplier)
                                               if kernel_multiplier > 4 or use_softmax else scaling_factor)

        self.to_out = nn.Sequential(
            GroupNorm(heads, dim_head * heads, domain_wise=True, affine=False),
            nn.Linear(dim_head * heads, dim_out, bias=False),
            nn.GELU(),
            nn.Linear(dim_out, dim_out, bias=False))

    def forward(self, u, pos_lst):
        # x: b h w d c
        u = self.in_norm(u)
        v = self.to_v(u)
        u = self.to_in(u)

        u_x = self.to_x(u)
        u_y = self.to_y(u)
        u_z = self.to_z(u)
        pos_x, pos_y, pos_z = pos_lst

        k_x = self.low_rank_kernel_x(u_x, pos_x=pos_x)
        k_y = self.low_rank_kernel_y(u_y, pos_x=pos_y)
        k_z = self.low_rank_kernel_z(u_z, pos_x=pos_z)

        u_phi = rearrange(v, 'b i l r (h c) -> b h i l r c', h=self.heads)
        u_phi = torch.einsum('bhij,bhjmsc->bhimsc', k_x, u_phi)
        u_phi = torch.einsum('bhlm,bhimsc->bhilsc', k_y, u_phi)
        u_phi = torch.einsum('bhrs,bhilsc->bhilrc', k_z, u_phi)
        u_phi = rearrange(u_phi, 'b h i l r c -> b i l r (h c)', h=self.heads)

        return self.to_out(u_phi)


class Factformer(nn.Module):
    def __init__(self,
                n_layers, 
                d_model,                   # input dimension
                dim_head,              # dimension in each attention head, will be expanded by the kernel_multiplier when computing kernel: d = dim_head * kernel_multiplier
                latent_dim,            # the output dimension of the projection operator
                heads,                 # attention heads
                dim_out,               # output dimension
                kernel_multiplier,     # use more function bases to computer kernel: k(x_i, x_j)=\sum_{c}^dq_c(x_i)k_c(x_j)    
                use_rope,              # use rotary positional encoding or not, by default True
                scaling_factor,        # use scaling factor to modulate the kernel, an example is 1/ sqrt(d) like scaled-dot product attention, by default is: 1
                memory_augmented,
                d_state,
                norm=False,
                **kwargs,
                ):
        super(Factformer, self).__init__()

        self.blocks = nn.ModuleList()
        self.norms  = nn.ModuleList()
        for _ in range(n_layers):
            self.blocks.append(
                 FABlock1D(d_model,         # input dimension
                     dim_head,              # dimension in each attention head, will be expanded by the kernel_multiplier when computing kernel: d = dim_head * kernel_multiplier
                     latent_dim,            # the output dimension of the projection operator
                     heads,                 # attention heads
                     dim_out,               # output dimension
                     kernel_multiplier,     # use more function bases to computer kernel: k(x_i, x_j)=\sum_{c}^dq_c(x_i)k_c(x_j)    
                     use_rope,              # use rotary positional encoding or not, by default True
                     scaling_factor         # use scaling factor to modulate the kernel, an example is 1/ sqrt(d) like scaled-dot product attention, by default is: 1
                    ))
            self.norms.append(nn.LayerNorm(d_model) if norm else nn.Identity())
        self.n_blocks = len(self.blocks)

        self.in_proj = nn.Linear(1, d_model)

        self.memory_augmented = memory_augmented
        if self.memory_augmented:
            self.memory_layer = get_s4block(n_layers=1, s4block_args = {"s4block_type": "S4Block",
                                                                         "d_state": d_state} )[0](d_model)
            for p in self.memory_layer.parameters():
                p.data.fill_(0.0)
            self.memory_norm = nn.LayerNorm(d_model)
            
        else:
            # if d_state != 4:
            #     raise ValueError('Removing this run so that it doesnt clog GPUs')
            self.memory_layer = lambda x, state: (0 * x, None)
            self.memory_norm = nn.Identity()
        self.initial_memory_state = nn.Parameter( torch.zeros(d_model, d_state, dtype = torch.cfloat, device = torch.device('cuda')))
        

        self.out_proj = nn.Linear(d_model, 1)



    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.0002)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)



    # def forward(self, g, u_p, inputs):    
    def forward(self, x, grid, batch_dt = None):
        # gs = dgl.unbatch(g)
        # x = pad_sequence([_g.ndata['x'] for _g in gs]).permute(1, 0, 2)  # B, T1, F
        B, S, T, V = x.shape

        x = rearrange(x, 'b s t v -> (b t) s v') # ((B, T), S, V )
        grid = repeat(grid, 'b s v -> (b t) s v', t=T)  # ((B, T), S, V )

        x = self.in_proj(x)

        # x = rearrange(x, 'bt s h-> bt h s') # ((B, T), H, S)
        for block, norm in zip(self.blocks[:self.n_blocks // 2], self.norms[:self.n_blocks // 2]):
            x = norm(block(x, [grid[0]])) + x
        
        # x = rearrange(x, '(b t) h s -> (b s) h t', t=T, s=S)
        x = rearrange(x, '(b t) s h -> (b s) h t', t=T, s=S)
        x_ = x
        state = repeat(self.initial_memory_state, ' h n -> (b s) h n', b=B, s=S)
        x = self.memory_layer(x, state=state)[0]
        x = self.memory_norm(x.transpose(-1,-2)).transpose(-1,-2) + x_
        x = rearrange(x, '(b s) h t -> (b t) s h', t=T, s=S )
        # x = rearrange(x, '(b s) h t -> (b t) h s', t=T, s=S )

        for block, norm in zip(self.blocks[self.n_blocks // 2 :], self.norms[self.n_blocks // 2:]):
            x = norm(block(x, [grid[0]])) + x
        
        # x = rearrange(x, '(b t) h s -> b t s h', b=B)

        x = self.out_proj(x)

        # x_out = torch.cat([x[i, :num] for i, num in enumerate(g.batch_num_nodes())],dim=0)
        x = rearrange(x, '(b t) s v -> b s t v', b=B)
        return x
    
    def predict(self,
                x_, # (B, S, V)
                grid, 
                n_timesteps,
                **kwargs):
        out = []    
        x_ = x_.unsqueeze(-2) # (B, S, 1, V)
        if self.memory_augmented:
            x = x_
            # TODO: use states (this is very inefficient)
            for _ in range(n_timesteps):
                y = self.forward(x, grid)
                x = torch.cat((x_, y), dim = -2)
            return y
        else: 
            for _ in range(n_timesteps):
                x_ = self.forward(x_, grid)
                out.append(x_)
            return torch.cat(out, dim=-2)
                