import torch
import torch.nn as nn
import torch_geometric.nn as gnn
from torch_geometric.utils import get_laplacian, to_dense_adj
from torch_geometric.data import Data
import numpy as np
import matplotlib.pyplot as plt
import torchsde
import torch.nn.functional as F
import networkx as nx
import torch.distributions as D
import ot
from sklearn.model_selection import train_test_split
import matplotlib.animation as animation

# work flow:
# 1. generate random node features using a mixture of Gaussians
# 2. with probability p, connect two nodes with the same Gaussian and probability 1-p, connect two nodes with different Gaussians
# use mixture of two Gaussians to generate node features
class MixtureOfGaussians:
    def __init__(self, means, stds, weights=None):
        """
        Initialize a mixture of Gaussian distributions.

        Args:
            means (Tensor): Shape (num_components, dim) representing the mean of each Gaussian.
            stds (Tensor): Shape (num_components, dim) representing the standard deviation.
            weights (Tensor, optional): Shape (num_components,) representing mixture weights.
        """
        self.means = means
        self.stds = stds
        self.num_components = means.shape[0]
        self.dim = means.shape[1]

        if weights is None:
            self.weights = torch.ones(self.num_components) / self.num_components
        else:
            self.weights = weights / weights.sum()  # Normalize weights

        self.components = [D.MultivariateNormal(mean, torch.diag(std**2)) for mean, std in zip(means, stds)]
        self.mixture = D.Categorical(self.weights)

    def sample(self, num_samples=1, component=None):
        """
        Sample from the mixture of Gaussians.

        Args:
            num_samples (int): Number of samples to draw.
            component (int, optional): If provided, samples from the specified Gaussian component.

        Returns:
            Tensor: Sampled points of shape (num_samples, dim).
        """
        if component is None:
            component_indices = self.mixture.sample((num_samples,))  # Sample mixture indices
        else:
            component_indices = torch.full((num_samples,), component, dtype=torch.long)

        samples = torch.stack([self.components[i].sample() for i in component_indices])
        return samples, component_indices

    def density(self, X, Y):
        """
        Compute the probability density function of the mixture at given points.

        Args:
            X, Y: Meshgrid tensors for evaluation.

        Returns:
            Z: Density values for the grid.
        """
        if not isinstance(X, torch.Tensor):
            X = torch.tensor(X)
        if not isinstance(Y, torch.Tensor):
            Y = torch.tensor(Y)
        pos = torch.stack([X, Y], dim=-1)  # Shape (grid_size, grid_size, 2)
        Z = torch.zeros_like(X)

        for i in range(self.num_components):
            weight = self.weights[i]
            density = torch.exp(self.components[i].log_prob(pos))
            Z += weight * density

        return Z.numpy()


class RandomGraphFromGMM:
    def __init__(self, samples, component_indices, p1=0.7, p2=0.3, p_ood=0.3, 
                 ood_connection='random', k_ood=3):
        """
        Initializes the random graph.

        Args:
            samples (numpy.ndarray): Sampled points of shape (num_samples, dim).
            component_indices (numpy.ndarray): Corresponding component indices.
            p1 (float): Probability of connecting nodes from the same IND component.
            p2 (float): Probability of connecting nodes from different IND components.
            p_ood (float): Probability of connecting OOD nodes to IND nodes (for random mode).
            ood_connection (str): Either 'random' or 'knn' for OOD connection strategy.
            k_ood (int): Number of nearest IND neighbors (only used if ood_connection='knn').
        """
        self.samples = samples
        self.component_indices = component_indices
        self.num_samples = len(samples)
        self.p1 = p1
        self.p2 = p2
        self.p_ood = p_ood
        self.k_ood = k_ood
        self.ood_connection = ood_connection
        self.graph = nx.Graph()
        self.generate_graph()
        self.plot_graph()


    def generate_graph(self):
        """
        Generates a random graph following the given connection probabilities.
        """
        # Add nodes
        for i, sample in enumerate(self.samples):
            self.graph.add_node(i, pos=sample, component=self.component_indices[i])

        # Create masks for IND and OOD samples
        ind_mask = self.component_indices < 2
        ood_mask = self.component_indices == 2
        ind_indices = np.where(ind_mask)[0]
        ood_indices = np.where(ood_mask)[0]

        # Add edges between IND samples
        for i in ind_indices:
            for j in ind_indices[i < ind_indices]:
                same_component = self.component_indices[i] == self.component_indices[j]
                connection_prob = self.p1 if same_component else self.p2

                if np.random.rand() < connection_prob:
                    self.graph.add_edge(i, j)

        # Connect OOD samples based on chosen strategy
        if self.ood_connection == 'knn':
            # K-nearest neighbors approach
            from scipy.spatial.distance import cdist
            distances = cdist(self.samples[ood_indices], self.samples[ind_indices])
            
            for i, ood_idx in enumerate(ood_indices):
                nearest_ind_idx = np.argpartition(distances[i], self.k_ood)[:self.k_ood]
                for ind_idx in nearest_ind_idx:
                    self.graph.add_edge(ood_idx, ind_indices[ind_idx])
        
        elif self.ood_connection == 'random':
            # Random connection approach
            for ood_idx in ood_indices:
                for ind_idx in ind_indices:
                    if np.random.rand() < self.p_ood:
                        self.graph.add_edge(ood_idx, ind_idx)
        
        print(f"number of edges: {len(self.graph.edges())}")
        print(f"number of nodes: {len(self.graph.nodes())}")

    def plot_graph(self):
        """
        Plots the generated graph with nodes colored by component.
        """
        pos = {i: self.samples[i] for i in range(self.num_samples)}
        node_colors = [(self.component_indices[i]+1) for i in range(self.num_samples)]
        
        plt.figure(figsize=(8, 6))
        nx.draw(self.graph, pos, node_color=node_colors, cmap=plt.cm.Set1, 
                node_size=50, edge_color="gray", alpha=0.6)
        plt.title(f"Random Graph Generated from GMM Samples ({self.ood_connection} OOD connections)")
        plt.show()




def generate_toy_data(means=torch.tensor([[1.5, 1.5], [.0,.0], [8.0, 8.0]]), 
                      stds=torch.tensor([[1.0, 1.0], [0.6, 0.6], [1.0, 1.0]]), 
                      weights=torch.tensor([0.4, 0.4, 0.2]),
                      num_samples=1200, num_components=3, dim=2):
    # Create mixture model
    mog = MixtureOfGaussians(means, stds, weights)

    # Generate samples
    all_samples, all_component_indices = mog.sample(1200)
    # Separate IND (components 0 and 1) from OOD (component 2)
    ind_mask = (all_component_indices < 2)  # Components 0 and 1
    ood_mask = (all_component_indices == 2)  # Component 2

    ind_samples = all_samples[ind_mask]
    ind_component_indices = all_component_indices[ind_mask]
    ood_samples = all_samples[ood_mask]
    ood_component_indices = all_component_indices[ood_mask]


    # Visualize the data
    plt.figure(figsize=(8, 6))
    colors = ["orange", "blue", "red"]
    for i in range(num_components):
        mask = all_component_indices == i
        label = "IND" if i < 2 else "OOD"
        plt.scatter(all_samples[mask, 0], all_samples[mask, 1], 
                    color=colors[i], label=f"{label} (Component {i})", alpha=0.6)
    plt.xlabel("X-axis")
    plt.ylabel("Y-axis")
    plt.title("Mixture of Gaussians - IND and OOD Samples")
    plt.grid(True)
    plt.legend()
    plt.show()

    return mog, all_samples, all_component_indices, ind_samples, ind_component_indices, ood_samples, ood_component_indices





def create_pyg_data_with_split(graph_model, test_ratio=0.2, val_ratio=0.1, 
                               ood_in_train=False, ood_has_label=False, seed=42):
    """
    Creates a PyTorch Geometric Data object from the graph model with train/val/test splits.
    
    Args:
        graph_model (RandomGraphFromGMM): The graph model containing the graph.
        test_ratio (float): Ratio of IND nodes to use for testing.
        val_ratio (float): Ratio of IND nodes to use for validation.
        ood_in_train (bool): Whether to include OOD nodes in training set.
        ood_has_label (bool): If OOD nodes are in training, whether they have labels.
        seed (int): Random seed for reproducibility.
        
    Returns:
        data (torch_geometric.data.Data): PyG Data object with masks and features.
    """
    import torch
    from torch_geometric.data import Data
    import numpy as np
    
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    G = graph_model.graph
    num_nodes = len(G.nodes())
    
    # Extract node features and labels
    x = torch.tensor(graph_model.samples, dtype=torch.float)
    y = torch.tensor(graph_model.component_indices, dtype=torch.long)
    
    # Create edge index
    edge_index = torch.tensor(list(G.edges())).t().contiguous()
    # Add reverse edges to make it undirected
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
    
    # Create masks
    ind_mask = y < 2  # IND samples (components 0 and 1)
    ood_mask = y == 2  # OOD samples (component 2)    
    # Get indices of IND nodes
    ind_indices = torch.nonzero(ind_mask).squeeze()
    num_ind = ind_indices.size(0)
    
    # Shuffle IND indices
    perm = torch.randperm(num_ind)
    ind_indices = ind_indices[perm]
    
    # Calculate split sizes
    test_size = int(num_ind * test_ratio)
    val_size = int(num_ind * val_ratio)
    train_size = num_ind - test_size - val_size
    
    # Create train/val/test masks for IND nodes
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    
    train_mask[ind_indices[:train_size]] = True
    val_mask[ind_indices[train_size:train_size+val_size]] = True
    test_mask[ind_indices[train_size+val_size:]] = True
    
    # Handle OOD nodes
    ood_indices = torch.nonzero(ood_mask).squeeze()
    
    if ood_in_train:
        # Include OOD nodes in training set
        train_mask[ood_indices] = True
        
        if not ood_has_label:
            # Create a mask for nodes with labels
            train_labeled_mask = train_mask.clone()
            train_labeled_mask[ood_indices] = False
    else:
        # OOD nodes are not in any split
        train_labeled_mask = train_mask.clone()
    
    # Create a mask for OOD detection evaluation
    ood_eval_mask = torch.zeros(num_nodes, dtype=torch.bool)
    ood_eval_mask[ind_indices[train_size+val_size:]] = True  # Test IND nodes
    ood_eval_mask[ood_indices] = True  # All OOD nodes
    
    # Create PyG Data object
    data = Data(
        x=x,
        edge_index=edge_index,
        y=y,
        train_mask=train_mask,
        val_mask=val_mask,
        test_mask=test_mask,
        ood_mask=ood_mask,
        ind_mask=ind_mask,
        ood_eval_mask=ood_eval_mask
    )
    
    # Add train_labeled_mask if OOD nodes are in training but unlabeled
    if ood_in_train and not ood_has_label:
        data.train_labeled_mask = train_labeled_mask
    
    return data


if __name__ == "__main__":
    mog, all_samples, all_component_indices, ind_samples, ind_component_indices, ood_samples, ood_component_indices = generate_toy_data()
    graph_model = RandomGraphFromGMM(all_samples, all_component_indices, ood_connection='random')
    graph_model.generate_graph()
    graph_model.plot_graph()
    data = create_pyg_data_with_split(graph_model, ood_in_train=False)
    print(data)