import csv
import json
import os

import numpy as np
from sklearn.preprocessing import StandardScaler
from tqdm.auto import tqdm
import torch
from torch.utils.data import Dataset

from .bipartite_graphs import DagModuleGenerator
from .enhanced_graphs import EnhancedDagModuleGenerator
from .sems_vectorized import (
    init_params,
    simulate_data_linear,
    simulate_data_nn,
)


class DatasetLowRankGenerator:
    """Generate datasets using simulations.py. `nb_dag` dags are sampled and
    then data are generated accordingly to the chosen parameters (e.g.
    mechanisms). Can generate dataset with 'hard stochastic' interventions"""

    def __init__(
        self,
        n_features,
        n_modules,
        p_vertex,
        p_module,
        n_samples,
        hard_intv=True,
        n_hidden=0,
        enhanced=False,
        rescale=True,
        nb_interventions=10,
        nb_test_interventions=1,
        min_nb_target=1,
        max_nb_target=3,
        noise_level=0.4,
        max_corr=1.0,
        min_corr=0.25,
        graph_dropout=0.0,
        dependent_dropout=True,
        conservative=False,
        uniform=False,
        cover=False,
        enhanced_kwargs={
          'mean_par': 5,
          'mean_chr': 5,
          'min_unique_par': 1,
          'min_unique_chr': 1,
          'min_upstream': -1,
          'min_downstream': -1
        },
        verbose=True,
    ):
        """
        Generate a dataset containing interventions. The setting is similar to the
        one in the DCDI paper, but with a low-rank structure. Save the lists of targets
        in a separate file.

        Args:
            n_features (int): Number of features
            n_modules (int): Number of modules
            p_vertex (float): Probability of adding a vertex connecting to a (parent) module
            p_module (float): Probability of adding a module connecting to a (parent) vertex
            n_samples (int): Number of samples
            n_hidden (int): Number of hidden nodes. If set to 0, a linear model will be used.
            enhanced (bool): if True, use the enhanced generator
            rescale (bool): if True, rescale each variables
            nb_interventions (int): number of interventional settings
            min_nb_target (int): minimal number of targets per setting
            max_nb_target (int): maximal number of targets per setting. For a fixed
                                 number of target, one can make min_nb_target==max_nb_target
            noise_level (float): noise level in the data
            max_corr (float): maximal correlation between variables
            min_corr (float): minimal correlation between variables
            conservative (bool): if True, make sure that the intervention family is
                                 conservative: i.e. that all nodes have not been
                                 intervened in at least one setting.
            uniform (bool): if True, set noise distribution to uniform
            cover (bool): if True, make sure that all nodes have been
                                 intervened on at least in one setting.
            enhanced_kwargs (dict): arguments for the enhanced generator
            verbose (bool): if True, print messages to inform users
        """
        self.n_features = n_features
        self.n_modules = n_modules
        self.p_vertex = p_vertex
        self.p_module = p_module
        self.n_samples = n_samples
        self.n_hidden = n_hidden
        self.enhanced = enhanced
        self.enhanced_kwargs = enhanced_kwargs
        self.rescale = rescale
        self.verbose = verbose
        self.uniform = uniform
        self.max_corr = max_corr
        self.min_corr = min_corr
        self.graph_dropout = graph_dropout
        self.dependent_dropout = dependent_dropout
        
        self.simulation_function_nn = simulate_data_nn
        self.simulation_function = simulate_data_linear          
        self.graph = None

        # attributes related to interventional data
        self.hard = hard_intv
        self.nb_interventions = nb_interventions
        self.nb_test_interventions = nb_test_interventions
        self.min_nb_target = min_nb_target
        self.max_nb_target = max_nb_target
        self.noise_level = noise_level
        self.conservative = conservative
        self.cover = cover
        self.verbose = verbose
        
        self.feature_alphabet = None

    def generate(self, intervention=True, resample_dag=True):
        # create DAG if does not exist
        if self.graph is None or resample_dag:
            if self.verbose:
                print("Sampling the DAG")
            if self.enhanced:
                self.generator = EnhancedDagModuleGenerator(
                    features=self.n_features,
                    modules=self.n_modules,
                    mean_par=self.enhanced_kwargs['mean_par'],
                    mean_chr=self.enhanced_kwargs['mean_chr'],
                    min_unique_par=self.enhanced_kwargs['min_unique_par'],
                    min_unique_chr=self.enhanced_kwargs['min_unique_chr'],
                    verbose=self.verbose
                )
            else:
                self.generator = DagModuleGenerator(
                    features=self.n_features,
                    modules=self.n_modules,
                    p_module=self.p_module,
                    p_vertex=self.p_vertex,
                )
            self.graph, self.causal_order = self.generator()
            # in alphabetical order
            self.feature_list = np.where(
                [self.graph.nodes[node]["type"] == "node" for node in self.graph.nodes]
            )[0]
            self.module_list = np.where(
                [
                    self.graph.nodes[node]["type"] == "module"
                    for node in self.graph.nodes
                ]
            )[0]

            # Ordered feature and module lists
            self.node_causal_order = self.causal_order.astype(int)  # nodes include both features and modules
            self.feature_list_ordered = self.causal_order[np.isin(self.causal_order, self.feature_list)]
            self.module_list_ordered = self.causal_order[np.isin(self.causal_order, self.module_list)]

            # sort nodes and modules to alphabetical order
            self.feature_alphabet = np.sort(self.feature_list_ordered)
            self.module_alphabet = np.sort(self.module_list_ordered)

            if self.verbose:
                print("Init sem")
            self.causal_order, self.weights = init_params(self.graph, self.n_hidden, self.max_corr, self.min_corr)

        mask_intervention = []
        regimes = []
        # plan intervention scheme, perform them and sample to put together a dataset
        if intervention:
            data = np.zeros((self.n_samples, self.n_features))
            data_factor = np.zeros((self.n_samples, self.n_modules))

            num = self.n_samples
            div = self.nb_interventions + self.nb_test_interventions + 1

            # one-liner taken from 
            # https://stackoverflow.com/questions/20348717/algo-for-dividing-a-number-into-almost-equal-whole-numbers/20348992
            points_per_interv = [
                num // div + (1 if x < num % div else 0) for x in range(div)
            ]

            # randomly pick targets
            target_list = self._pick_targets(nb_max_iteration=10000)
            # get intervention-to-feature matrix
            self.adj_int = np.zeros((self.nb_interventions, self.n_features))
            for i in range(self.nb_interventions):
                self.adj_int[i, target_list[i]] = 1
            # randomly pick test targets not in the training regimes
            test_target_list = self._pick_test_targets(target_list)
            target_list.extend(test_target_list)
            del test_target_list
    
            # perform interventions
            for j in tqdm(range(div), desc="interventions"):
                # these interventions are at the feature level, must convert into features
                targets = np.array([self.feature_list[t] for t in target_list[j]])

                # generate the datasets with the given interventions
                if self.n_hidden > 0:
                    dataset = self.simulation_function_nn(
                        points_per_interv[j],
                        self.weights[0],
                        self.weights[1],
                        self.causal_order,
                        targets,
                        set(self.feature_list),
                        self.hard,
                        self.dependent_dropout,
                        self.uniform,
                        self.noise_level,
                        self.graph_dropout
                    )

                else:
                    dataset = self.simulation_function(
                        points_per_interv[j],
                        self.weights,
                        self.causal_order,
                        targets,
                        set(self.feature_list),
                        self.hard,
                        self.dependent_dropout,
                        self.uniform,
                        self.noise_level,
                        self.graph_dropout
                    )

                # keep only the "feature" nodes
                factors = dataset[:, self.module_list]
                dataset = dataset[:, self.feature_list]
                # put dataset and targets in arrays
                if j == 0:
                    start = 0
                else:
                    start = np.cumsum(points_per_interv[:j])[-1]
                end = start + points_per_interv[j]
                data[start:end, :] = dataset
                data_factor[start:end, :] = factors
                # here add at the feature level, not node
                mask_intervention.extend(
                    [target_list[j] for i in range(points_per_interv[j])]
                )
                if j <= self.nb_interventions:  # training set
                    regimes.extend([j+1 for i in range(points_per_interv[j])])
                else:  # we use negative integers to encode test(unseen) interventions
                    regimes.extend([self.nb_interventions-1-j for i in range(points_per_interv[j])])
        else:
            # generate the datasets with no intervention
            if self.n_hidden > 0:
                data = self.simulation_function_nn(
                    self.n_samples,
                    self.weights[0],
                    self.weights[1],
                    self.causal_order,
                    np.array([-1]),
                    set(self.feature_list),
                    self.hard,
                    self.dependent_dropout,
                    self.uniform,
                    self.noise_level,
                    self.graph_dropout
                )
            else:
                data = self.simulation_function(
                    self.n_samples,
                    self.weights,
                    self.causal_order,
                    np.array([-1]),
                    self.hard,
                    self.dependent_dropout,
                    self.uniform,
                    self.noise_level,
                    self.graph_dropout
                )
            data = data[:, self.feature_list]
            data_factor = data[:, self.module_list]

        if self.rescale:
            scaler = StandardScaler()
            scaler.fit_transform(data)

        # dump into class
        self.data = data
        self.data_factor = data_factor
        self.mask_intervention = mask_intervention
        self.regimes = regimes

    def _pick_targets(self, nb_max_iteration=100000):
        nodes = range(self.n_features)
        not_correct = True
        i = 0

        if self.max_nb_target == 1:
            if self.cover:
                intervention = np.random.permutation(self.n_features)
            else:
                intervention = np.random.choice(
                    self.n_features, self.nb_interventions, replace=False
                )
            targets = [[i] for i in intervention]
            if len(targets) < self.nb_interventions:
                targets.extend(
                    [[np.random.choice(self.n_features)] for _ in range(self.nb_interventions - len(targets))]
                )
        else:
            while not_correct and i < nb_max_iteration:
                targets = []
                not_correct = False
                i += 1

                # pick targets randomly
                for _ in range(self.nb_interventions):
                    nb_targets = np.random.randint(
                        self.min_nb_target, self.max_nb_target + 1, 1
                    )
                    intervention = np.random.choice(
                        self.n_features, nb_targets, replace=False
                    )
                    targets.append(intervention)

                # apply rejection sampling
                if self.cover and not self._is_covering_pairs(nodes, targets):
                    not_correct = True

            if i == nb_max_iteration:
                raise ValueError(
                    "Could generate appropriate targets. \
                                 Exceeded the maximal number of iterations"
                )

            for i, t in enumerate(targets):
                targets[i] = np.sort(t)
            
        targets.extend([[]])

        return targets

    def _pick_test_targets(self, train_targets):
        """ Pick test targets as combinations of training targets """
        def _combine_targets(*targets):
            out = set()
            for target in targets:
                if len(target) > 0:
                    out = out.union(set(target))
            return list(out)

        if len(train_targets) == 1:
            return [train_targets[0]]
        if len(train_targets) == 2:
            return _combine_targets(train_targets[0], train_targets[1])
        targets = []
        for i in range(self.nb_test_interventions):
            ind = np.random.choice(len(train_targets), 2, replace=False)
            targets.append(_combine_targets(train_targets[ind[0]], train_targets[ind[1]]))
        return targets

    def _check_test_targets(self, train_targets, test_targets):
        for x in test_targets:
            for y in train_targets:
                if set(x) == set(y):
                    return False
        return True

    def _is_conservative(self, elements, lists):
        for e in elements:
            conservative = False

            for list_ in lists:
                if e not in list_:
                    conservative = True
                    break
            if not conservative:
                return False
        return True

    def _is_covering(self, elements, lists):
        return set(elements) == self._union(lists)
    
    def _is_covering_pairs(self, elements, lists):
        # print('Checking paired-order condition.')
        perturbed = np.zeros((len(elements), len(elements))).astype(bool)
        for targets in lists:
            perturbed[targets] = True
            perturbed[targets][:, targets] = False
        for i in range(len(elements)):
            perturbed[i, i] = True
        # print('Finished')
        return np.all(perturbed)

    def _union(self, lists):
        union_set = set()

        for _list in lists:
            union_set = union_set.union(set(_list))
        return union_set

    def save_data(self, folder, i):
        # save da
        print("saving")
        os.makedirs(folder, exist_ok=True)
        dag_path = os.path.join(folder, f"DAG{i}.npy")
        np.save(dag_path, self.generator.bipartite_half)
        dag_int_path = os.path.join(folder, f"DAG_int{i}.npy")
        np.save(dag_int_path, self.adj_int)

        np.save(os.path.join(folder, f"U{i}.npy"), self.generator.U)
        np.save(os.path.join(folder, f"V{i}.npy"), self.generator.V)
        np.save(os.path.join(folder, f"adj{i}.npy"), self.generator.adjacency_matrix)
        np.save(os.path.join(folder, f"module{i}.npy"), self.generator.module_list)
        np.save(os.path.join(folder, f"order{i}.npy"), self.node_causal_order)
        np.save(os.path.join(folder, f"module_ordered{i}.npy"), self.module_list_ordered)
        np.save(os.path.join(folder, f"feature_ordered{i}.npy"), self.feature_list_ordered)
        np.save(os.path.join(folder, f"module_alphabet{i}.npy"), self.module_alphabet)
        np.save(os.path.join(folder, f"feature_alphabet{i}.npy"), self.feature_alphabet)

        # save data
        if len(self.mask_intervention) == 0:
            data_path = os.path.join(folder, f"data{i}.npy")
            np.save(data_path, self.data)
        else:
            data_path = os.path.join(folder, f"data_interv{i}.npy")
            np.save(data_path, self.data)

            data_path = os.path.join(folder, f"intervention{i}.csv")
            with open(data_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerows(self.mask_intervention)
        # save regimes
        if self.regimes is not None:
            regime_path = os.path.join(folder, f"regime{i}.csv")
            with open(regime_path, "w", newline="") as f:
                writer = csv.writer(f)
                for regime in self.regimes:
                    writer.writerow([regime])
        
        # save simulation metadata as json
        metadata = {
            "n_samples": self.n_samples,
            "n_features": self.n_features,
            "n_modules": self.n_modules,
            "p_vertex": self.p_vertex,
            "p_module": self.p_module,
            "n_hidden": self.n_hidden,
            "rescale": self.rescale,
            "nb_interventions": self.nb_interventions,
            "min_nb_target": self.min_nb_target,
            "max_nb_target": self.max_nb_target,
            "noise_level": self.noise_level,
            "max_corr": self.max_corr,
            "min_corr": self.min_corr,
            "conservative": self.conservative,
            "uniform": self.uniform,
            "cover": self.cover,
            "enhanced": self.enhanced,
            "enhanced_kwargs": self.enhanced_kwargs,
        }
        metadata_path = os.path.join(folder, f"metadata{i}.json")
        with open(metadata_path, "w") as f:
            json.dump(metadata, f)


class SimulationDataset(Dataset):
    """
    A generic class for simulation data loading and extraction, as well as pre-filtering of interventions
    NOTE: the 0-th regime should always be the observational one
    """

    def __init__(
        self,
        file_path,
        i_dataset,
        intervention=True,
        unknown_intervention=False,
        soft_intervention=False,
        load_test=False
    ) -> None:
        """Simuation Dataset

        Args:
            file_path (str): path to the data
            i_dataset (int): index of the dataset
            intervention (bool, optional): Whether data is interventional. Defaults to True.
            unknown_intervention (bool, optional): Whether intervention targets are known. Defaults to True.
            soft_intervention (bool, optional): Whether interventions are soft. Defaults to False.
            load_test (bool, optional): Whether to load the test set. Defaults to False.
        """
        super(SimulationDataset, self).__init__()
        self.file_path = file_path
        self.i_dataset = i_dataset
        self.intervention = intervention
        self.unknown_intervention = unknown_intervention
        self.soft_intervention = soft_intervention
        # load data
        all_data, all_masks, all_regimes = self.load_data()
        # index of all regimes, even if not used in the regimes_to_ignore case
        self.all_regimes_list = np.unique(all_regimes)

        # determine which regimes to ignore
        if load_test:
            self.regimes_to_ignore = np.array(
                list(filter(lambda x: x >= 0, self.all_regimes_list))
            )
        else:
            self.regimes_to_ignore = np.array(
                list(filter(lambda x: x < 0, self.all_regimes_list))
            )
        to_keep = np.array(
            [
                regime not in self.regimes_to_ignore
                for regime in np.array(all_regimes)
            ]
        )
        data = all_data[to_keep]
        masks = [mask for i, mask in enumerate(all_masks) if to_keep[i]]
        regimes = np.array(
            [regime for i, regime in enumerate(all_regimes) if to_keep[i]]
        )

        self.data = data
        self.regimes = regimes
        self.masks = np.array(masks, dtype=object)

        self.num_regimes = np.unique(self.regimes).shape[0]
        self.num_samples = self.data.shape[0]
        self.dim = self.data.shape[1]

    def __getitem__(self, idx):
        if self.intervention and not self.unknown_intervention and not self.soft_intervention:
            # binarize mask from list
            masks_list = self.masks[idx]
            masks = np.ones((self.dim,))
            for j in masks_list:
                masks[j] = 0
            return (
                self.data[idx].astype(np.float32),
                masks.astype(np.float32),
                self.regimes[idx],
            )
        else:
            # put full ones mask
            return (
                self.data[idx].astype(np.float32),
                np.ones((self.dim,)).astype(np.float32),
                self.regimes[idx],
            )

    def __len__(self):
        return self.data.shape[0]

    def load_data(self):
        """
        Load the mask, regimes, and data
        """
        if self.intervention:
            name_data = f"data_interv{self.i_dataset}.npy"
        else:
            name_data = f"data{self.i_dataset}.npy"

        # Load data
        self.data_path = os.path.join(self.file_path, name_data)
        data = np.load(self.data_path, allow_pickle=True)

        # Load intervention masks and regimes
        masks = []
        if self.intervention:
            name_data = f"data_interv{self.i_dataset}.npy"
            interv_path = os.path.join(
                self.file_path, f"intervention{self.i_dataset}.csv"
            )
            regimes = np.genfromtxt(
                os.path.join(self.file_path, f"regime{self.i_dataset}.csv"),
                delimiter=",",
            )
            regimes = regimes.astype(int)

            # read masks
            with open(interv_path, "r") as f:
                interventions_csv = csv.reader(f)
                for row in interventions_csv:
                    mask = [int(x) for x in row]
                    masks.append(mask)
        else:
            regimes = np.array([0] * data.shape[0])

        return data, masks, regimes

    def convert_masks(self, idxs):
        """
        Convert mask index to mask vectors
        :param np.ndarray idxs: indices of mask to convert
        :return: masks
        Example:
            if self.masks[i] = [1,4]
                self.dim = 10 then
            masks[i] = [1,0,1,1,0,1,1,1,1,1]
        """
        masks_list = [self.masks[i] for i in idxs]

        masks = torch.ones((idxs.shape[0], self.dim))
        for i, m in enumerate(masks_list):
            for j in m:
                masks[i, j] = 0

        return masks
