import csv
import json
import os
import pdb

import numpy as np
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

from .spn_graphs import SPNFGGenerator
from .sems_ifg_vectorized import (
    init_params,
    simulate_data_linear,
    simulate_data_nn,
)


class DatasetSPNIFGGenerator:
    """ Generate datasets with unknown intervention targets using SPN-FG """

    def __init__(
        self,
        n_features,
        n_modules,
        n_samples,
        hard_intv=True,
        max_copies=8,
        p_conn=(0.5, 0.5),
        sparsity_temp=(1.0, 1.0),
        n_hidden=0,
        rescale=False,
        nb_interventions=10,
        nb_test_interventions=1,
        min_nb_target=1,
        max_nb_target=1,
        noise_level=0.4,
        max_corr=1.0,
        min_corr=0.25,
        graph_dropout=0.0,
        dependency_dropout=True,
        alpha=0.1,
        scale=0.1,
        conservative=False,
        uniform=False,
        cover=False,
        verbose=True,
    ):
        """ Generate a spn-fg """
        self.n_features = n_features
        self.n_modules = n_modules
        self.n_total_vertices = n_features + n_modules + nb_interventions
        self.n_samples = n_samples
        self.max_copies = max_copies
        self.p_conn = p_conn
        self.sparsity_temp = sparsity_temp
        self.n_hidden = n_hidden
        self.rescale = rescale
        self.verbose = verbose
        self.uniform = uniform
        self.max_corr = max_corr
        self.min_corr = min_corr
        self.dropout = graph_dropout
        self.dependency_dropout = dependency_dropout
        self.alpha = alpha
        self.scale = scale
        
        
        self.simulation_function_nn = simulate_data_nn
        self.simulation_function = simulate_data_linear
        self.graph = None

        # attributes related to interventional data
        self.hard = hard_intv
        self.nb_interventions = nb_interventions
        self.nb_test_interventions = nb_test_interventions
        self.min_nb_target = min_nb_target
        self.max_nb_target = max_nb_target
        self.noise_level = noise_level
        self.conservative = conservative
        self.cover = cover
        self.verbose = verbose
        
        self.feature_alphabet = None

    def _get_downstream_nodes(self, graph, modules):
        downstream_nodes = set()
        for module in modules:
            downstream_nodes = downstream_nodes.union(set(graph.successors(module)))
        return list(downstream_nodes)

    def generate(self, resample_dag=True):
        # create DAG if does not exist
        if self.graph is None or resample_dag:
            if self.verbose:
                print("Sampling the DAG")
            self.generator = SPNFGGenerator(
                self.n_features,
                self.n_modules,
                self.nb_interventions,
                'node',
                self.max_copies,
                1.0,
                self.p_conn,
                self.sparsity_temp,
            )
            self.graph, self.causal_order = self.generator()
            # in alphabetical order
            self.feature_list = np.where(
                [self.graph.nodes[node]["type"] == "node" for node in self.graph.nodes]
            )[0]
            self.module_list = np.where(
                [
                    self.graph.nodes[node]["type"] == "module"
                    for node in self.graph.nodes
                ]
            )[0]
            self.intv_list = np.where(
                [
                    self.graph.nodes[node]["type"] == "intervention"
                    for node in self.graph.nodes
                ]
            )[0]

            # Ordered feature and module lists
            self.node_causal_order = self.causal_order.astype(int)  # nodes include both features and modules
            self.feature_list_ordered = self.causal_order[np.isin(self.causal_order, self.feature_list)]
            self.module_list_ordered = self.causal_order[np.isin(self.causal_order, self.module_list)]

            # sort nodes and modules to alphabetical order
            self.feature_alphabet = np.sort(self.feature_list_ordered)
            self.module_alphabet = np.sort(self.module_list_ordered)

            if self.verbose:
                print("Init sem")
            self.causal_order, self.weights = init_params(self.graph, self.n_hidden, self.max_corr, self.min_corr)
            # pdb.set_trace()
        mask_intervention = []
        regimes = []
        self.int2module = self.generator.W
        # plan intervention scheme, perform them and sample to put together a dataset
        data = np.zeros((self.n_samples, self.n_features))
        data_factor = np.zeros((self.n_samples, self.n_modules))
        perturb_features = np.zeros((self.n_samples, self.nb_interventions))

        num = self.n_samples
        div = self.nb_interventions + self.nb_test_interventions + 1
        # one-liner taken from 
        # https://stackoverflow.com/questions/20348717/algo-for-dividing-a-number-into-almost-equal-whole-numbers/20348992
        points_per_interv = [
            num // div + (1 if x < num % div else 0) for x in range(div)
        ]

        # convert intervention-to-module matrix to target list
        target_list = []
        for i in range(self.nb_interventions):
            target_list.append(np.where(self.int2module[i] == 1)[0])
        target_list.append([])
        # randomly pick test targets not in the training regimes
        test_combs, test_target_list = self._pick_test_targets(target_list)
        target_list.extend(test_target_list)

        # perform interventions
        assert div == len(target_list)
        for j in tqdm(range(div), desc="interventions"):
            target_modules = [self.module_list[t] for t in target_list[j]]
            targets = self._get_downstream_nodes(self.graph, target_modules)

            if j < self.nb_interventions:
                intv_src = set([self.intv_list[j]])
            elif j > self.nb_interventions:
                intv_src = set([self.intv_list[t] for t in test_combs[j-self.nb_interventions-1]])

            # generate the datasets with the given interventions
            if self.n_hidden > 0:
                dataset = self.simulation_function_nn(
                    points_per_interv[j],
                    self.weights[0],
                    self.weights[1],
                    self.causal_order,
                    targets,
                    intv_src,
                    self.intv_list,
                    set(self.feature_list),
                    self.hard,
                    self.alpha,
                    self.scale,
                    self.noise_level,
                    self.dropout,
                    self.dependency_dropout,
                    self.uniform
                )
            else:
                dataset = self.simulation_function(
                    points_per_interv[j],
                    self.weights,
                    self.causal_order,
                    targets,
                    intv_src,
                    self.intv_list,
                    set(self.feature_list),
                    self.hard,
                    self.alpha,
                    self.scale,
                    self.noise_level,
                    self.dropout,
                    self.dependency_dropout,
                    self.uniform
                )

            # keep only the "feature" nodes
            factors = dataset[:, self.module_list]
            intv_feature = dataset[:, self.intv_list]
            dataset = dataset[:, self.feature_list]

            # put dataset and targets in arrays
            if j == 0:
                start = 0
            else:
                start = np.cumsum(points_per_interv[:j])[-1]
            end = start + points_per_interv[j]
            data[start:end, :] = dataset
            data_factor[start:end, :] = factors
            # perturbation information
            perturb_features[start:end, :] = intv_feature

            # here add at the feature level, not node
            if self.hard:
                if len(target_list[j]) > 0:
                    downstream_nodes = list(np.where(self.generator.V[np.array(target_list[j])].sum(0))[0])
                else:
                    downstream_nodes = []
                mask_intervention.extend(
                    [downstream_nodes for i in range(points_per_interv[j])]
                )
            else:
                mask_intervention.extend([[] for i in range(points_per_interv[j])])
            if j <= self.nb_interventions:  # training set
                regimes.extend([j+1 for i in range(points_per_interv[j])])
            else:  # we use negative integers to encode test(unseen) interventions
                regimes.extend([self.nb_interventions-1-j for i in range(points_per_interv[j])])

        if self.rescale:
            scaler = StandardScaler()
            scaler.fit_transform(data)

        # dump into class
        self.data = data
        self.data_factor = data_factor
        self.perturb_features = perturb_features
        self.mask_intervention = mask_intervention
        self.regimes = regimes

    def _pick_test_targets(self, train_targets):
        """ Pick test targets as combinations of training targets """
        def _combine_targets(*targets):
            out = set()
            for target in targets:
                if len(target) > 0:
                    out = out.union(set(target))
            return list(out)

        if self.nb_interventions == 1:
            return [(0, 0)], [train_targets[0]]
        targets = []
        has_enough_combs = self.nb_interventions * (self.nb_interventions - 1) >= self.nb_test_interventions
        test_combs = [
            np.random.choice(self.nb_interventions, 2, replace=not has_enough_combs)
            for i in range(self.nb_test_interventions)
        ]
        for i, comb in enumerate(test_combs):
            targets.append(_combine_targets(train_targets[comb[0]], train_targets[comb[1]]))
        return test_combs, targets

    def _check_test_targets(self, train_targets, test_targets):
        for x in test_targets:
            for y in train_targets:
                if set(x) == set(y):
                    return False
        return True

    def _is_conservative(self, elements, lists):
        for e in elements:
            conservative = False

            for list_ in lists:
                if e not in list_:
                    conservative = True
                    break
            if not conservative:
                return False
        return True

    def _is_covering(self, elements, lists):
        return set(elements) == self._union(lists)
    
    def _is_covering_pairs(self, elements, lists):
        # print('Checking paired-order condition.')
        perturbed = np.zeros((len(elements), len(elements))).astype(bool)
        for targets in lists:
            perturbed[targets] = True
            perturbed[targets][:, targets] = False
        for i in range(len(elements)):
            perturbed[i, i] = True
        # print('Finished')
        return np.all(perturbed)

    def _union(self, lists):
        union_set = set()

        for _list in lists:
            union_set = union_set.union(set(_list))
        return union_set

    def save_data(self, folder, i):
        # save data
        print("saving")
        if not os.path.exists(folder):
            os.removedirs(folder)
        os.makedirs(folder, exist_ok=True)
        dag_path = os.path.join(folder, f"DAG{i}.npy")
        np.save(dag_path, self.generator.bipartite_half)
        dag_int_path = os.path.join(folder, f"DAG_int{i}.npy")
        np.save(dag_int_path, self.generator.intv_half)

        np.save(os.path.join(folder, f"U{i}.npy"), self.generator.U)
        np.save(os.path.join(folder, f"V{i}.npy"), self.generator.V)
        np.save(os.path.join(folder, f"W{i}.npy"), self.generator.W)
        np.save(os.path.join(folder, f"adj{i}.npy"), self.generator.adj_mtx)
        np.save(os.path.join(folder, f"module{i}.npy"), self.generator.module_list)
        np.save(os.path.join(folder, f"order{i}.npy"), self.node_causal_order)
        np.save(os.path.join(folder, f"module_ordered{i}.npy"), self.module_list_ordered)
        np.save(os.path.join(folder, f"feature_ordered{i}.npy"), self.feature_list_ordered)
        np.save(os.path.join(folder, f"module_alphabet{i}.npy"), self.module_alphabet)
        np.save(os.path.join(folder, f"feature_alphabet{i}.npy"), self.feature_alphabet)

        # save data
        if self.nb_interventions == 0:
            data_path = os.path.join(folder, f"data{i}.npy")
            np.save(data_path, self.data)
        else:
            data_path = os.path.join(folder, f"data_interv{i}.npy")
            np.save(data_path, self.data)
            data_perturb_path = os.path.join(folder, f"perturb_features{i}.npy")
            np.save(data_perturb_path, self.perturb_features)
            np.save(os.path.join(folder, f"data_factor{i}.npy"), self.data_factor)

            data_path = os.path.join(folder, f"intervention{i}.csv")
            with open(data_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerows(self.mask_intervention)

        # save regimes
        if self.regimes is not None:
            regime_path = os.path.join(folder, f"regime{i}.csv")
            with open(regime_path, "w", newline="") as f:
                writer = csv.writer(f)
                for regime in self.regimes:
                    writer.writerow([regime])
        
        # save simulation metadata as json
        metadata = {
            "n_samples": self.n_samples,
            "n_features": self.n_features,
            "n_modules": self.n_modules,
            "max_copies": self.max_copies,
            "p_conn": self.p_conn,
            "sparsity_temp": self.sparsity_temp,
            "n_hidden": self.n_hidden,
            "rescale": self.rescale,
            "nb_interventions": self.nb_interventions,
            "min_nb_target": self.min_nb_target,
            "max_nb_target": self.max_nb_target,
            "noise_level": self.noise_level,
            "max_corr": self.max_corr,
            "min_corr": self.min_corr,
            "conservative": self.conservative,
            "uniform": self.uniform,
            "cover": self.cover,
        }
        metadata_path = os.path.join(folder, f"metadata{i}.json")
        with open(metadata_path, "w") as f:
            json.dump(metadata, f)
