import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from causal_nf.preparators.tabular_preparator import TabularPreparator
from causal_nf.utils.scalers import StandardTransform, IdentityTransform

class InMemoryDataset(Dataset):
    def __init__(self, data_matrix):
        self.data = torch.tensor(data_matrix, dtype=torch.float32)
    def __getitem__(self, idx):
        return self.data[idx], 0  # second value is dummy (for compatibility)
    def __len__(self):
        return self.data.shape[0]

class CustomPreparator(TabularPreparator):
    def __init__(self, data, adjacency, index_to_variable, discrete=False, batch_size=128, device="cpu", scale="default"):
        self.name = "custom"
        self.splits = [1.0, 0.0, 0.0]
        self.shuffle_train = False
        self.single_split = "train"
        self.task = "modeling"
        self.k_fold = 1
        self.root = "./"
        self.loss = "default"
        self.scale = scale
        super().__init__(
            name=self.name,
            splits=self.splits,
            shuffle_train=self.shuffle_train,
            single_split=self.single_split,
            task=self.task,
            k_fold=self.k_fold,
            root=self.root,
            loss=self.loss,
            scale=self.scale,
        )
        # data: dict of var name -> np.ndarray (num_samples, dim)
        # adjacency: dict of int -> list of int
        # index_to_variable: list of var names (strings)
        self.index_to_variable = index_to_variable
        self.variable_to_index = {v: i for i, v in enumerate(index_to_variable)}
        self.num_nodes = len(index_to_variable)
        self.discrete = discrete
        self.device = device
        self.batch_size = batch_size
        # Prepare data matrix (num_samples, num_nodes)
        # Assume all variables are 1D for now
        data_matrix = np.column_stack([data[v] for v in index_to_variable])
        self.data_matrix = data_matrix
        self.datasets = [InMemoryDataset(data_matrix)]
        # Convert adjacency list to adjacency matrix
        adj = np.zeros((self.num_nodes, self.num_nodes), dtype=np.float32)
        for src, dsts in adjacency.items():
            for dst in dsts:
                adj[src, dst] = 1.0
        self._adjacency = torch.tensor(adj, dtype=torch.float32)

    def get_dataloader_train(self, batch_size=None, num_workers=0, shuffle=None):
        if batch_size is None:
            batch_size = self.batch_size
        return DataLoader(self.datasets[0], batch_size=batch_size, shuffle=False, num_workers=num_workers)

    def adjacency(self, as_numpy=False):
        if as_numpy:
            return self._adjacency.numpy()
        return self._adjacency

    def feature_names(self, latex=False):
        if latex:
            return [f"$x_{{{i+1}}}$" for i in range(self.num_nodes)]
        else:
            return list(self.index_to_variable)

    def get_features_train(self):
        # Return all features as a torch tensor
        return torch.tensor(self.data_matrix, dtype=torch.float32)

    def x_dim(self):
        return self.num_nodes

    def _x_dim(self):
        return self.num_nodes

    def num_samples(self):
        return self.data_matrix.shape[0]

    def _get_dataset_raw(self):
        # Not used for in-memory data
        return None

    def _loss(self, loss):
        # Not used for in-memory data
        return None

    def get_scaler(self, fit=True):
        scaler = self._get_scaler()
        self.scaler_transform = None
        if fit:
            x = self.get_features_train()
            scaler.fit(x, dims=self.dims_scaler)
            if self.scale in ["default"]:
                self.scaler_transform = IdentityTransform()
                # print("scaler_transform", self.scaler_transform)
            elif self.scale in ["min0_max1"]:
                self.scaler_transform = StandardTransform(
                    shift=x.mean(0), scale=x.std(0)
                )

        self.scaler = scaler
        return scaler

    @property
    def dims_scaler(self):
        return (0,)