from torch import nn
import torch
from models.gcl import GCL, GCL_rf, E_GCL, E_GCL_vel

from torch_geometric.nn import GCNConv, GINConv, global_add_pool, global_mean_pool, global_max_pool
torch.set_default_dtype(torch.float32)  # or torch.float64, torch.float16

# Option 2: Set default tensor type
torch.set_default_tensor_type(torch.FloatTensor)  # CPU float32
OUTPUT_DIM = 16
class AE_parent(nn.Module):
    """Graph Neural Net with global state and fixed number of nodes per graph.
    Args:
          hidden_dim: Number of hidden units.
          num_nodes: Maximum number of nodes (for self-attentive pooling).
          global_agg: Global aggregation function ('attn' or 'sum').
          temp: Softmax temperature.
    """

    def __init__(self):
        super(AE_parent, self).__init__()

    def encode(self, nodes, edges, edge_attr):
        pass

    def decode(self, x):
        pass

    def decode_from_x(self, x, linear_layer=None, C=10, b=-1, remove_diagonal=True):
        n_nodes = x.size(0)
        x_a = x.unsqueeze(0)
        x_b = torch.transpose(x_a, 0, 1)
        X = (x_a - x_b) ** 2
        X = X.view(n_nodes ** 2, -1)
        #X = torch.sigmoid(self.C*torch.sum(X, dim=1) + self.b)
        if linear_layer is not None:
            X = torch.sigmoid(linear_layer(X))
        else:
            X = torch.sigmoid(C*torch.sum(X, dim=1) + b)

        adj_pred = X.view(n_nodes, n_nodes)
        if remove_diagonal:
            adj_pred = adj_pred * (1 - torch.eye(n_nodes).to(self.device))
        return adj_pred

    def forward(self, data,  edge_attr=None):
        x = self.encode(data, edge_attr)
        #adj_pred = self.decode(x)
        return x


class AE(AE_parent):
    def __init__(self, hidden_nf, embedding_nf=32, noise_dim=1, device='cpu', act_fn=nn.SiLU(), learnable_dec=1, n_layers=4, attention=0):
        super(AE, self).__init__()
        self.hidden_nf = hidden_nf
        self.embedding_nf = embedding_nf
        self.noise_dim = noise_dim
        self.device = device
        self.n_layers = n_layers
        ### Encoder
        self.add_module("gcl_0", GCL(max(1, self.noise_dim), self.hidden_nf, self.hidden_nf, edges_in_nf=1, act_fn=act_fn, attention=attention, recurrent=False))
        for i in range(1, n_layers):
            self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_nf=1, act_fn=act_fn, attention=attention))
        self.fc_emb = nn.Linear(self.hidden_nf, self.embedding_nf)

        ### Decoder
        self.fc_dec = None
        if learnable_dec:
            self.fc_dec = nn.Linear(self.embedding_nf, 1)
        self.to(self.device)

    def decode(self, x):
        return self.decode_from_x(x, linear_layer=self.fc_dec)

    def encode(self, nodes, edges, edge_attr=None):
        if self.noise_dim:
            nodes = torch.randn(nodes.size(0), self.noise_dim).to(self.device)
        h, _ = self._modules["gcl_0"](nodes, edges, edge_attr=edge_attr)
        for i in range(1, self.n_layers):
            h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr)
        return self.fc_emb(h)


class AE_rf(AE_parent):
    def __init__(self, embedding_nf=32, nf=64, device='cpu', n_layers=4, act_fn=nn.SiLU(), reg=1e-3, clamp=False):
        super(AE_rf, self).__init__()
        self.embedding_nf = embedding_nf
        self.device = device
        self.n_layers = n_layers

        ### Encoder
        self.gcl = GCL_rf(nf, reg=reg)
        for i in range(n_layers):
            self.add_module("gcl_%d" % i, GCL_rf(nf, act_fn=act_fn, reg=reg, edge_attr_nf=1, clamp=clamp))

        ### Decoder
        self.w = nn.Parameter(-0.1 * torch.ones(1)).to(device)
        self.b = nn.Parameter(torch.ones(1)).to(device)
        self.to(self.device)

    def decode(self, x):
        return self.decode_from_x(x, C=self.w, b=self.b)

    def encode(self, nodes, edges, edge_attr=None):
        x = torch.randn(nodes.size(0), self.embedding_nf).to(self.device)
        for i in range(0, self.n_layers):
            x, _ = self._modules["gcl_%d" % i](x, edges, edge_attr=edge_attr)
        return x



class AE_EGNN(AE_parent):
    def __init__(self, hidden_nf, K=8, device='cpu', act_fn=nn.SiLU(), n_layers=4, reg = 1e-3, clamp=False):
        super(AE_EGNN, self).__init__()
        self.hidden_nf = hidden_nf
        self.K = K
        self.device = device
        self.n_layers = n_layers
        self.reg = reg
        ### Encoder
        self.add_module("gcl_0", E_GCL(1, self.hidden_nf, self.hidden_nf, edges_in_d=1, act_fn=act_fn, recurrent=False, clamp=clamp))
        for i in range(1, n_layers):
            self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=1, act_fn=act_fn, recurrent=True, clamp=clamp))
        #self.fc_emb = nn.Linear(self.hidden_nf, self.embedding_nf)

        ### Decoder
        self.w = nn.Parameter(-0.1*torch.ones(1)).to(device)
        self.b = nn.Parameter(torch.ones(1)).to(device)
        self.to(self.device)

    def decode(self, x):
        return self.decode_from_x(x, C=self.w, b=self.b)

    def encode(self, h, edges, edge_attr=None):
        coords = torch.randn(h.size(0), self.K).to(self.device)
        #h, coords, _ = self._modules["gcl_0"](nodes, edges, coords, edge_attr=edge_attr)
        for i in range(0, self.n_layers):
            h, coords, _ = self._modules["gcl_%d" % i](h, edges, coords, edge_attr=edge_attr)
            coords -= self.reg * coords
            #coords = normalizer(coords)
        return coords

class MC_AE_EGNN(AE_parent):
    def __init__(self, in_node_nf, in_edge_nf, hidden_edge_nf, hidden_node_nf, hidden_coord_nf,
                 device='cpu', act_fn=nn.SiLU(), n_layers=4, coords_weight=1.0, recurrent=False,
                 norm_diff=False, tanh=False, num_vectors=1, reg=1e-3, K=8,
                 update_vel=False, dtype=torch.float32, with_coords=True):
        
        # Set default dtype for this model
        self.dtype = dtype
        torch.set_default_dtype(self.dtype)
        
        super(MC_AE_EGNN, self).__init__()
        self.hidden_edge_nf = hidden_edge_nf
        self.hidden_node_nf = hidden_node_nf
        self.hidden_coord_nf = hidden_coord_nf
        self.device = device
        self.n_layers = n_layers
        self.update_vel = update_vel
        self.with_coords = with_coords
        self.reg = reg
        self.K = K
        
        # Create embedding with specific dtype
        self.embedding = nn.Linear(in_node_nf+1, self.hidden_node_nf).to(dtype=self.dtype)
        self.relu = nn.LeakyReLU(0.1)
        self.linear = nn.Linear(self.hidden_node_nf, OUTPUT_DIM).to(dtype=self.dtype)
        # Create GCL layers with specific dtype
        self.add_module("gcl_%d" % 0, E_GCL_vel(self.hidden_node_nf, self.hidden_node_nf, 
                                              self.hidden_edge_nf, self.hidden_node_nf, 
                                              self.hidden_coord_nf, edges_in_d=2*in_edge_nf, 
                                              act_fn=act_fn, coords_weight=coords_weight, 
                                              recurrent=recurrent, norm_diff=norm_diff, 
                                              tanh=tanh, num_vectors_out=num_vectors, k =self.K).to(dtype=self.dtype))
        
        for i in range(1, n_layers - 1):
            self.add_module("gcl_%d" % i, E_GCL_vel(self.hidden_node_nf, self.hidden_node_nf, 
                                                  self.hidden_edge_nf, self.hidden_node_nf, 
                                                  self.hidden_coord_nf, edges_in_d=2*in_edge_nf, 
                                                  act_fn=act_fn, coords_weight=coords_weight, 
                                                  recurrent=recurrent, norm_diff=norm_diff, 
                                                  tanh=tanh, num_vectors_in=num_vectors, 
                                                  num_vectors_out=num_vectors, k =self.K).to(dtype=self.dtype))
            
        self.add_module("gcl_%d" % (n_layers - 1), E_GCL_vel(self.hidden_node_nf, self.hidden_node_nf,
                                                           self.hidden_edge_nf, self.hidden_node_nf, 
                                                           self.hidden_coord_nf, edges_in_d=2*in_edge_nf, 
                                                           act_fn=act_fn, coords_weight=coords_weight, 
                                                           recurrent=recurrent, norm_diff=norm_diff, 
                                                           tanh=tanh, num_vectors_in=num_vectors, 
                                                           last_layer=True, k =self.K).to(dtype=self.dtype))
        
        # Create parameters with specific dtype
        self.w = nn.Parameter(-0.1 * torch.ones(1, dtype=self.dtype))
        self.b = nn.Parameter(torch.ones(1, dtype=self.dtype))
        
        # Move model to device
        self.to(self.device)

    def decode(self, x):
        return self.decode_from_x(x, C=self.w, b=self.b)
    
    def encode(self, data, edge_attr=None):
        edges = data.edge_index.to(dtype=torch.long)
        h = torch.cat([torch.abs(data.pos.to(dtype=self.dtype)), data.x], dim=-1) # almost invariant ambiguity
        eigvals = None#data.eigvals.to(dtype=self.dtype)
        edge_attr = data.edge_attr.to(dtype=self.dtype)
        coords = torch.zeros_like(data.x.to(dtype=self.dtype)) #+ torch.randn_like(data.x.to(dtype=self.dtype) )*0.05#torch.zeros(h.size(0), self.K).to(self.device)
        #h, coords, _ = self._modules["gcl_0"](nodes, edges, coords, edge_attr=edge_attr)
        h = self.embedding(h)
        h = self.relu(h)
        for i in range(0, self.n_layers):
            if self.with_coords:
                h, coords, _ = self._modules["gcl_%d" % i](h, edges, coords, edge_attr=edge_attr, eigvals=eigvals, batch=data.batch)
                coords -= self.reg * coords
            else:
                h, _, _ = self._modules["gcl_%d" % i](h, edges, coords, edge_attr=edge_attr, eigvals=eigvals, batch=data.batch)
                
            #coords = normalizer(coords)
        # Assuming batch is your DataBatch object

        # Apply global pooling to get graph-level representations
        h = global_add_pool(h, data.batch)
        return self.linear(h)

class Baseline(nn.Module):
    """Graph Neural Net with global state and fixed number of nodes per graph.
    Args:
          hidden_dim: Number of hidden units.
          num_nodes: Maximum number of nodes (for self-attentive pooling).
          global_agg: Global aggregation function ('attn' or 'sum').
          temp: Softmax temperature.
    """

    def __init__(self, device='cpu'):
        super(Baseline, self).__init__()
        self.dummy = nn.Parameter(torch.ones(1))
        self.device = device
        self.to(device)

    def forward(self, nodes, b, c):
        n_nodes = nodes.size(0)
        return torch.zeros(n_nodes, n_nodes).to(self.device) * self.dummy, torch.ones(n_nodes)


def normalizer(x):
    x = x - torch.mean(x, dim=0).unsqueeze(0)
    #x = x / (torch.max(x) - torch.min(x) +1e-8)
    return x