import torch
import libs.torchsde as torchsde
import torch.nn as nn
import torch.nn.functional as F
import torchdiffeq

from models.backbone import *


# used for decoding embedding to classes
class decoder_MLP(nn.Module):
  def __init__(self, c, cfg):
    super().__init__()
    self.cfg = cfg
    self.m21 = nn.Linear(cfg["hidden_channels"], cfg["hidden_channels"])
    self.m22 = nn.Linear(cfg["hidden_channels"], c)

  def forward(self, x):
    x = F.dropout(x, self.cfg["dropout"], training=self.training)
    x = F.dropout(x + self.m21(torch.tanh(x)), self.cfg["dropout"], training=self.training)
    x = F.dropout(self.m22(torch.tanh(x)), self.cfg["dropout"], training=self.training)
    return x


class decoder_MLP_simple(nn.Module):
  def __init__(self, c, cfg):
    super().__init__()
    self.cfg = cfg

    self.m22 = nn.Linear(cfg["hidden_channels"], c)

  def forward(self, x):
    x = self.m22(torch.tanh(x))
    return x

# used for encoding feature to embedding
class encoder_MLP(nn.Module):
  def __init__(self, d, cfg):
    super().__init__()
    self.cfg = cfg
    self.m11 = nn.Linear(d, cfg["hidden_channels"])


  def forward(self, x):

      x = x.to(self.cfg["device"])
      x = F.dropout(x, self.cfg["input_dropout"], training=self.training)

      x = self.m11(x)

      return x



class SGNN(torchsde.SDEIto):
    def __init__(self, d, c, cfg, covariance_matrix=None):
        super(SGNN, self).__init__(noise_type="diagonal")
        self.input_encoder = encoder_MLP(d, cfg)
        self.bnin = nn.BatchNorm1d(cfg["hidden_channels"])
        self.f_encoder = Drift(cfg["hidden_channels"], cfg["hidden_channels"], cfg["hidden_channels"], 
                               num_layers=1, dropout=cfg["dropout"], use_bn=cfg["use_bn"], type=cfg["encoder_gnn_type"])
        # SFN includes the graph information to model  the dependency of noises.
        self.g_encoder = SFN(cfg["hidden_channels"], cfg["hidden_channels"], 
                             cfg["hidden_channels"], num_layers=1, 
                             dropout=cfg["dropout"], use_bn=cfg["use_bn"], type=cfg["encoder_gnn_type"])
        
        self.bng = nn.BatchNorm1d(cfg["hidden_channels"])

        self.output_decoder = decoder_MLP_simple(c, cfg)

        self.cfg = cfg
        self.time = self.cfg["time"]
        self.use_cholesky = False  # Always set to False as requested

        self.N = self.cfg["N"]
        self.ts = torch.tensor([0, self.time])
        self.device = self.cfg["device"]

        self.ind_flag = True
        self.ind_edge_index = None
        self.ood_edge_index = None
        self.ind_edge_weight = None
        self.ood_edge_weight = None
        self.c_size = c
        self.sdeint_fn = torchsde.sdeint_adjoint if self.cfg["adjoint"] else torchsde.sdeint
        
        # Learnable kernels
        self.learnable_matern_kernel = cfg.get("learnable_matern_kernel", False)
        self.learnable_diffusion_kernel = cfg.get("learnable_diffusion_kernel", False)
        
        if self.learnable_matern_kernel:
            from models.matern_kernel import Matern_Kernel_Module
            self.kernel_module = Matern_Kernel_Module(hidden_dim=cfg["hidden_channels"])
            self.covariance_matrix = None  # Will be computed during forward pass
        elif self.learnable_diffusion_kernel:
            from models.matern_kernel import Diffusion_Kernel_Module
            self.kernel_module = Diffusion_Kernel_Module(hidden_dim=cfg["hidden_channels"])
            self.covariance_matrix = None  # Will be computed during forward pass
        else:
            # Store the covariance matrix
            self.covariance_matrix = covariance_matrix
            if covariance_matrix is not None and not isinstance(covariance_matrix, torch.Tensor):
                self.covariance_matrix = torch.tensor(covariance_matrix, dtype=torch.float32)

    def set_covariance_matrix(self, covariance_matrix):
        """Set or update the covariance matrix"""
        if self.learnable_matern_kernel or self.learnable_diffusion_kernel:
            print("Warning: Using learnable kernel, fixed covariance matrix will be ignored")
            return
            
        if not isinstance(covariance_matrix, torch.Tensor):
            covariance_matrix = torch.tensor(covariance_matrix, dtype=torch.float32)
        self.covariance_matrix = covariance_matrix
        
        if hasattr(self, 'device') and self.device is not None:
            self.covariance_matrix = self.covariance_matrix.to(self.device)
        print(f"covariance_matrix: {self.covariance_matrix.shape}")
            
    def reset_parameters(self):
        self.f_encoder.reset_parameters()
        self.g_encoder.reset_parameters()
        if hasattr(self, 'kernel_module'):
            if hasattr(self.kernel_module, 'reset_parameters'):
                self.kernel_module.reset_parameters()

    def f_net(self, t, y):
        if self.ind_flag == True:
            edge_index = self.ind_edge_index.to(self.device)
            ax = self.f_encoder(y, edge_index)
            return ax - y
        else:
            edge_index = self.ood_edge_index.to(self.device)
            ax = self.f_encoder(y, edge_index)
            return ax - y

    def g_net(self, t, y):
        if self.ind_flag == True:
            edge_index = self.ind_edge_index.to(self.device)
            g_output = self.g_encoder(y, edge_index)
            
            # Apply covariance inverse transformation if available
            if self.inverse_kernel is not None:
                # Apply inverse kernel (covariance inverse) to the output
                g_output = torch.matmul(self.inverse_kernel, g_output)
                
            return y-g_output
        else:
            edge_index = self.ood_edge_index.to(self.device)
            g_output = self.g_encoder(y, edge_index)
            
            # Apply covariance inverse transformation if available
            if self.inverse_kernel is not None:
                # Apply inverse kernel to the output
                g_output = torch.matmul(self.inverse_kernel, g_output)
                
            return y-g_output

    def forward(self, x, flag, device, n_trajectories=1):
        self.ind_flag = flag
        self.device = device
        
        # Process input normally
        node_embeddings = self.input_encoder(x)  # Shape: (N, d)
        if self.cfg["use_bn"]:
            node_embeddings = self.bnin(node_embeddings)
        
        # If using learnable kernels, compute them based on the current node embeddings
        if self.learnable_matern_kernel or self.learnable_diffusion_kernel:
            # Get the appropriate edge index based on flag
            edge_index = self.ind_edge_index if flag else self.ood_edge_index
            edge_index = edge_index.to(device)
            
            # Compute kernel and its inverse
            if self.learnable_matern_kernel:
                kernel, k, nu = self.kernel_module(node_embeddings, edge_index=edge_index, 
                                                 num_nodes=node_embeddings.shape[0])
                # Compute inverse for use in g_net
                self.covariance_matrix = kernel  # Store for potential use
            elif self.learnable_diffusion_kernel:
                kernel, k = self.kernel_module(node_embeddings, edge_index=edge_index, 
                                             num_nodes=node_embeddings.shape[0])
                # Compute inverse for use in g_net
                self.covariance_matrix = kernel  # Store for potential use
        elif self.covariance_matrix is not None:
            # Move existing covariance to device
            self.covariance_matrix = self.covariance_matrix.to(device)
        # Flatten for SDE solving
        ts = torch.linspace(0, self.time, self.N).to(device)

        if n_trajectories == 1:
            z = self.sdeint_fn(
                sde=self,
                y0=node_embeddings,
                ts=ts,
                covariance_matrix=self.covariance_matrix,
                method=self.cfg["method"],
                adaptive=self.cfg["adaptive"],
                rtol=self.cfg["rtol"],
                atol=self.cfg["atol"],
                names={'drift': 'f_net', 'diffusion': 'g_net'}
            )
            # Reshape back to (N, d) for decoder
            hidden_embedding = z[-1].view(-1, self.cfg["hidden_channels"])
            logits = self.output_decoder(hidden_embedding)
        else:
            all_logits = []
            
            for _ in range(n_trajectories):
                z = self.sdeint_fn(
                    sde=self,
                    y0=node_embeddings,
                    ts=ts,
                    covariance_matrix=self.covariance_matrix,
                    method=self.cfg["method"],
                    adaptive=self.cfg["adaptive"],
                    rtol=self.cfg["rtol"],
                    atol=self.cfg["atol"],
                    names={'drift': 'f_net', 'diffusion': 'g_net'}
                )
                hidden_embedding = z[-1].view(-1, self.cfg["hidden_channels"])
                traj_logits = self.output_decoder(hidden_embedding)
                all_logits.append(traj_logits)
            
            logits = torch.stack(all_logits, dim=0)
        return logits


class SGNN_ODE(nn.Module):
    def __init__(self, d, c, cfg, covariance_matrix=None, noise_scale=1.0):
        """
        Neural ODE version of SGNN with noise injection into the vector field.
        
        Args:
            d: Input feature dimension
            c: Number of classes
            cfg: Configuration dictionary
            covariance_matrix: Covariance matrix for noise transformation
            noise_scale: Scaling factor for noise
        """
        super(SGNN_ODE, self).__init__()
        
        # Feature encoder
        self.input_encoder = encoder_MLP(d, cfg)
        self.bnin = nn.BatchNorm1d(cfg["hidden_channels"]) if cfg["use_bn"] else nn.Identity()
        
        # Two separate encoders for deterministic and stochastic parts
        self.f_encoder = Drift(cfg["hidden_channels"], cfg["hidden_channels"], cfg["hidden_channels"], 
                              num_layers=1, dropout=cfg["dropout"], use_bn=cfg["use_bn"], 
                              type=cfg["encoder_gnn_type"])
        
        self.g_encoder = SFN(cfg["hidden_channels"], cfg["hidden_channels"], cfg["hidden_channels"], 
                            num_layers=1, dropout=cfg["dropout"], use_bn=cfg["use_bn"], 
                            type=cfg["encoder_gnn_type"])
        
        # Decoder
        self.output_decoder = decoder_MLP_simple(c, cfg)
        
        # Store configuration
        self.cfg = cfg
        self.time = self.cfg["time"]
        self.N = self.cfg["N"]
        self.device = self.cfg["device"]
        self.noise_scale = noise_scale
        self.hidden_channels = self.cfg["hidden_channels"]
        self.use_cholesky = cfg.get("use_cholesky", False)
        
        # Graph structure variables
        self.ind_flag = True
        self.ind_edge_index = None
        self.ood_edge_index = None
        self.ind_edge_weight = None
        self.ood_edge_weight = None
        self.c_size = c
        
        # Store the covariance matrix
        self.covariance_matrix = None
        self.cholesky_factor = None
        self.set_covariance_matrix(covariance_matrix)
        
        # Cache for optimization
        self._cached_edge_index = None
        self._cached_device = None
        
        # For performance tracking
        self.vector_field_call_count = 0
        self.vector_field_time = 0.0
    
    def reset_parameters(self):
        """Reset parameters of the model"""
        self.f_encoder.reset_parameters()
        self.g_encoder.reset_parameters()
    
    def set_cholesky(self, cholesky):
        self.use_cholesky = cholesky
        if cholesky and self.covariance_matrix is not None:
            self.cholesky_factor = torch.linalg.cholesky(self.covariance_matrix)
            if hasattr(self, 'device') and self.device is not None:
                self.cholesky_factor = self.cholesky_factor.to(self.device)
    
    def set_covariance_matrix(self, covariance_matrix):
        """Set or update the covariance matrix"""
        if covariance_matrix is None:
            self.covariance_matrix = None
            self.cholesky_factor = None
            return
            
        if not isinstance(covariance_matrix, torch.Tensor):
            covariance_matrix = torch.tensor(covariance_matrix, dtype=torch.float32)
        
        self.covariance_matrix = covariance_matrix
        
        if self.use_cholesky:
            self.cholesky_factor = torch.linalg.cholesky(self.covariance_matrix)
            
        if hasattr(self, 'device') and self.device is not None:
            self.covariance_matrix = self.covariance_matrix.to(self.device)
            if self.cholesky_factor is not None:
                self.cholesky_factor = self.cholesky_factor.to(self.device)

    def _get_edge_index(self):
        """Get the appropriate edge index based on current state"""
        # Return cached edge index if device hasn't changed
        if self._cached_edge_index is not None and self._cached_device == self.device:
            return self._cached_edge_index
        
        # Otherwise, get the appropriate edge index and cache it
        if self.ind_flag:
            edge_index = self.ind_edge_index.to(self.device)
        else:
            edge_index = self.ood_edge_index.to(self.device)
        
        self._cached_edge_index = edge_index
        self._cached_device = self.device
        return edge_index

    def vector_field(self, t, y):
        """
        Optimized vector field function for the ODE solver.
        
        Args:
            t: Time
            y: State tensor of shape (N, d)
        Returns:
            vector_field: Tensor of shape (N, d)
        """
        import time
        start_time = time.time()
        self.vector_field_call_count += 1

        # Get edge index from cache or update cache
        edge_index = self._get_edge_index()
        
        # Compute deterministic and stochastic parts
        f_x = self.f_encoder(y, edge_index)  # Shape: (N, d)
        
        # Generate noise with time-dependent scaling - do this before g_x to prepare for batch operation
        N, d = y.shape
        time_scaling = torch.sqrt(t.abs() + 1e-8)
        Z = torch.randn(N, d, device=y.device) * time_scaling
        
        # Apply covariance transformation to noise if available
        if self.covariance_matrix is not None:
            if self.use_cholesky and self.cholesky_factor is not None:
                matrix = self.cholesky_factor
            else:
                matrix = self.covariance_matrix
                
            # Optimize the matrix multiplication based on matrix shape
            if matrix.dim() == 2:
                if matrix.shape[0] == N and matrix.shape[1] == N:
                    # Pre-compute the matrix-noise product for all features at once
                    Z_reshaped = Z.view(N, -1)  # Reshape to (N, d)
                    L_Z = torch.matmul(matrix, Z_reshaped).view(N, d)  # Apply and reshape back
                elif matrix.shape[0] == d and matrix.shape[1] == d:
                    # Fast feature correlation (N,d) @ (d,d) -> (N,d)
                    L_Z = torch.matmul(Z, matrix)
                else:
                    L_Z = Z * self.noise_scale
            elif matrix.dim() == 1:
                # Handle vector case with broadcasting
                if matrix.shape[0] == N:
                    L_Z = Z * matrix.unsqueeze(1)
                elif matrix.shape[0] == d:
                    L_Z = Z * matrix.unsqueeze(0)
                else:
                    L_Z = Z * self.noise_scale
            else:
                L_Z = Z * self.noise_scale
        else:
            # Without covariance matrix, just use scaled noise
            L_Z = Z * self.noise_scale
        
        # Calculate g_x after noise transformation for better parallel execution
        g_x = self.g_encoder(y, edge_index)  # Shape: (N, d)
        
        # Combine deterministic and stochastic parts
        vector_field = f_x + g_x * L_Z
        
        # Track time
        self.vector_field_time += time.time() - start_time
        
        return vector_field
    
    def forward(self, x, flag, device, n_trajectories=1):
        """
        Forward pass through the neural ODE with runtime measurement.
        
        Args:
            x: Input features tensor
            flag: Boolean indicating whether to use IND or OOD graph
            device: Device to run computation on
            n_trajectories: Number of trajectories to sample (default: 1)
            
        Returns:
            logits: Output logits
        """
        import time
        
        # Start timing the entire forward pass
        forward_start_time = time.time()
        
        # Reset profiling counters for this forward pass
        self.vector_field_call_count = 0
        self.vector_field_time = 0.0
        
        # Update instance variables
        self.ind_flag = flag
        self.device = device
        
        # Reset cache when device or flag changes
        self._cached_edge_index = None
        self._cached_device = None
        
        # Ensure covariance matrices are on the correct device
        if self.covariance_matrix is not None:
            self.covariance_matrix = self.covariance_matrix.to(device)
            if self.cholesky_factor is not None:
                self.cholesky_factor = self.cholesky_factor.to(device)
        
        # Time the encoding phase
        encode_start = time.time()
        node_embeddings = self.input_encoder(x)
        if self.cfg["use_bn"]:
            node_embeddings = self.bnin(node_embeddings)
        encode_time = time.time() - encode_start
        
        # Define integration time points
        ts = torch.linspace(0, self.time, self.N).to(device)
        
        # Use rtol/atol from config or set defaults
        method = self.cfg.get("method", "euler")  # Default to faster Euler method
        rtol = self.cfg.get("rtol", 1e-3)
        atol = self.cfg.get("atol", 1e-3)
        
        # Time the ODE solving phase
        solve_start = time.time()
        if n_trajectories == 1:
            # Single trajectory case
            trajectory = torchdiffeq.odeint(
                self.vector_field,
                node_embeddings,
                ts,
                method=method,
                rtol=rtol,
                atol=atol,
            )
            
            final_state = trajectory[-1]
            logits = self.output_decoder(final_state)
        else:
            # Multiple trajectories case - pre-allocate for better memory efficiency
            first_traj = torchdiffeq.odeint(
                self.vector_field,
                node_embeddings,
                ts,
                method=method,
                rtol=rtol,
                atol=atol,
            )
            
            final_state = first_traj[-1]
            first_logits = self.output_decoder(final_state)
            
            # Pre-allocate tensor for all logits
            all_logits = torch.empty((n_trajectories,) + first_logits.shape, 
                                    dtype=first_logits.dtype, 
                                    device=first_logits.device)
            all_logits[0] = first_logits
            
            # Compute remaining trajectories
            for i in range(1, n_trajectories):
                trajectory = torchdiffeq.odeint(
                    self.vector_field,
                    node_embeddings,
                    ts,
                    method=method,
                    rtol=rtol,
                    atol=atol,
                )
                
                final_state = trajectory[-1]
                all_logits[i] = self.output_decoder(final_state)
            
            logits = all_logits
        
        solve_time = time.time() - solve_start
        
        # Calculate total forward pass time
        total_forward_time = time.time() - forward_start_time
        
        # Print profiling information
        print(f"Forward Pass Time: {total_forward_time:.4f} seconds")
        print(f"  Encoding Time: {encode_time:.4f} seconds")
        print(f"  ODE Solving Time: {solve_time:.4f} seconds")
        print(f"  Vector Field Calls: {self.vector_field_call_count}")
        print(f"  Vector Field Total Time: {self.vector_field_time:.4f} seconds")
        print(f"  Vector Field Avg Time: {self.vector_field_time/max(1, self.vector_field_call_count):.4f} seconds/call")
        
        return logits