import csv
import json
import os

import numpy as np
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

from .spn_graphs import SPNFGGenerator
from .sems_vectorized import (
    init_params,
    simulate_data_linear,
    simulate_data_nn
)


class DatasetSPNFGGenerator:
    """ Generate datasets using SPN-FG """

    def __init__(
        self,
        n_features,
        n_modules,
        n_samples,
        hard_intv=True,
        max_copies=8,
        p_conn=0.5,
        sparsity_temp=1.0,
        n_hidden=0,
        rescale=True,
        nb_interventions=10,
        nb_test_interventions=1,
        min_nb_target=1,
        max_nb_target=3,
        noise_level=0.4,
        max_corr=1.0,
        min_corr=0.25,
        graph_dropout=0.0,
        dependent_dropout=True,
        conservative=False,
        uniform=False,
        cover=False,
        verbose=True,
    ):
        """ Generate a spn-fg """
        self.n_features = n_features
        self.n_modules = n_modules
        self.n_samples = n_samples
        self.hard = hard_intv
        self.max_copies = max_copies
        self.p_conn = p_conn
        self.sparsity_temp = sparsity_temp
        self.n_hidden = n_hidden
        self.rescale = rescale
        self.verbose = verbose
        self.uniform = uniform
        self.max_corr = max_corr
        self.min_corr = min_corr
        self.graph_dropout = graph_dropout
        self.dependent_dropout = dependent_dropout
        
        self.simulation_function_nn = simulate_data_nn
        self.simulation_function = simulate_data_linear          
        self.graph = None

        # attributes related to interventional data
        self.nb_interventions = nb_interventions
        self.nb_test_interventions = nb_test_interventions
        self.min_nb_target = min_nb_target
        self.max_nb_target = max_nb_target
        self.noise_level = noise_level
        self.conservative = conservative
        self.cover = cover
        self.verbose = verbose
        
        self.feature_alphabet = None

    def generate(self, intervention=True, resample_dag=True):
        # create DAG if does not exist
        if self.graph is None or resample_dag:
            if self.verbose:
                print("Sampling the DAG")
            self.generator = SPNFGGenerator(
                self.n_features,
                self.n_modules,
                0,
                'node',
                self.max_copies,
                1.0,
                self.p_conn,
                self.sparsity_temp,
            )
            self.graph, self.causal_order = self.generator()
            # in alphabetical order
            self.feature_list = np.where(
                [self.graph.nodes[node]["type"] == "node" for node in self.graph.nodes]
            )[0]
            self.module_list = np.where(
                [
                    self.graph.nodes[node]["type"] == "module"
                    for node in self.graph.nodes
                ]
            )[0]

            # Ordered feature and module lists
            self.node_causal_order = self.causal_order.astype(int)  # nodes include both features and modules
            self.feature_list_ordered = self.causal_order[np.isin(self.causal_order, self.feature_list)]
            self.module_list_ordered = self.causal_order[np.isin(self.causal_order, self.module_list)]

            # sort nodes and modules to alphabetical order
            self.feature_alphabet = np.sort(self.feature_list_ordered)
            self.module_alphabet = np.sort(self.module_list_ordered)

            if self.verbose:
                print("Init sem")
            self.causal_order, self.weights = init_params(self.graph, self.n_hidden, self.max_corr, self.min_corr)

        mask_intervention = []
        regimes = []
        # plan intervention scheme, perform them and sample to put together a dataset
        if intervention:
            data = np.zeros((self.n_samples, self.n_features))
            data_factor = np.zeros((self.n_samples, self.n_modules))

            num = self.n_samples
            div = self.nb_interventions + self.nb_test_interventions + 1

            # one-liner taken from 
            # https://stackoverflow.com/questions/20348717/algo-for-dividing-a-number-into-almost-equal-whole-numbers/20348992
            points_per_interv = [
                num // div + (1 if x < num % div else 0) for x in range(div)
            ]

            # randomly pick targets
            target_list = self._pick_targets(nb_max_iteration=10000)
            # get intervention-to-feature matrix
            self.adj_int = np.zeros((self.nb_interventions, self.n_features))
            for i in range(self.nb_interventions):
                self.adj_int[i, target_list[i]] = 1
            # randomly pick test targets not in the training regimes
            test_target_list = self._pick_test_targets(target_list)
            target_list.extend(test_target_list)
            del test_target_list
    
            # perform interventions
            for j in tqdm(range(div), desc="interventions"):
                # these interventions are at the feature level, must convert into features
                targets = np.array([self.feature_list[t] for t in target_list[j]])

                # generate the datasets with the given interventions
                if self.n_hidden > 0:
                    dataset = self.simulation_function_nn(
                        points_per_interv[j],
                        self.weights[0],
                        self.weights[1],
                        self.causal_order,
                        targets,
                        set(self.feature_list),
                        self.hard,
                        self.dependent_dropout,
                        self.uniform,
                        self.noise_level,
                        self.graph_dropout
                    )

                else:
                    dataset = self.simulation_function(
                        points_per_interv[j],
                        self.weights,
                        self.causal_order,
                        targets,
                        set(self.feature_list),
                        self.hard,
                        self.dependent_dropout,
                        self.uniform,
                        self.noise_level,
                        self.graph_dropout
                    )

                # keep only the "feature" nodes
                factors = dataset[:, self.module_list]
                dataset = dataset[:, self.feature_list]
                # put dataset and targets in arrays
                if j == 0:
                    start = 0
                else:
                    start = np.cumsum(points_per_interv[:j])[-1]
                end = start + points_per_interv[j]
                data[start:end, :] = dataset
                data_factor[start:end, :] = factors
                # here add at the feature level, not node
                mask_intervention.extend(
                    [target_list[j] for i in range(points_per_interv[j])]
                )
                if j <= self.nb_interventions:  # training set
                    regimes.extend([j+1 for i in range(points_per_interv[j])])
                else:  # we use negative integers to encode test(unseen) interventions
                    regimes.extend([self.nb_interventions-1-j for i in range(points_per_interv[j])])
        else:
            # generate the datasets with no intervention
            if self.n_hidden > 0:
                data = self.simulation_function_nn(
                    self.n_samples,
                    self.weights[0],
                    self.weights[1],
                    self.causal_order,
                    np.array([-1]),
                    set(self.feature_list),
                    self.hard,
                    self.dependent_dropout,
                    self.uniform,
                    self.noise_level,
                    self.graph_dropout
                )
            else:
                data = self.simulation_function(
                    self.n_samples,
                    self.weights,
                    self.causal_order,
                    np.array([-1]),
                    set(self.feature_list),
                    self.hard,
                    self.dependent_dropout,
                    self.uniform,
                    self.noise_level,
                    self.graph_dropout
                )
            data_factor = data[:, self.module_list]
            data = data[:, self.feature_list]

        if self.rescale:
            scaler = StandardScaler()
            scaler.fit_transform(data)

        # dump into class
        self.data = data
        self.data_factor = data_factor
        self.mask_intervention = mask_intervention
        self.regimes = regimes

    def _pick_targets(self, nb_max_iteration=100000):
        nodes = range(self.n_features)
        not_correct = True
        i = 0

        if self.max_nb_target == 1:
            if self.cover:
                intervention = np.random.permutation(self.n_features)
            else:
                intervention = np.random.choice(
                    self.n_features, self.nb_interventions, replace=False
                )
            targets = [[i] for i in intervention]
            if len(targets) < self.nb_interventions:
                targets.extend(
                    [[np.random.choice(self.n_features)] for _ in range(self.nb_interventions - len(targets))]
                )
        else:
            while not_correct and i < nb_max_iteration:
                targets = []
                not_correct = False
                i += 1

                # pick targets randomly
                for _ in range(self.nb_interventions):
                    nb_targets = np.random.randint(
                        self.min_nb_target, self.max_nb_target + 1, 1
                    )
                    intervention = np.random.choice(
                        self.n_features, nb_targets, replace=False
                    )
                    targets.append(intervention)

                # apply rejection sampling
                if self.cover and not self._is_covering_pairs(nodes, targets):
                    not_correct = True

            if i == nb_max_iteration:
                raise ValueError(
                    "Could generate appropriate targets. \
                                 Exceeded the maximal number of iterations"
                )

            for i, t in enumerate(targets):
                targets[i] = np.sort(t)
            
        targets.extend([[]])

        return targets

    def _pick_test_targets(self, train_targets):
        """ Pick test targets as combinations of training targets """
        def _combine_targets(*targets):
            out = set()
            for target in targets:
                if len(target) > 0:
                    out = out.union(set(target))
            return list(out)

        if len(train_targets) == 1:
            return [train_targets[0]]
        if len(train_targets) == 2:
            return _combine_targets(train_targets[0], train_targets[1])
        targets = []
        for i in range(self.nb_test_interventions):
            ind = np.random.choice(len(train_targets), 2, replace=False)
            targets.append(_combine_targets(train_targets[ind[0]], train_targets[ind[1]]))
        return targets

    def _check_test_targets(self, train_targets, test_targets):
        for x in test_targets:
            for y in train_targets:
                if set(x) == set(y):
                    return False
        return True

    def _is_conservative(self, elements, lists):
        for e in elements:
            conservative = False

            for list_ in lists:
                if e not in list_:
                    conservative = True
                    break
            if not conservative:
                return False
        return True

    def _is_covering(self, elements, lists):
        return set(elements) == self._union(lists)
    
    def _is_covering_pairs(self, elements, lists):
        # print('Checking paired-order condition.')
        perturbed = np.zeros((len(elements), len(elements))).astype(bool)
        for targets in lists:
            perturbed[targets] = True
            perturbed[targets][:, targets] = False
        for i in range(len(elements)):
            perturbed[i, i] = True
        # print('Finished')
        return np.all(perturbed)

    def _union(self, lists):
        union_set = set()

        for _list in lists:
            union_set = union_set.union(set(_list))
        return union_set

    def save_data(self, folder, i):
        # save data
        print("saving")
        os.makedirs(folder, exist_ok=True)
        dag_path = os.path.join(folder, f"DAG{i}.npy")
        np.save(dag_path, self.generator.bipartite_half)
        dag_int_path = os.path.join(folder, f"DAG_int{i}.npy")
        np.save(dag_int_path, self.adj_int)

        np.save(os.path.join(folder, f"U{i}.npy"), self.generator.U)
        np.save(os.path.join(folder, f"V{i}.npy"), self.generator.V)
        np.save(os.path.join(folder, f"adj{i}.npy"), self.generator.adj_mtx)
        np.save(os.path.join(folder, f"module{i}.npy"), self.generator.module_list)
        np.save(os.path.join(folder, f"order{i}.npy"), self.node_causal_order)
        np.save(os.path.join(folder, f"module_ordered{i}.npy"), self.module_list_ordered)
        np.save(os.path.join(folder, f"feature_ordered{i}.npy"), self.feature_list_ordered)
        np.save(os.path.join(folder, f"module_alphabet{i}.npy"), self.module_alphabet)
        np.save(os.path.join(folder, f"feature_alphabet{i}.npy"), self.feature_alphabet)

        # save data
        if len(self.mask_intervention) == 0:
            data_path = os.path.join(folder, f"data{i}.npy")
            np.save(data_path, self.data)
        else:
            data_path = os.path.join(folder, f"data_interv{i}.npy")
            np.save(data_path, self.data)

            data_path = os.path.join(folder, f"intervention{i}.csv")
            with open(data_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerows(self.mask_intervention)
        # save regimes
        if self.regimes is not None:
            regime_path = os.path.join(folder, f"regime{i}.csv")
            with open(regime_path, "w", newline="") as f:
                writer = csv.writer(f)
                for regime in self.regimes:
                    writer.writerow([regime])
        
        # save simulation metadata as json
        metadata = {
            "n_samples": self.n_samples,
            "n_features": self.n_features,
            "n_modules": self.n_modules,
            "max_copies": self.max_copies,
            "p_conn": self.p_conn,
            "sparsity_temp": self.sparsity_temp,
            "n_hidden": self.n_hidden,
            "rescale": self.rescale,
            "nb_interventions": self.nb_interventions,
            "min_nb_target": self.min_nb_target,
            "max_nb_target": self.max_nb_target,
            "noise_level": self.noise_level,
            "max_corr": self.max_corr,
            "min_corr": self.min_corr,
            "conservative": self.conservative,
            "uniform": self.uniform,
            "cover": self.cover,
        }
        metadata_path = os.path.join(folder, f"metadata{i}.json")
        with open(metadata_path, "w") as f:
            json.dump(metadata, f)
