# %%
import matplotlib.pyplot as plt
import numpy as np
import math
import torch
from torchvision import datasets, transforms
from torch import nn
# %%

class ToyChainLoader:
    def __init__(self, batch_size=32, device='cuda'):
        self.batch_size = batch_size
        self.device = device
        
        prec = np.load('./glasso_loader_Theta_true.npy')
        prec = torch.tensor(prec, device=self.device, dtype=torch.float32)  
        plt.imshow(prec.cpu().numpy(), cmap = 'viridis')
        plt.title("Precision Matrix Visualization")
        plt.colorbar(label='Precision Value')
        plt.xlabel("Dimension Index")
        plt.ylabel("Dimension Index")
        plt.show()
        self.dim = prec.size(0)
            
        self.Theta = prec.clone()
        self.Sigma = torch.inverse(prec)
        
        self.xdata = torch.distributions.MultivariateNormal(
                    torch.zeros(self.dim, device=self.device), self.Sigma).sample((2048,))
        
        self.xtestdata = torch.distributions.MultivariateNormal(
                    torch.zeros(self.dim, device=self.device), self.Sigma).sample((1000,))
        
        
        print("Toy dataset with dim =", self.dim, "and batch size =", batch_size, "on", device)
        

    def _process_batch(self, batch1, batch2):
        """
        Helper function to ensure Train and Test batches undergo 
        the EXACT same transformation and corruption logic.
        """
        # Randomly select one dimension to mask for each item in the batch
        indices = torch.randint(0, self.dim, (batch1.size(0),), device=self.device)
        mask = torch.zeros((batch1.size(0), self.dim), device=self.device)
        mask.scatter_(1, indices.unsqueeze(1), 1.0)

        # Flip a coin for each item to potentially mask the next dimension
        # Only consider indices that are not the last dimension
        should_mask_next = (torch.rand(batch1.size(0), device=self.device) < 0.5) & (indices < self.dim - 1)
        mask[torch.arange(batch1.size(0), device=self.device)[should_mask_next], indices[should_mask_next] + 1] = 1.0
        
        # flip a coin for each item to potentially mask the previous dimension
        should_mask_prev = (torch.rand(batch1.size(0), device=self.device) < 0.5) & (indices > 0)
        mask[torch.arange(batch1.size(0), device=self.device)[should_mask_prev], indices[should_mask_prev] - 1] = 1.0
        
        return batch1, batch2, mask 

    def get_batch(self):
        """Returns a random batch."""
        batch1 = self.xdata[torch.randint(0, self.xdata.size(0) - self.batch_size, (self.batch_size,), device=self.device)]
        batch2 = self.xdata[torch.randint(0, self.xdata.size(0) - self.batch_size, (self.batch_size,), device=self.device)]
        
        return self._process_batch(batch1, batch2)

    def get_test_batch(self):
        """Returns a random batch from the TESTING set."""
        testbatch1 = self.xtestdata[torch.randint(0, self.xtestdata.size(0), (self.batch_size,), device=self.device)]
        testbatch2 = self.xtestdata[torch.randint(0, self.xtestdata.size(0), (self.batch_size,), device=self.device)]
        
        return self._process_batch(testbatch1, testbatch2)

# --- Usage Example ---
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    loader = ToyChainLoader(batch_size=99, device=device)

    # Get Train Batch
    v1, v2, mask = loader.get_batch()
    print(f"Train Batch: {v1.shape, v2.shape}")
    print(v1*mask)
    print(v2*(1-mask))
    print(f"Index: {mask[:5, :].cpu().numpy()}")

    # Get Test Batch
    test_v1, test_v2, test_idx = loader.get_test_batch()
    print(f"Test Batch:  {test_v1.shape, test_v2.shape}")
    print(f"Index: {test_idx[:5, :].cpu().numpy()}")

    # Visualize the adjancency structure in the precision matrix as a graph
    import matplotlib.pyplot as plt
    import networkx as nx
    Theta = loader.Sigma.inverse().cpu().numpy(); np.fill_diagonal(Theta, 0)
    Theta = Theta > 1e-5  # Threshold for visualization
    G = nx.from_numpy_array(Theta)
    plt.figure(figsize=(6,6))
    pos = nx.spring_layout(G)
    nx.draw(G, pos, with_labels=True, node_color='lightblue', edge_color='gray')
    plt.title("Graphical Model Structure from Precision Matrix")
    plt.show()


# %%
