import networkx as nx
import numpy as np
import torch
from custom_models.cnf_src.causal_nf.preparators.scm.batch_generator import (
    BatchGenerator,
)
from custom_models.cnf_src.causal_nf.preparators.tabular_preparator import (
    TabularPreparator,
)
from custom_models.cnf_src.causal_nf.utils.io import dict_to_cn
from custom_models.cnf_src.causal_nf.utils.scalers import StandardTransform
from torch.distributions import Independent, Normal
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.utils.sparse import dense_to_sparse

class CustomDataset(Dataset):
    def __init__(
        self,
        features,
        labels,
        type: str = "torch",
        causal_graph: nx.DiGraph = None,
        use_edge_attr: bool = False,
    ):
        # Add small Gaussian noise to constant columns
        constant_columns = torch.unique(features, dim=0).eq(1).all(dim=0)
        for i, column in enumerate(constant_columns):
            if column:
                features[:, i] += torch.randn(features.shape[0]) * 0.01

        self.x = features
        self.y = features
        self.labels = labels
        self.type = type
        self.use_edge_attr = use_edge_attr
        if causal_graph is not None:
            self.adjacency = torch.from_numpy(nx.to_numpy_array(causal_graph)).bool()
            self.num_nodes = causal_graph.number_of_nodes()

        # VACA has a PyTorch Geometric backend
        if self.type == "pyg":  # pytorch geometric
            self.edge_index = dense_to_sparse(self.adjacency)[0]
            self._edge_attr = torch.eye(self.edge_index.shape[-1])
            self.node_ids = torch.eye(self.num_nodes)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        if self.type == "torch":
            return self.x[index], self.x[index]
        elif self.type == "pyg":
            attr_dict = {}
            attr_dict["x"] = self.x[index].reshape(-1, 1)
            attr_dict["edge_index"] = self.edge_index
            attr_dict["node_ids"] = self.node_ids
            attr_dict["edge_attr"] = self.edge_attr
            return Data(**attr_dict)

    @property
    def edge_attr(self):
        if self.use_edge_attr:
            return self._edge_attr
        else:
            return None


class CustomPreparator(TabularPreparator):
    def __init__(
        self, dataset_df, causal_graph, name, use_edge_attr: bool = False, **kwargs
    ):
        self.dataset = None
        self.nx_causal_graph = causal_graph
        self.causal_graph = torch.from_numpy(
            nx.to_numpy_array(causal_graph)
        ).bool()  # Convert to numpy array
        self.num_nodes = len(list(causal_graph.nodes()))
        self.edge_index = dense_to_sparse(self.causal_graph)[0]
        self.intervention_index_list = 0  # FIXME
        super().__init__(name=name, task="modeling", **kwargs)
        self.type = kwargs["type"]

        # Step 1: Convert DataFrame to tensors
        features = torch.tensor(dataset_df.to_numpy()).double()

        if self.type == "pyg":
            self.data_list = []
            for i in range(features.shape[0]):
                self.data_list.append(
                    Data(x=features[i, :], edge_index=self.edge_index)
                )
            dataset_dl = CustomDataset(
                features,
                list(dataset_df.columns),
                type="pyg",
                causal_graph=causal_graph,
                use_edge_attr=use_edge_attr,
            )
        elif self.type == "torch":
            dataset_dl = CustomDataset(features, list(dataset_df.columns))
        else:
            raise ValueError(f"Unknown dataset type {self.type}")

        self.datasets = [dataset_dl]

        assert self.split == [0.8, 0.2]

    def adjacency(self, add_diag=False):
        adj = self.causal_graph
        if add_diag:
            adj += torch.eye(self.num_nodes).bool()
        return adj

    @classmethod
    def params(cls, dataset):
        if isinstance(dataset, dict):
            dataset = dict_to_cn(dataset)

        return TabularPreparator.params(dataset)

    @classmethod
    def loader(cls, dataset_df, causal_graph, dataset, device=torch.device("cpu")):
        my_dict = CustomPreparator.params(dataset)

        return cls(dataset_df, causal_graph, "custom", device=device, **my_dict)

    def _x_dim(self):
        if self.type == "torch":
            return self.num_nodes
        elif self.type == "pyg":
            return 1

    def edge_attr_dim(self):
        if self.dataset.use_edge_attr:
            return self.datasets[0].edge_attr.shape[-1]
        else:
            return None

    def get_batch_generator(self):
        if self.type == "pyg":
            return BatchGenerator(
                node_dim=1,
                num_nodes=self.num_nodes,
                edge_index=self.edge_index,
                device="cpu",
                node_ids=self.datasets[0].node_ids,
            )

    def get_deg(self):
        # Compute the in-degree histogram tensor

        # Compute in-degrees of all nodes
        in_degrees = dict(self.nx_causal_graph.in_degree()).values()

        # Compute histogram of in-degrees
        max_degree = max(in_degrees)
        hist, _ = np.histogram(
            list(in_degrees), bins=range(max_degree + 2), density=False
        )

        # Convert histogram to PyTorch tensor
        deg_histogram = torch.tensor(hist, dtype=torch.float)
        return deg_histogram.float()

    def get_intervention_list(self):
        x = self.get_features_train().numpy()

        perc_idx = [25, 50, 75]

        percentiles = np.percentile(x, perc_idx, axis=0)
        int_list = []
        for i in self.intervention_index_list:
            percentiles_i = percentiles[:, i]
            values_i = []
            for perc_name, perc_value in zip(perc_idx, percentiles_i):
                values_i.append({"name": f"{perc_name}p", "value": perc_value})

            for value in values_i:
                value["value"] = round(value["value"], 2)
                value["index"] = i
                int_list.append(value)

        return int_list

    def diameter(self):
        adjacency = self.adjacency(True).numpy()
        G = nx.from_numpy_array(adjacency, create_using=nx.DiGraph)
        max_diameter = 0
        for component in nx.strongly_connected_components(G):
            subgraph = G.subgraph(component)
            diameter = nx.diameter(subgraph)
            max_diameter = max(max_diameter, diameter) 
        return max_diameter

    def longest_path_length(self):
        adjacency = self.adjacency(False).numpy()
        G = nx.from_numpy_matrix(adjacency, create_using=nx.DiGraph)
        longest_path_length = nx.algorithms.dag.dag_longest_path_length(G)
        return int(longest_path_length)

    def get_ate_list(self):
        x = self.get_features_train().numpy()

        perc_idx = [25, 50, 75]

        percentiles = np.percentile(x, perc_idx, axis=0)
        int_list = []
        for i in self.intervention_index_list:
            percentiles_i = percentiles[:, i]
            values_i = []
            values_i.append(
                {"name": "25_50", "a": percentiles_i[0], "b": percentiles_i[1]}
            )
            values_i.append(
                {"name": "25_75", "a": percentiles_i[0], "b": percentiles_i[2]}
            )
            values_i.append(
                {"name": "50_75", "a": percentiles_i[1], "b": percentiles_i[2]}
            )
            for value in values_i:
                value["a"] = round(value["a"], 2)
                value["b"] = round(value["b"], 2)
                value["index"] = i
                int_list.append(value)

        return int_list

    def get_ate_list_2(self):
        x = self.get_features_train()

        x_mean = x.mean(0)
        x_std = x.std(0)
        int_list = []
        for i in self.intervention_index_list:
            x_mean_i = x_mean[i].item()
            x_std_i = x_std[i].item()
            values_i = []
            values_i.append({"name": "mu_std", "a": x_mean_i, "b": x_mean_i + x_std_i})
            values_i.append({"name": "mu_-std", "a": x_mean_i, "b": x_mean_i - x_std_i})
            values_i.append(
                {"name": "-std_std", "a": x_mean_i - x_std_i, "b": x_mean_i + x_std_i}
            )
            for value in values_i:
                value["a"] = round(value["a"], 2)
                value["b"] = round(value["b"], 2)
                value["index"] = i
                int_list.append(value)

        return int_list

    def intervene(self, index, value, shape):
        if len(shape) == 1:
            shape = (shape[0], 7)

        x = self.get_features_train()
        cond = x[..., index].floor() == int(value)
        x = x[cond, :]

        return x[: shape[0]]

    def compute_ate(self, index, a, b, num_samples=10000):
        ate = torch.rand((6)) * 2 - 1.0
        return ate

    def compute_counterfactual(self, x_factual, index, value):
        x_cf = torch.randn_like(x_factual)
        x_cf[:, index] = value

        return x_cf

    def log_prob(self, x):
        px = Independent(
            Normal(
                torch.zeros(x.shape[1], device=x.device),
                torch.ones(x.shape[1], device=x.device),
            ),
            1,
        )
        return px.log_prob(x)

    def _loss(self, loss):
        if loss in ["default", "forward"]:
            return "forward"
        else:
            raise NotImplementedError(f"Wrong loss {loss}")

    def get_dataloader_train(self, batch_size, num_workers=0, shuffle=None):
        assert isinstance(self.datasets, list)

        dataset = self.datasets[0]
        shuffle = self.shuffle_train if shuffle is None else shuffle
        loader_train = self._data_loader(
            dataset, batch_size, shuffle=shuffle, num_workers=num_workers
        )

        return loader_train

    def get_dataloaders(self, batch_size, num_workers=0):
        assert isinstance(self.datasets, list)

        # Splitting the dataset
        dataset = self.datasets[0]

        train_size = int(len(dataset) * self.split[0])
        val_size = len(dataset) - train_size

        dataset_train, dataset_val = torch.utils.data.random_split(
            dataset, [train_size, val_size]
        )
        self.dataset = dataset_train.dataset

        # Creating the train DataLoader
        loader_train = self._data_loader(
            dataset_train.dataset, batch_size, shuffle=False, num_workers=num_workers
        )

        # Creating the validation DataLoader
        loader_val = self._data_loader(
            dataset_val.dataset, batch_size, shuffle=False, num_workers=num_workers
        )

        loaders = [loader_train, loader_val]

        return loaders

    def _split_dataset(self, dataset_raw):
        datasets = []

        for i, split_s in enumerate(self.split):
            dataset = GermanDataset(
                root_dir=self.root, split=self.split_names[i], seed=self.k_fold
            )

            dataset.prepare_data()
            dataset.set_add_noise(self.add_noise)
            if i == 0:
                self.dataset = dataset
            datasets.append(dataset)

        return datasets

    def _get_dataset(self, num_samples, split_name):
        raise NotImplementedError

    def get_scaler(self, fit=True):
        scaler = self._get_scaler()
        self.scaler_transform = None
        if fit:
            x = self.get_features_train()
            scaler.fit(x, dims=self.dims_scaler)
            if self.scale in ["default", "std"]:
                self.scaler_transform = StandardTransform(
                    shift=x.mean(0).to(self.device), scale=x.std(0).to(self.device)
                )
                print("scaler_transform", self.scaler_transform)

        self.scaler = scaler

        return scaler

    def get_scaler_info(self):
        if self.scale in ["default", "std"]:
            return [("std", None)]
        else:
            raise NotImplementedError

    @property
    def dims_scaler(self):
        return (0,)

    def _get_dataset_raw(self):
        return None

    def _transform_dataset_pre_split(self, dataset_raw):
        return dataset_raw

    def post_process(self, x, inplace=False):
        if not inplace:
            x = x.clone()
        dims = self.dataset.binary_dims
        min_values = self.dataset.binary_min_values
        max_values = self.dataset.binary_max_values

        x[..., dims] = x[..., dims].floor().float()
        x[..., dims] = torch.clamp(x[..., dims], min=min_values, max=max_values)

        return x

    def feature_names(self, latex=False):
        return self.dataset.column_names

    def _plot_data(
        self,
        batch=None,
        title_elem_idx=None,
        batch_size=None,
        df=None,
        hue=None,
        **kwargs,
    ):
        title = ""
        return super()._plot_data(
            batch=batch,
            title_elem_idx=title_elem_idx,
            batch_size=batch_size,
            df=df,
            title=title,
            hue=hue,
            diag_plot="hist",
        )
