import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv,GATConv,FiLMConv,GATv2Conv,GINEConv,GINConv,GIN,global_mean_pool
from torch_geometric.data import Data, DataLoader
from typing import Tuple
from torch.nn import BatchNorm1d
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import copy
from typing import Callable, Optional, Tuple, Union
import torch.nn as nn

from torch import Tensor
from torch.nn import ModuleList, ReLU, SELU

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import reset
from torch_geometric.typing import (
    Adj,
    OptTensor,
    PairTensor,
    SparseTensor,
    torch_sparse,
)
import torch.nn.utils as utils


def hyperbolic_distance(x, y, eps=1e-5):
    norm_x = torch.clamp(torch.norm(x), max=1-eps)
    norm_y = torch.clamp(torch.norm(y), max=1-eps)

    diff = torch.norm(x - y)
    numerator = 2 * diff
    denominator = (1 - norm_x**2) * (1 - norm_y**2)

    return torch.acosh(1 + numerator**2 / denominator)
class TemporalDecayAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_heads, dropout=0.1):
        super(TemporalDecayAttention, self).__init__()
        self.n_heads = n_heads
        self.hidden_dim = hidden_dim
        self.query = nn.Linear(input_dim, hidden_dim * n_heads)
        self.key = nn.Linear(input_dim, hidden_dim * n_heads)
        self.value = nn.Linear(input_dim, hidden_dim * n_heads)
        self.dropout = nn.Dropout(dropout)

        # Temporal decay parameter (learnable)
        self.decay_rate = nn.Parameter(torch.tensor(0.1))
        self.fc_out = nn.Linear(hidden_dim * n_heads, hidden_dim)
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.query.weight)
        torch.nn.init.xavier_uniform_(self.key.weight)
        torch.nn.init.xavier_uniform_(self.value.weight)
        torch.nn.init.xavier_uniform_(self.fc_out.weight)

        if self.query.bias is not None:
            torch.nn.init.zeros_(self.query.bias)
        if self.key.bias is not None:
            torch.nn.init.zeros_(self.key.bias)
        if self.value.bias is not None:
            torch.nn.init.zeros_(self.value.bias)
        if self.fc_out.bias is not None:
            torch.nn.init.zeros_(self.fc_out.bias)

        # Initialize decay rate to a small positive value close to 0
        torch.nn.init.constant_(self.decay_rate, 0.1)
    def forward(self, inputs, time_diffs):
        """
        inputs: [batch_size, seq_len, input_dim]
        time_diffs: [batch_size, seq_len, seq_len] - time differences between elements
        """
        batch_size, seq_len = inputs.size()

        Q = self.query(inputs).view(batch_size, seq_len, self.n_heads, self.hidden_dim)
        K = self.key(inputs).view(batch_size, seq_len, self.n_heads, self.hidden_dim)
        V = self.value(inputs).view(batch_size, seq_len, self.n_heads, self.hidden_dim)

        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (
                    self.hidden_dim ** 0.5)  # [batch_size, n_heads, seq_len, seq_len]

        temporal_decay = torch.exp(-self.decay_rate * time_diffs.unsqueeze(1))  # [batch_size, 1, seq_len, seq_len]
        scores = scores * temporal_decay

        attn_weights = F.softmax(scores, dim=-1)  # [batch_size, n_heads, seq_len, seq_len]
        attn_weights = self.dropout(attn_weights)

        output = torch.matmul(attn_weights, V)  # [batch_size, n_heads, seq_len, hidden_dim]

        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_dim * self.n_heads)

        output = self.fc_out(output)  # [batch_size, seq_len, hidden_dim]

        return output, attn_weights

class HyperGraph(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(HyperGraph, self).__init__(aggr='mean')  # "Add" aggregation.
        self.in_channels = in_channels
        self.out_channels = out_channels
        # Edge embedding layer
        self.edge_embedding = torch.nn.Linear(40, 20)
        self.activation = nn.PReLU()
        self.activation2 = nn.PReLU()
        self.activation3 = nn.PReLU()
        self.weight_layer = torch.nn.Linear(20, out_channels)

        # Node transformation layer
        self.node_transform = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index, edge_attr):
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr=edge_attr, fill_value=1, num_nodes=x.size(0))
        edge_embeddings = self.activation(self.edge_embedding(edge_attr))
        weights =self.weight_layer(edge_embeddings)
        node_attr = self.activation2(self.node_transform(x))
        out = self.propagate(edge_index, x=node_attr, weights=weights)

        return out

    def message(self, x_j, weights):

        weights = weights / weights.sum(dim=1).unsqueeze(1)


        # x_j: Node features of source nodes.
        return weights * x_j

    def update(self, aggr_out):
        return aggr_out
def batch_corr(x):
    mean_x = torch.mean(x, 1)
    xm = x.sub(mean_x.repeat(x.size(1), 1, 1).reshape(x.size()))
    c = torch.bmm(xm.permute(0, 2, 1), xm)
    c = c / (x.size(1) - 1)

    d = torch.diagonal(c, offset=0, dim1=1, dim2=2)
    stddev = torch.pow(d, 0.5)
    c = c.div(stddev.repeat(c.size(1), 1, 1).reshape(c.size()))
    c = c.div(stddev.repeat(c.size(1), 1, 1).reshape(c.size()).permute(0, 2, 1))

    c = torch.clamp(c, -1.0, 2.0)
    return c
class GraphResidualBlock(nn.Module):
    def __init__(self, input_dim: int,hidden_channels: int , num_layers: int, output_dim: int,dropout:float):
        super(GraphResidualBlock, self).__init__()
        nheads=1
        self.convs = torch.nn.ModuleList()
        self.convs.append( GCNConv(in_channels =input_dim, out_channels =hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append( GCNConv(in_channels =hidden_channels, out_channels =hidden_channels))
        self.convs.append(GCNConv(in_channels=hidden_channels, out_channels=output_dim))
        self.activation = nn.PReLU()
        self.create_residual_connection = True if input_dim == output_dim else False
        self.dropout = dropout

    def forward(self, x, edge_index):

        for conv in self.convs[:-1]:

            x = self.activation(conv(x, edge_index))

            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)


        return x



class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout=0.0,num_relations=1,act1=None):
        super().__init__()
        self.dropout = dropout

        in_channels2=hidden_channels#(hidden_channels,hidden_channels)
        self.convs = torch.nn.ModuleList()
        # if act1 is None:
        if act1 is not None:
            self.activation = True
        else:
            self.activation = False
        prelu = nn.PReLU()
        self.convs.append(FiLMConv(in_channels, hidden_channels, num_relations=num_relations, act=prelu))

        for _ in range(num_layers - 2):
            self.convs.append(FiLMConv(in_channels2, hidden_channels,num_relations=num_relations, act=None))
        self.convs.append(FiLMConv(in_channels2, out_channels,num_relations=num_relations,act=prelu))



    def forward(self, x,z2_t,z3_t, edge_index,edge_type,return_Film=False):
        r=x[:,:-1]
        z=x[:,-1].reshape(-1,1)
        for conv in self.convs[:-1]:
        # for conv, norm in zip(self.convs[:-1], self.norms):
            if torch.sum(torch.isnan(conv(x,z3_t, edge_index, edge_type,edge_attr=z2_t))).item() > 0:
                print('a')
            x = F.dropout(x, p=self.dropout, training=self.training)
            x =conv(x,z3_t, edge_index,edge_type,edge_attr=z2_t,return_film=False)
            # x = norm(conv(x, edge_index))


            # x_orig=x
        if torch.sum(torch.isnan(self.convs[-1](x,z3_t, edge_index,edge_type,edge_attr=z2_t))).item() > 0:
            print('a')
        if return_Film:
            x,betas,gammas=self.convs[-1](x,z3_t, edge_index,edge_type,edge_attr=z2_t,return_film=return_Film)
            if self.activation:
                x=torch.sigmoid(x)
            return x,torch.stack([x.mean(dim=1).detach().cpu() for x in betas]),torch.stack([x.mean(dim=1).detach().cpu() for x in gammas])
        else:
            x = self.convs[-1](x,z3_t, edge_index,edge_type,edge_attr=z2_t,return_film=return_Film)
            if self.activation:
                x=torch.sigmoid(x)

            return x#,r,z

class GraphResidualBlockAttn(nn.Module):
    def __init__(self, input_dim: int, output_dim: int,nheads:int,dropout:float):
        super(GraphResidualBlockAttn, self).__init__()
        # nheads=5
        self.conv = GATv2Conv(input_dim, output_dim,nheads,edge_dim=input_dim-1)
        self.activation = nn.PReLU()
        self.create_residual_connection = True if input_dim == output_dim else False
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index,edge_features):
        y = self.activation(self.dropout(self.conv(x, edge_index,edge_features)))
        if self.create_residual_connection:
            y = x + y
        return y
class GraphResFNN(nn.Module):
    def __init__(self, input_dim: int,hidden_channels:int,num_layers:int, output_dim: int,dropout:float):
        super(GraphResFNN, self).__init__()
        self.blocks = nn.ModuleList()
        self.input_dim = input_dim

        input_dim_block = input_dim

        self.blocks.append(GraphResidualBlock(input_dim_block,hidden_channels,num_layers, output_dim,dropout))

    def forward(self, x, edge_index,edge_attr=None):
        for block in self.blocks:
                x = block(x, edge_index)
        return x

def jaccard_index(edges1, edges2):
    set1 = set(map(tuple, edges1.t().tolist()))
    set2 = set(map(tuple, edges2.t().tolist()))

    intersection = set1.intersection(set2)
    union = set1.union(set2)

    jaccard = len(intersection) / len(union) if len(union) > 0 else 0.0
    return jaccard

def get_edge_index(batchx,bs,maxbs=0.98):
    edge_index_new=[]
    batchx=batchx.reshape(-1,50,40).permute(0,2,1)
    bc=batch_corr(batchx)
    indices = ((bc >= bs) & (bc <= maxbs)).nonzero(as_tuple=False).long().T
    indices[1]=indices[1]+indices[0]*50
    indices[2]=indices[0]*50+indices[2]
    indices=torch.index_select(indices,0,torch.tensor([1,2]).to(batchx.device))
    return indices
def get_edge_features(batchx,edge_index):
    edge_index_new=[]
    # batchx=batchx.reshape(-1,50,40).permute(0,2,1)
    source_indices = edge_index[0]
    target_indices = edge_index[1]
    # Index into the node_features tensor
    source_node_features = batchx[source_indices]
    target_node_features = batchx[target_indices]
    # Compute the differences
    differences = source_node_features - target_node_features
    return differences

class GraphHyperAttention(nn.Module):
    def __init__(self, input_dim: int, output_dim: int,nheads:int):
        super(GraphHyperAttention, self).__init__()
        self.blocks = nn.ModuleList()
        self.input_dim = input_dim

        input_dim_block = input_dim
        self.blocks.append(HyperGraph(input_dim_block, 10))
        self.blocks.append(HyperGraph(10, 1))

    def forward(self, x, edge_index,edge_features):
        for block in self.blocks:
            if isinstance(block, GraphResidualBlockAttn):
                x = block(x, edge_index,edge_features)
            else:
                x = block(x, edge_index,edge_features)
        return x

class GraphResAttention(nn.Module):
    def __init__(self, input_dim: int, output_dim: int,nheads:int,dropout:float):
        super(GraphResAttention, self).__init__()
        self.blocks = nn.ModuleList()
        self.input_dim = input_dim

        input_dim_block = input_dim
        # for hidden_dim in hidden_dims:
        embeddim=256
        self.blocks.append(GraphResidualBlockAttn(input_dim_block, embeddim,nheads,dropout))
        input_dim_block = output_dim
        nheads=int(embeddim*nheads)
        self.blocks.append(GATConv(nheads, output_dim,heads=1,dropout=dropout))

    def forward(self, x, edge_index,edge_features):
        for block in self.blocks:
            if isinstance(block, GraphResidualBlockAttn):
                x = block(x, edge_index,edge_features)
            else:
                x = block(x, edge_index,edge_features)
        return x
class ResidualBlock(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(ResidualBlock, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.activation1 = nn.PReLU()
        self.create_residual_connection = True if input_dim == output_dim else False
        self.dropout= nn.Dropout(.0)
    def forward(self, x):
        y = self.activation1(self.dropout(self.linear(x.float())))
        if self.create_residual_connection:
            y = x + y
        return y
class ResFNN(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dims: Tuple[int], flatten: bool = False):
        super(ResFNN, self).__init__()
        blocks = list()
        self.input_dim = input_dim
        self.flatten = flatten
        input_dim_block = input_dim
        try:
            for hidden_dim in hidden_dims:
                blocks.append(ResidualBlock(input_dim_block, hidden_dim))
                # blocks.append(nn.LayerNorm(hidden_dim))
                input_dim_block = hidden_dim
        except:
            blocks.append(ResidualBlock(input_dim_block, hidden_dims))
            input_dim_block = hidden_dims
        blocks.append(nn.Linear(input_dim_block, output_dim))

        self.network = nn.Sequential(*blocks)
        self.blocks = blocks
    def forward(self, x):
        if self.flatten:
            x = x.reshape(x.shape[0], -1)
        out = self.network(x)
        return out





class GraphArFNN_Hyper(nn.Module):
    def __init__(self, input_dim1: int, output_dim: int,corr_thresh:float):
        super(GraphArFNN_Hyper, self).__init__()
        self.network = FiLMConv(input_dim1, output_dim,act='prelu').to('cuda')
        # self.network_std = GraphResAttention(input_dim1, output_dim,nheads).to('cuda')
        # self.linear = nn.Linear(input_dim1, output_dim)
        self.act=nn.PReLU()
        # self.linear2 = nn.Linear(2000, 2000)
        self.b_3 = [2.5, 10, 20]
        self.b_1 = [1.25, 5, 10]
        self.b_2 = [1.25, 5, 10]
        self.corr_thresh=corr_thresh
    def forward(self, data: Data, z, z2):
        x, edge_index,edge_features, batch = data.x, data.edge_index,data.edge_attr, data.batch
        x_generated = []
        # bs=x.reshape(-1,50,40)
        x=x[:,:z.shape[1]]
        edge_index1=edge_index.clone()
        # corr_thresh=0.45
        for t in range(z.shape[1]):
            z_t = z[:, t]
            z_t2 = z[:, t]**2# Shape: (B * N, feature_dim)
            z2_t = z2[:, t]
            z2_t2 = z2[:, t]**2 # Shape: (B * N, 1)
            x_in = torch.cat([z_t.reshape(-1,1), x], dim=-1)
            x_in = torch.cat([z_t2.reshape(-1, 1), x_in], dim=-1)
            x_in_abs = torch.cat([z2_t.reshape(-1,1), torch.abs(x)], dim=-1)
            x_gen = self.network(x_in, edge_index1,edge_features)
            # x_gen = x_gen + self.network_std(x_in_abs, edge_index1)
            # x_gen=self.act(self.linear(x_gen))
            x = torch.cat([x[:, 1:], x_gen], dim=1)
            x_generated.append(x_gen)
            edge_index1 = get_edge_index(x,self.corr_thresh)
            edge_features=get_edge_features(x,edge_index1)
        x_fake = torch.stack(x_generated, dim=1)
        # x_fake = x_fake.reshape(bs, 1, -1)
        # x_fake=self.linear2(x_fake)
        return x_fake.reshape((len(batch),len(x_generated),1))


class LearnablePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_len, d_model)
        self.alpha = nn.Parameter(torch.ones(1))

    def forward(self, x):

        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)  # (1, seq_len)
        pos_emb = self.pos_embedding(positions)  # (1, seq_len, d_model)
        return x + self.alpha * pos_emb

class TemporalAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.embedding = nn.Linear(1, d_model)  # embed scalar time steps
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.query = nn.Parameter(torch.randn(d_model))  # shared query vector
        self.scale = d_model ** 0.5

    def forward(self, x):
        """
        x: (N, T) tensor of scalar time series
        Returns:
          context: (N, d_model) weighted sum per series
          attn_weights: (N, T) attention weights per time step
        """
        N, T = x.shape
        x = x.unsqueeze(-1)  # (N, T, 1)
        x_emb = self.embedding(x)  # (N, T, d_model)

        K = self.key(x_emb)  # (N, T, d_model)
        V = self.value(x_emb)  # (N, T, d_model)
        Q = self.query.unsqueeze(0).unsqueeze(0)  # (1, 1, d_model)
        Q = Q.expand(N, 1, -1)  # (N, 1, d_model)

        scores = torch.bmm(Q, K.transpose(1, 2)) / self.scale  # (N, 1, T)
        attn_weights = F.softmax(scores, dim=-1)  # (N, 1, T)

        context = torch.bmm(attn_weights, V).squeeze(1)  # (N, d_model)

        return context, attn_weights.squeeze(1)


class ARFNN_Net(nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int,out_channels:int,num_layers:int,dropout:float,corr_thresh:float,num_relations:int):
        super(ARFNN_Net, self).__init__()
        self.p=in_channels-1
        self.tempattn=TemporalAttention(self.p)
        self.network = Net(in_channels=in_channels+1, hidden_channels=hidden_channels,
                out_channels=out_channels, num_layers=num_layers,
                dropout=dropout,num_relations=num_relations).to('cuda')

        self.linear = nn.Linear(42, 1)
        self.linear.bias.data.fill_(0.0001)
        self.act=nn.PReLU()
        self.linear2 = nn.Linear(2000, 2000)

        self.gating = nn.ModuleList()
        self.gating.append( nn.Linear(in_channels+1, int(in_channels * 2)))
        self.gating.append(nn.PReLU())
        self.gating.append(nn.Linear( int(in_channels * 2),out_channels))
        self.gating.append(nn.Sigmoid())
        self.b_3 = [2.5, 10, 20]
        self.b_3a = [2.5, 10, 20]
        self.b_3b = [2.5, 10, 20]
        self.b_4 = [2.5, 10, 20]
        self.b_1 = [1.25, 5, 10]
        self.b_1ix = [1.25, 5, 10]
        self.b_1s = [1.25, 5, 10]
        self.b_1sa = [1.25, 5, 10]

        self.b_1six = [1.25, 5, 10]
        self.b_1saix = [1.25, 5, 10]
        self.b_2 = [1.25, 5, 10]
        self.b_2ix = [1.25, 5, 10]
        self.b_11 = [1.25, 5, 10]
        self.b_22 = [1.25, 5, 10]
        self.b_1a = [1.25, 5, 10]
        self.b_2a= [1.25, 5, 10]
        self.b_5 = [1.25, 5, 10]
        self.b_6 = [1.25, 5, 10]
        self.b_7 = [2.5, 10, 20]
        self.b_8 = [1.25, 5, 10]
        self.b_9 = [1.25, 5, 10]
        self.b_10 = [1.25, 5, 10]
        self.b_11 = [1.25, 5, 10]
        self.corr_thresh = corr_thresh
        self.act=nn.Sigmoid()
        self.act2=nn.ReLU()

    def forward(self, data: Data, z,z_shared, z_shared1,z_jump,extra_edge_index=None,extra_edge_attr=None,desired_vol_mat=None,post_extra_edge_index=None,post_extra_edge_attr=None,post_z_shared1=None,event_t=None,ninst2=None,return_FiLM=False,actual=0,notreal=False):
        if return_FiLM:
            betas1=[]
            gammas1=[]
        try:
            x,x2, edge_index,edge_type, batch = data.x,data.x2, data.edge_index,data.edge_attr, data.batch
        except:
            x,  edge_index, edge_type, batch = data.x,  data.edge_index, data.edge_attr, data.batch
        x_generated = []

        x=x[:,:self.p]
        if not notreal:
            x_real1=x
        div_jump=(x.std(dim=1).reshape(-1,1)+1e-2)

        x_jumps=x/div_jump
        x_jumps[x_jumps.abs()<5]=0
        print('percent jump',torch.mean(sum(x_jumps!=0)/sum(x_jumps==0)).item())
        x=x.clamp(-1.1,1.1)
        if desired_vol_mat is not None:
            pp=int(desired_vol_mat.size(2)/2)
            desired_vol_mat1=desired_vol_mat[:,:ninst2,:].reshape(x.size(0),-1)[:,-pp:]

        if extra_edge_index is not None:
            edge_index1=extra_edge_index
            edge_type=extra_edge_attr
        else:
            edge_index1=edge_index.clone()
        for t in range(z.shape[1]):
            context, attn_weights=self.tempattn(x)
            x = attn_weights * x
            z_t = z[:, t]
            z_t2 = z[:, t]**2
            z_j=z_jump[:,t]# Shape: (B * N, feature_dim)
            z_j2 = z_jump[:, t]**2
            if event_t is not None:
                if t>=event_t:
                    z2_t = post_z_shared1[:, t]  # Shape: (B * N, 1)
                    z2_t2 = post_z_shared1[:, t] ** 2
                else:
                    z2_t = z_shared1[:, t]  # Shape: (B * N, 1)
                    z2_t2 = z_shared1[:, t] ** 2


            else:
                z2_t = z_shared1[:, t]  # Shape: (B * N, 1)
                z2_t2 = z_shared1[:,  t]**2

            z3_t = z_shared[:,:, t]  # Shape: (B * N, 1)
            z3_t2 = z_shared[:,:, t] ** 2
            z2_t=torch.cat([z2_t.unsqueeze(1), z2_t2.unsqueeze(1)], dim=-1)
            z3_t=torch.cat([z3_t.unsqueeze(1), z3_t2.unsqueeze(1)], dim=1)

            x_in = torch.cat([z_t.reshape(-1,1), x], dim=-1)
            x_in = torch.cat([z_t2.reshape(-1, 1), x_in], dim=-1)
            x_in_j = torch.cat([z_j.reshape(-1, 1), x_jumps], dim=-1)
            x_in_j = torch.cat([z_j2.reshape(-1, 1), x_in_j], dim=-1)

            if event_t is not None:
                if t>=event_t:
                    edge_index1=post_extra_edge_index
                    edge_type = post_extra_edge_attr
            if desired_vol_mat is not None:

                if not  notreal:
                    if t < event_t:
                        x_gen = x_real1[:,  t]

                    else:
                        x_gen1 = self.network(x_in,z2_t,z3_t, edge_index1,edge_type,return_FiLM)
                        size_adj=desired_vol_mat1[:,t].reshape(x_gen1.size())
                        adj=torch.abs(size_adj/x_gen1)*x_gen1
                        mask=adj!=0
                        x_gen = torch.where(mask, adj.float(), x_gen1)
                        for layer in self.gating:
                            x_in_j = layer(x_in_j)
                        x_gen_jump=x_in_j
                        # probs = logits)
                        hard_gate = (x_gen_jump > 0.99).float()
                        gate = hard_gate + (x_gen_jump - x_gen_jump.detach())
                        x_jump_gen=x_gen * gate * z_j.reshape(-1, 1)
                        x_gen=x_gen+x_jump_gen
                        # x_gen[mask]=adj[mask].float()
                else:
                    x_gen1 = self.network(x_in, z2_t, z3_t, edge_index1, edge_type, return_FiLM)
                    size_adj = desired_vol_mat1[:, t].reshape(x_gen1.size())
                    adj = torch.abs(size_adj / x_gen1) * x_gen1
                    mask = adj != 0
                    x_gen = torch.where(mask, adj.float(), x_gen1)
                    for layer in self.gating:
                        x_in_j = layer(x_in_j)
                    x_gen_jump = x_in_j
                    # probs = logits)
                    hard_gate = (x_gen_jump > 0.99).float()
                    gate = hard_gate + (x_gen_jump - x_gen_jump.detach())
                    x_jump_gen = x_gen * gate * z_j.reshape(-1, 1)
                    x_gen = x_gen + x_jump_gen

            else:
                if return_FiLM:
                    x_gen,betas,gammas= self.network(x_in,z2_t.to(x_in.device),z3_t, edge_index1, edge_type,return_FiLM)
                    betas1.append(betas)
                    gammas1.append(gammas)
                else:
                    x_gen = self.network(x_in,z2_t.to(x_in.device),z3_t, edge_index1, edge_type,return_FiLM)
                for layer in self.gating:
                    x_in_j = layer(x_in_j)
                x_gen_jump=x_in_j
                # probs = logits)
                hard_gate = (x_gen_jump > 0.999).float()
                gate = hard_gate + (x_gen_jump - x_gen_jump.detach())
                x_jump_gen=x_gen * gate * z_j.reshape(-1, 1)
                x_gen=x_gen+x_jump_gen

            if t%10==0:
                print('r_pcnt',x_gen.abs().mean().item())
            x_gen = x_gen.clamp(-1.1, 1.1)
            if x_gen.abs().max()>1:
                print('a')
            if torch.sum(torch.isnan(x_gen)).item()>0:
                print('a')
            x = torch.cat([x[:, 1:], x_gen.reshape(-1,1)], dim=1)
            x_jump_gen=x_gen.reshape(-1,1)/div_jump
            x_jump_gen[x_jump_gen.abs() < 5] = 0
            x_jumps = torch.cat([x_jumps[:, 1:], x_jump_gen], dim=1)
            x_generated.append(x_gen.reshape(-1,1))
            if extra_edge_index is  None:
                edge_index1 = get_edge_index(x,self.corr_thresh)
                if edge_index1.size(1)<x.size(0):
                    edge_index1 = get_edge_index(x,.1)
        print('percent jump generated',torch.mean(sum(x_jumps!=0)/sum(x_jumps==0)).item())

        x_fake = torch.stack(x_generated, dim=1)
        if return_FiLM:
            return x_fake.reshape(z.size(0), len(x_generated), 1),betas1,gammas1
        else:
            return x_fake.reshape(z.size(0),len(x_generated),1),[],[]





class GCNGMMN(nn.Module):
    def __init__(self, input_dim1: int,hidden_channels:int,num_layers:int, output_dim: int,dropout:float):
        super(GCNGMMN, self).__init__()
        self.network = GraphResFNN(input_dim1,hidden_channels,num_layers, output_dim,dropout).to('cuda')
        self.act=nn.PReLU()
        self.b_3 = [2.5, 10, 20]
        self.b_1 = [1.25, 5, 10]
        self.b_2 = [1.25, 5, 10]
        self.b_4 = [2.5, 10, 20]

    def forward(self, data: Data, z, z2,bs=2048):
        x, edge_index,edge_features, batch = data.x, data.edge_index,data.edge_attr, data.batch
        x_generated = []
        x=x[:,:z.shape[1]]
        edge_index1=edge_index.clone()
        for t in range(z.shape[1]):
            z_t = z[:, t]  # Shape: (B * N, feature_dim)
            x_in = torch.cat([z_t.reshape(-1,1), x], dim=-1)
            x_gen = self.network(x_in, edge_index1)
            x = torch.cat([x[:, 1:], x_gen], dim=1)
            x_generated.append(x_gen)
        x_fake = torch.stack(x_generated, dim=1)
        return x_fake.reshape((len(batch),len(x_generated),1))






class GATGMMN(nn.Module):
    def __init__(self, input_dim1: int, output_dim: int,nheads:int,dropout:float):
        super(GATGMMN, self).__init__()
        self.network = GraphResAttention(input_dim1, output_dim,nheads,dropout).to('cuda')
        # self.network_std = GraphResAttention(input_dim1, output_dim,nheads).to('cuda')
        # self.linear = nn.Linear(input_dim1, output_dim)
        self.act=nn.PReLU()
        # self.linear2 = nn.Linear(2000, 2000)
        self.b_3 = [2.5, 10, 20]
        self.b_1 = [1.25, 5, 10]
        self.b_2 = [1.25, 5, 10]
        self.b_4 = [2.5, 10, 20]

    def forward(self, data: Data, z, z2,bs=2048):
        x, edge_index,edge_features, batch = data.x, data.edge_index,data.edge_attr, data.batch
        x_generated = []
        # bs=x.reshape(-1,50,40)
        x=x[:,:z.shape[1]]
        edge_index1=edge_index.clone()
        for t in range(z.shape[1]):
            z_t = z[:, t]  # Shape: (B * N, feature_dim)
            z2_t = z2[:, t]  # Shape: (B * N, 1)
            x_in = torch.cat([z_t.reshape(-1,1), x], dim=-1)
            x_in_abs = torch.cat([z2_t.reshape(-1,1), torch.abs(x)], dim=-1)
            x_gen = self.network(x_in, edge_index1,edge_features)
            # x_gen = x_gen + self.network_std(x_in_abs, edge_index1)
            # x_gen=self.act(self.linear(x_gen))
            x = torch.cat([x[:, 1:], x_gen], dim=1)
            x_generated.append(x_gen)
            edge_index1 = get_edge_index(x,bs)
            edge_features=get_edge_features(x,edge_index1)
        x_fake = torch.stack(x_generated, dim=1)
        # x_fake = x_fake.reshape(bs, 1, -1)
        # x_fake=self.linear2(x_fake)
        return x_fake.reshape((len(batch),len(x_generated),1))

class GraphArFNNAttentioncorr(nn.Module):
    def __init__(self, input_dim1: int, output_dim: int, nheads: int):
        super(GraphArFNNAttentioncorr, self).__init__()
        self.network = GraphResAttention(input_dim1, output_dim, nheads).to('cuda')
        self.network_std = GraphResAttention(input_dim1, output_dim, nheads).to('cuda')
        # self.linear = nn.Linear(input_dim1, output_dim)
        self.act = nn.PReLU()
        self.linear2 = nn.Linear(2000, 2000)
        self.b_3 = [2.5, 10, 20]
        self.b_1 = [1.25, 5, 10]
        self.b_2 = [1.25, 5, 10]

    def forward(self, data: Data, z, z2, bs=2048):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x_generated = []
        # bs=x.reshape(-1,50,40)
        x = x[:, :z.shape[1]]
        edge_index1 = edge_index.clone()
        for t in range(z.shape[1]):
            z_t = z[:, t]  # Shape: (B * N, feature_dim)
            z2_t = z2[:, t]  # Shape: (B * N, 1)
            x_in = torch.cat([z_t.reshape(-1, 1), x], dim=-1)
            x_in_abs = torch.cat([z2_t.reshape(-1, 1), torch.abs(x)], dim=-1)
            x_gen = self.network(x_in, edge_index1)
            x_gen = x_gen + self.network_std(x_in_abs, edge_index1)
            # x_gen=self.act(self.linear(x_gen))
            x = torch.cat([x[:, 1:], x_gen], dim=1)
            x_generated.append(x_gen)
            # edge_index1 = get_edge_index(x, bs)
        x_fake = torch.stack(x_generated, dim=1)
        # x_fake = x_fake.reshape(bs, 1, -1)
        # x_fake=self.linear2(x_fake)
        return x_fake.reshape((len(batch), len(x_generated), 1))
class GraphArFNNAbs(nn.Module):
    def __init__(self, input_dim1: int,input_dim_lin:int, output_dim: int,output_dim_lin:int,hidden_dims_lin:Tuple[int]):
        super(GraphArFNNAbs, self).__init__()
        self.network = GraphResFNN(input_dim1, output_dim).to('cuda')
        # self.network_std = GraphResFNN(input_dim1, output_dim).to('cuda')
        self.network_lin = ResFNN(input_dim_lin, output_dim_lin, hidden_dims_lin).to('cuda')
        # self.network_std_lin = ResFNN(input_dim_lin, output_dim_lin, hidden_dims_lin).to('cuda')

    def forward(self, data: Data, z, z2,bs):
        x, edge_index,edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x_generated = []

        x=x[:,:z.shape[1]]
        x_past = self.network(x, edge_index, edge_attr)

        x_past=x_past.reshape(bs,z.shape[1],-1)
        for t in range(z.shape[1]):
            z_t = z[:, t]  # Shape: (B * N, feature_dim)
            z2_t = z2[:, t]  # Shape: (B * N, 1)
            # x_in = torch.cat([z_t.reshape(-1,1), x], dim=-1)
            # x_in_abs = torch.cat([z2_t.reshape(-1,1), torch.abs(x)], dim=-1)

            x_in = torch.cat([z_t.reshape(bs,1,-1), x_past.reshape(bs, 1, -1)], dim=-1)
            x_in = self.network_lin(x_in)
            x_in_abs = torch.cat([z2_t.reshape(bs,1,-1), torch.abs(x_past).reshape(bs, 1, -1)], dim=-1)
            # x_in_abs = self.network_std_lin(x_in_abs)
            x_gen=x_in
            x = torch.cat([x_past[:, 1:], x_gen], dim=1)
            x_past = torch.cat([x_past[:, 1:], x_gen], dim=1)

            x_generated.append(x_past[:,-1,:].reshape(-1,1))
        x_fake = torch.stack(x_generated, dim=1)

        return x_fake




class GraphResFNNCorr(nn.Module):
    def __init__(self, input_dim: int,input_dim_lin: int, output_dim: int,output_dim_lin: int,hidden_dims:Tuple[int],hidden_dims_lin:Tuple[int]):
        super(GraphResFNNCorr, self).__init__()
        self.input_dim = input_dim
        blocks = list()
        input_dim_block = int(input_dim_lin)
        for hidden_dim in hidden_dims_lin:
            blocks.append(ResidualBlock(input_dim_block, hidden_dim))
            input_dim_block = hidden_dim
        blocks.append(nn.Linear(input_dim_block, output_dim_lin))
        self.network_lin = nn.Sequential(*blocks)
        blocks1 = list()
        input_dim_block = int(input_dim_lin)
        for hidden_dim in hidden_dims_lin:
            blocks1.append(ResidualBlock(input_dim_block, hidden_dim))
            input_dim_block = hidden_dim
        blocks1.append(nn.Linear(input_dim_block, output_dim_lin))
        self.networkstd_lin = nn.Sequential(*blocks1)
        self.network1 = nn.ModuleList()
        self.network2 = nn.ModuleList()
        input_dim_block = input_dim
        self.network1.append(GraphResidualBlock(input_dim_block, output_dim))
        self.network2.append(GraphResidualBlock(input_dim_block, output_dim))

        self.fc = nn.Linear(output_dim, 1)

    def forward(self, x, edge_index,edge_attr, batch,bs):
        linx = x.reshape(bs,1, -1)
        out = self.network_lin(linx)
        out2 = self.networkstd_lin(torch.abs(linx))


        x_abs = torch.abs(x)
        for layer in self.network1:
            x = layer(x, edge_index,edge_attr)
        for layer in self.network2:
            x_abs = layer(x_abs, edge_index,edge_attr)
        x_combined = x + x_abs

        x_pooled = global_mean_pool(x_combined, batch)
        output = self.fc(x_pooled)
        return output+out.squeeze(1)+out2.squeeze(1)
