import csv
import os

import numpy as np
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset
from tqdm.autonotebook import tqdm

from .bipartite_graphs import IntvDagModuleGenerator
from .enhanced_graphs import IntvEnhancedDagModuleGenerator
from .sems_ifg_vectorized import (
    init_params,
    simulate_data_linear,
    simulate_data_nn,
)


class DatasetLowRankGeneratorIFG:
    """Generate datasets using simulations_fg.py. This is an extension of the 
    original DatasetLowRankGenerator class to mimic the effect of untargeted interventions.
    These interventions are assumed to affect the latent factors."""

    def __init__(
        self,
        n_features,
        n_modules,
        p_vertex,
        p_module,
        n_samples,
        hard_intv=True,
        n_hidden=0,
        enhanced=False,
        rescale=True,
        nb_interventions=None,
        nb_test_interventions=1,
        min_nb_target=1,
        max_nb_target=1,
        noise_level=0.4,
        max_corr=1.0,
        min_corr=0.25,
        graph_dropout=0.0,
        dependency_dropout=True,
        alpha=0.1,
        scale=0.1,
        conservative=False,
        uniform=False,
        cover=False,
        enhanced_kwargs={
          'mean_par': 5,
          'mean_chr': 5,
          'min_unique_par': 1,
          'min_unique_chr': 1,
          'min_upstream': -1,
          'min_downstream': -1
        },
        verbose=True
    ):
        """
        Generate a dataset containing interventions. The setting is similar to the
        one in the DCDI paper, but with a low-rank structure. Save the lists of targets
        in a separate file.

        Args:
            n_features (int): Number of features
            n_modules (int): Number of modules
            p_vertex (float): Probability of adding a vertex connecting to a (parent) module
            p_module (float): Probability of adding a module connecting to a (parent) vertex
            n_samples (int): Number of samples
            n_hidden (int): Number of hidden nodes. If set to 0, a linear model will be used.
            enhanced (bool): if True, use the enhanced generator
            hard (bool): if True, apply hard intervention to factors
            rescale (bool): if True, rescale each variables
            min_nb_target (int): minimal number of targets per setting
            max_nb_target (int): maximal number of targets per setting. For a fixed
                                 number of target, one can make min_nb_target==max_nb_target
            noise_level (float): noise level in the data
            max_corr (float): maximal correlation between variables
            min_corr (float): minimal correlation between variables
            alpha (float): shape parameter for the gamma distribution
            scale (float): scale parameter for the gamma distribution
            conservative (bool): if True, make sure that the intervention family is
                                 conservative: i.e. that all nodes have not been
                                 intervened in at least one setting.
            uniform (bool): if True, set noise distribution to uniform
            cover (bool): if True, make sure that all nodes have been
                                 intervened on at least in one setting.
            enhanced_kwargs (dict): arguments for the enhanced generator
            verbose (bool): if True, print messages to inform users
        """
        self.n_features = n_features
        self.n_modules = n_modules
        self.n_total_vertices = n_features + n_modules + nb_interventions
        self.p_vertex = p_vertex
        self.p_module = p_module
        self.n_samples = n_samples
        self.n_hidden = n_hidden
        self.enhanced = enhanced
        self.enhanced_kwargs = enhanced_kwargs
        self.rescale = rescale
        self.verbose = verbose
        self.uniform = uniform
        self.max_corr = max_corr
        self.min_corr = min_corr
        self.dropout = graph_dropout
        self.dependency_dropout = dependency_dropout
        self.alpha = alpha
        self.scale = scale
        
        self.simulation_function_nn = simulate_data_nn
        self.simulation_function = simulate_data_linear
        self.graph = None

        # attributes related to interventional data
        self.hard = hard_intv
        if nb_interventions is None:
            self.nb_interventions = self.n_modules
        else:
            self.nb_interventions = nb_interventions
        self.nb_test_interventions = nb_test_interventions
        self.min_nb_target = min_nb_target
        self.max_nb_target = max_nb_target
        self.noise_level = noise_level
        self.conservative = conservative
        self.cover = cover
        self.verbose = verbose

    def _get_downstream_nodes(self, graph, modules):
        downstream_nodes = set()
        for module in modules:
            downstream_nodes = downstream_nodes.union(set(graph.successors(module)))
        return list(downstream_nodes)

    def generate(self, resample_dag=True):
        # create DAG if does not exist
        if self.graph is None or resample_dag:
            if self.verbose:
                print("Sampling the DAG")
            if self.enhanced:
                self.generator = IntvEnhancedDagModuleGenerator(
                    features=self.n_features,
                    modules=self.n_modules,
                    intvs=self.nb_interventions,
                    mean_par=self.enhanced_kwargs['mean_par'],
                    mean_chr=self.enhanced_kwargs['mean_chr'],
                    min_unique_par=self.enhanced_kwargs['min_unique_par'],
                    min_unique_chr=self.enhanced_kwargs['min_unique_chr'],
                    verbose=self.verbose
                )
            else:
                self.generator = IntvDagModuleGenerator(
                    features=self.n_features,
                    modules=self.n_modules,
                    p_module=self.p_module,
                    p_vertex=self.p_vertex,
                )
            # randomly pick modules as targets
            target_list = self._pick_targets()
            self.graph, self.causal_order = self.generator(target_list)

            self.feature_list = np.where(
                [self.graph.nodes[node]["type"] == "node" for node in self.graph.nodes]
            )[0]
            self.module_list = np.where(
                [
                    self.graph.nodes[node]["type"] == "module"
                    for node in self.graph.nodes
                ]
            )[0]
            self.intv_list = np.where(
                [
                    self.graph.nodes[node]["type"] == "intervention"
                    for node in self.graph.nodes
                ]
            )[0]
            assert len(self.module_list) == self.n_modules, "Number of modules is not correct"
            assert len(self.intv_list) == self.nb_interventions, "Number of interventions is not correct"

            # Ordered feature and module lists
            self.node_causal_order = self.causal_order.astype(int)  # nodes include both features and modules
            self.feature_list_ordered = self.causal_order[np.isin(self.causal_order, self.feature_list)]
            self.module_list_ordered = self.causal_order[np.isin(self.causal_order, self.module_list)]

            # sort nodes and modules to alphabetical order
            self.feature_alphabet = np.sort(self.feature_list_ordered)
            self.module_alphabet = np.sort(self.module_list_ordered)

            if self.verbose:
                print("Init sem")
            self.causal_order, self.weights = init_params(self.graph, self.n_hidden, self.max_corr, self.min_corr)
            # import pdb
            # pdb.set_trace()
        mask_intervention = []
        regimes = []
        self.int2module = self.generator.W.copy()
        # plan intervention scheme, perform them and sample to put together a dataset
        data = np.zeros((self.n_samples, self.n_features))
        data_factor = np.zeros((self.n_samples, self.n_modules))
        perturb_features = np.zeros((self.n_samples, self.nb_interventions))

        num = self.n_samples
        div = self.nb_interventions + self.nb_test_interventions + 1

        # one-liner taken from 
        # https://stackoverflow.com/questions/20348717/algo-for-dividing-a-number-into-almost-equal-whole-numbers/20348992
        points_per_interv = [
            num // div + (1 if x < num % div else 0) for x in range(div)
        ]

        
        # randomly pick test targets not in the training regimes
        test_combs, test_target_list = self._pick_test_targets(target_list)
        target_list.extend(test_target_list)

        # perform interventions
        assert div == len(target_list)
        for j in tqdm(range(div), desc="interventions"):
            # these interventions are at the module level, must convert into modules
            # track downstream features of intervened modules
            target_modules = [self.module_list[t] for t in target_list[j]]
            targets = self._get_downstream_nodes(self.graph, target_modules)
            if j < self.nb_interventions:
                intv_src = set([self.intv_list[j]])
            elif j > self.nb_interventions:
                intv_src = set([self.intv_list[t] for t in test_combs[j-self.nb_interventions-1]])

            # generate the datasets with the given interventions
            if self.n_hidden > 0:
                dataset = self.simulation_function_nn(
                    points_per_interv[j],
                    self.weights[0],
                    self.weights[1],
                    self.causal_order,
                    targets,
                    intv_src,
                    self.intv_list,
                    set(self.feature_list),
                    self.hard,
                    self.alpha,
                    self.scale,
                    self.noise_level,
                    self.dropout,
                    self.dependency_dropout,
                    self.uniform
                )
            else:
                dataset = self.simulation_function(
                    points_per_interv[j],
                    self.weights,
                    self.causal_order,
                    targets,
                    intv_src,
                    self.intv_list,
                    set(self.feature_list),
                    self.hard,
                    self.alpha,
                    self.scale,
                    self.noise_level,
                    self.dropout,
                    self.dependency_dropout,
                    self.uniform
                )

            # keep only the "feature" nodes
            factors = dataset[:, self.module_list]
            intv_feature = dataset[:, self.intv_list]
            dataset = dataset[:, self.feature_list]
            # put dataset and targets in arrays
            if j == 0:
                start = 0
            else:
                start = np.cumsum(points_per_interv[:j])[-1]
            end = start + points_per_interv[j]
            data[start:end, :] = dataset
            data_factor[start:end, :] = factors
            # perturbation information
            perturb_features[start:end, :] = intv_feature

            # here add at the feature level, not factor
            # Find downstream nodes of the factor intervention
            if self.hard:
                if len(target_list[j]) > 0:
                    downstream_nodes = list(np.where(self.generator.V[np.array(target_list[j])].sum(0))[0])
                else:
                    downstream_nodes = []
                mask_intervention.extend(
                    [downstream_nodes for i in range(points_per_interv[j])]
                )
            else:
                mask_intervention.extend([])
            if j <= self.nb_interventions:  # training set
                regimes.extend([j+1 for i in range(points_per_interv[j])])
            else:  # we use negative integers to encode test(unseen) interventions
                regimes.extend([self.nb_interventions-1-j for i in range(points_per_interv[j])])

        if self.rescale:
            scaler = StandardScaler()
            scaler.fit_transform(data)

        # dump into class
        self.data = data
        self.data_factor = data_factor
        self.perturb_features = perturb_features
        self.mask_intervention = mask_intervention
        self.regimes = regimes

    def _pick_targets(self, nb_max_iteration=100000):
        factors = range(self.n_modules)
        not_correct = True
        i = 0

        if self.max_nb_target == 1:
            if self.cover:
                intervention = np.random.permutation(self.n_modules)
            else:
                intervention = np.random.choice(
                    self.n_modules, self.nb_interventions, replace=False
                )
            targets = [[i] for i in intervention]
            if len(targets) < self.nb_interventions:
                targets.extend(
                    [[np.random.choice(self.n_modules)] for _ in range(self.nb_interventions - len(targets))]
                )
        else:
            while not_correct and i < nb_max_iteration:
                targets = []
                not_correct = False
                i += 1

                # pick targets randomly
                for _ in range(self.nb_interventions):
                    nb_targets = np.random.randint(
                        self.min_nb_target, self.max_nb_target + 1, 1
                    )
                    intervention = np.random.choice(
                        self.n_modules, nb_targets, replace=False
                    )
                    targets.append(intervention)

                # apply rejection sampling
                if self.cover and not self._is_covering(factors, targets):
                    not_correct = True

            if i == nb_max_iteration:
                raise ValueError(
                    "Could generate appropriate targets. \
                                 Exceeded the maximal number of iterations"
                )

            for i, t in enumerate(targets):
                targets[i] = np.sort(t)
            
        targets.extend([[]])

        return targets
    
    def _pick_test_targets(self, train_targets):
        """ Pick test targets as combinations of training targets """
        def _combine_targets(*targets):
            out = set()
            for target in targets:
                if len(target) > 0:
                    out = out.union(set(target))
            return list(out)

        if self.nb_interventions == 1:
            return [(0, 0)], [train_targets[0]]
        targets = []
        has_enough_combs = self.nb_interventions * (self.nb_interventions - 1) >= self.nb_test_interventions
        test_combs = [
            np.random.choice(self.nb_interventions, 2, replace=not has_enough_combs)
            for i in range(self.nb_test_interventions)
        ]
        for i, comb in enumerate(test_combs):
            targets.append(_combine_targets(train_targets[comb[0]], train_targets[comb[1]]))
        return test_combs, targets

    def _check_test_targets(self, train_targets, test_targets):
        for x in test_targets:
            for y in train_targets:
                if set(x) == set(y):
                    return False
        return True

    def _is_conservative(self, elements, lists):
        for e in elements:
            conservative = False

            for list_ in lists:
                if e not in list_:
                    conservative = True
                    break
            if not conservative:
                return False
        return True

    def _is_covering(self, elements, lists):
        return set(elements) == self._union(lists)
    
    def _is_covering_pairs(self, elements, lists):
        # print('Checking paired-order condition.')
        perturbed = np.zeros((len(elements), len(elements))).astype(bool)
        for targets in lists:
            perturbed[targets] = True
            perturbed[targets][:, targets] = False
        for i in range(len(elements)):
            perturbed[targets[i], targets[i]] = True
        # print('Finished')
        return np.all(perturbed)

    def _union(self, lists):
        union_set = set()

        for l in lists:
            union_set = union_set.union(set(l))
        return union_set

    def save_data(self, folder, i):
        # save da
        print("saving")
        if not os.path.exists(folder):
            os.removedirs(folder)
        os.makedirs(folder, exist_ok=True)
        dag_path = os.path.join(folder, f"DAG{i}.npy")
        np.save(dag_path, self.generator.bipartite_half)
        dag_int_path = os.path.join(folder, f"DAG_int{i}.npy")
        np.save(dag_int_path, self.generator.intv_half)

        np.save(os.path.join(folder, f"U{i}.npy"), self.generator.U)
        np.save(os.path.join(folder, f"V{i}.npy"), self.generator.V)
        np.save(os.path.join(folder, f"W{i}.npy"), self.int2module)
        np.save(os.path.join(folder, f"adj{i}.npy"), self.generator.adjacency_matrix)
        np.save(os.path.join(folder, f"module{i}.npy"), self.generator.module_list)
        np.save(os.path.join(folder, f"order{i}.npy"), self.node_causal_order)
        np.save(os.path.join(folder, f"module_ordered{i}.npy"), self.module_list_ordered)
        np.save(os.path.join(folder, f"feature_ordered{i}.npy"), self.feature_list_ordered)
        np.save(os.path.join(folder, f"module_alphabet{i}.npy"), self.module_alphabet)
        np.save(os.path.join(folder, f"feature_alphabet{i}.npy"), self.feature_alphabet)

        # save data
        if self.nb_interventions == 0:
            data_path = os.path.join(folder, f"data{i}.npy")
            np.save(data_path, self.data)
        else:
            data_path = os.path.join(folder, f"data_interv{i}.npy")
            np.save(data_path, self.data)
            data_perturb_path = os.path.join(folder, f"perturb_features{i}.npy")
            np.save(data_perturb_path, self.perturb_features)
            np.save(os.path.join(folder, f"data_factor{i}.npy"), self.data_factor)

            data_path = os.path.join(folder, f"intervention{i}.csv")
            with open(data_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerows(self.mask_intervention)

        # save regimes
        if self.regimes is not None:
            regime_path = os.path.join(folder, f"regime{i}.csv")
            with open(regime_path, "w", newline="") as f:
                writer = csv.writer(f)
                for regime in self.regimes:
                    writer.writerow([regime])


class SimulationDatasetIFG(Dataset):
    """
    A generic class for simulation data loading and extraction, as well as pre-filtering of interventions
    NOTE: the 0-th regime should always be the observational one
    """

    def __init__(
        self,
        file_path,
        i_dataset,
        intervention=True,
        unknown_intervention=False,
        soft_intervention=False,
        load_test=False
    ) -> None:
        """Simuation Dataset

        Args:
            file_path (str): path to the data
            i_dataset (int): index of the dataset
            intervention (bool, optional): Whether data is interventional. Defaults to True.
            unknown_intervention (bool, optional): Whether intervention targets are known. Defaults to True.
            soft_intervention (bool, optional): Whether interventions are soft. Defaults to False.
            load_test (bool, optional): Whether to load the test set. Defaults to False.
        """
        super(SimulationDatasetIFG, self).__init__()
        self.file_path = file_path
        self.i_dataset = i_dataset
        self.intervention = intervention
        self.unknown_intervention = unknown_intervention
        self.soft_intervention = soft_intervention
        # load data
        all_data, all_perturb, all_masks, all_regimes = self.load_data()
        # index of all regimes, even if not used in the regimes_to_ignore case
        self.all_regimes_list = np.unique(all_regimes)
        # determine which regimes to ignore
        if load_test:
            self.regimes_to_ignore = np.array(
                list(filter(lambda x: x >= 0, self.all_regimes_list))
            )
        else:
            self.regimes_to_ignore = np.array(
                list(filter(lambda x: x < 0, self.all_regimes_list))
            )
        to_keep = np.array(
            [
                regime not in self.regimes_to_ignore
                for regime in np.array(all_regimes)
            ]
        )
        data = all_data[to_keep]
        perturbation = all_perturb[to_keep]
        masks = [mask for i, mask in enumerate(all_masks) if to_keep[i]]
        regimes = np.array(
            [regime for i, regime in enumerate(all_regimes) if to_keep[i]]
        )

        self.data = np.concatenate([data, perturbation], axis=1)
        self.regimes = regimes
        self.masks = np.array(masks, dtype=object)

        self.num_regimes = np.unique(self.regimes).shape[0]
        self.num_samples = self.data.shape[0]
        self.dim = self.data.shape[1]
        self.dim_intv = all_perturb.shape[1]

    def __getitem__(self, idx):
        if self.intervention and not self.unknown_intervention and not self.soft_intervention:
            # binarize mask from list
            masks_list = self.masks[idx]
            masks = np.ones((self.dim - self.dim_intv,))
            for j in masks_list:
                masks[j] = 0
            return (
                self.data[idx].astype(np.float32),
                masks.astype(np.float32),
                self.regimes[idx],
            )
        else:
            return (
                self.data[idx].astype(np.float32),
                np.ones((self.dim - self.dim_intv,)).astype(np.float32),
                self.regimes[idx],
            )

    def __len__(self):
        return self.data.shape[0]

    def load_data(self):
        """
        Load the mask, regimes, and data
        """
        name_data = f"data_interv{self.i_dataset}.npy"
        name_perturb = f"perturb_features{self.i_dataset}.npy"

        # Load data
        self.data_path = os.path.join(self.file_path, name_data)
        data = np.load(self.data_path, allow_pickle=True)
        self.perturb_path = os.path.join(self.file_path, name_perturb)
        perturb_features = np.load(self.perturb_path, allow_pickle=True)

        # Load intervention masks and regimes
        masks = []
        if self.intervention:
            name_data = f"data_interv{self.i_dataset}.npy"
            interv_path = os.path.join(
                self.file_path, f"intervention{self.i_dataset}.csv"
            )
            regimes = np.genfromtxt(
                os.path.join(self.file_path, f"regime{self.i_dataset}.csv"),
                delimiter=",",
            )
            regimes = regimes.astype(int)

            # read masks
            with open(interv_path, "r") as f:
                interventions_csv = csv.reader(f)
                for row in interventions_csv:
                    mask = [int(x) for x in row]
                    masks.append(mask)
        else:
            regimes = np.array([0] * data.shape[0])

        return data, perturb_features, masks, regimes

