from data_generators import generate_synthetic_data, data_distributions
from scaling_functions import *
from causallearn.search.ConstraintBased.PC import pc
import scipy.stats
import numpy as np
import pandas as pd
import networkx as nx
from models import * 
from dodiscover.toporder import *
from dodiscover.metrics import toporder_divergence
from scipy.sparse import csr_matrix
import anndata as ad
import scanpy as sc
import gies
import gies.utils
from ease import *
import scipy.stats
from concurrent.futures import ProcessPoolExecutor

def prepare_data(d, n_obs, n_interv, distribution, seed):
    g, data, I = generate_synthetic_data(d, n_obs, n_interv, distribution=distribution, seed=seed)
    inter = np.array([np.where(row)[0][0] if np.any(row) else -1 for row in I])
    return g, data, inter, I

def create_graph(g):
    graph = nx.DiGraph()
    num_nodes = g.shape[0]
    graph.add_nodes_from(range(num_nodes))
    for i, j in zip(*np.where(g == 1)):
        graph.add_edge(i, j)
    return graph

def fit_scalers(data, inter):
    scalers = {
        method: scaler_class()
        for method, scaler_class in scaling_classes.items() if method in ["Standard Scaling"]
    }
    for scaler in scalers.values():
        if hasattr(scaler, "fit"):
            scaler.fit(data[inter == -1, :])
    return scalers


def var_task(data, inter, var, d, ):
    score_column = np.zeros(d)
    data_obs = data[inter == -1, var]
    for node in range(d):
        data_inter = data[inter == node, var]
        w_dist = scipy.stats.wasserstein_distance(data_obs, data_inter)
        if node != var:
            score_column[node] = w_dist
    return var, score_column

def batch_var_task(data, inter, vars, d):
    scores = []
    for var in vars:
        score_column = np.zeros(d)
        
        for node in range(d):
            if node != var:
                data_inter = data[inter == node, var]
                data_obs = data[inter == -1, var]
                if len(data_inter) > 0:
                    w_dist = scipy.stats.wasserstein_distance(data_obs, data_inter)
                else:
                    w_dist = 0.0
                score_column[node] = w_dist
        scores.append((var, score_column))
    return scores

def calculate_wasserstein_distances(data, inter, d, batch_size=10, scaling_method=None, scaler=None):
    score_matrix = np.zeros((d, d))

    # Create batches of variables
    batches = [range(i, min(i + batch_size, d)) for i in range(0, d, batch_size)]

    with ProcessPoolExecutor(max_workers=25) as executor:
        futures = [executor.submit(batch_var_task, data, inter, batch, d) for batch in batches]

        for future in futures:
            batch_scores = future.result()
            for var, score_column in batch_scores:
                score_matrix[:, var] = score_column

    return score_matrix

def createFullyConnectedGraph(topological_order):
    n = len(topological_order)
    adj_matrix = np.zeros((n, n))

    for i in range(n):
        for j in range(i + 1, n):
            adj_matrix[topological_order[i], topological_order[j]] = 1

    return adj_matrix

def main():
    n_obs = 5000
    n_interv = 100

    results_list = []
    for d in [ 10, 30 ]:
        for p_cover in [0.0, 0.25, 0.5, 0.75, 1.0]:
            for seed in range(0, 10):
                test_config(n_obs, n_interv, p_cover, d, results_list, seed)

    results_df = pd.DataFrame(results_list)
    print(results_df.mean())
    results_df.to_csv('results_test.csv', index=False)


def test_config(n_obs, n_interv, p_cover, d, results_list, seed):
    for distribution in data_distributions.keys():
        g, data, inter, inter_a = prepare_data(d, n_obs, n_interv, distribution, seed)
        n_intervened = int(p_cover * d)
        intervened_vars = list(range(n_intervened))
        intervened_vars.append(-1)
        subset = np.isin(inter, intervened_vars)
        data = data[subset]
        inter = inter[subset]
        if distribution == 'gene-ecoli':
            counts = csr_matrix(data)
            adata = ad.AnnData(counts)
            adata.obs_names = [str(i) for i in inter]
            sc.pp.normalize_per_cell(adata, )
            sc.pp.log1p(adata)
        graph = create_graph(g)
        scalers = fit_scalers(data, inter)
        for method, scaler in scalers.items():
            scaled_data = scaler.transform(data)
            score_matrix = calculate_wasserstein_distances(scaled_data, inter, d,)
            if distribution == 'gene-ecoli':
                eps = 0.5
            else:
                eps = 0.3
            def score_ordering(topological_order):
                tot = 0
                before = list()
                after = list(range(d))
                for i in topological_order:
                    after.remove(i)
                    if np.any(score_matrix[i, :] > 0.0):
                        positive = np.sum(score_matrix[i, after] - eps)
                        tot += positive
                    before.append(i)
                return tot

            pred_sort_ranking, G_sort_ranking = sort_ranking(score_matrix, scaled_data, inter, eps)
            topological_order_sortranking = list(nx.topological_sort(create_graph(pred_sort_ranking))) 
            candidate = local_search_extended(topological_order_sortranking, score_ordering)
            score_ordering_pagerank = toporder_divergence(graph, candidate)
            pred_sort_ranking = createFullyConnectedGraph(candidate)

            obs_data = scaled_data[inter == -1, :]
            
            
            ease_ordering = ease(pd.DataFrame(scaled_data))
            score_ordering_ease = toporder_divergence(graph, ease_ordering)

            data_ = []
            I = []
            for i in set(inter):
                if i != -1:
                    I.append([i])
                    data_.append(scaled_data[inter == i, :])
            data_gies = data_.copy()
            data_gies.append(obs_data)
            I.append([])
            
            adjacency, _ = gies.fit_bic(data_gies, I, )
            only_direct = gies.utils.only_directed(adjacency)
            gies_ordering = list(nx.topological_sort(create_graph(only_direct)))
            score_ordering_gies = toporder_divergence(graph, gies_ordering)

            cg_with_background_knowledge = pc(scaled_data)
            directed = cg_with_background_knowledge.find_fully_directed()
            pc_graph = nx.DiGraph()
            pc_graph.add_nodes_from(range(d))
            for (i, j) in directed:
                if not nx.has_path(pc_graph, j, i):
                    pc_graph.add_edge(i, j)
            cg_with_background_knowledge = nx.to_numpy_array(pc_graph)
            pc_ordering = list(nx.topological_sort(pc_graph))
            score_ordering_pc = toporder_divergence(graph, pc_ordering)
        results_list.append({
                'Distribution': distribution,
                'Seed': seed,
                'Score pagerank ordering': score_ordering_pagerank,
                'Score ease ordering': score_ordering_ease,
                'Score PC ordering': score_ordering_pc,
                'Score gies ordering': score_ordering_gies,
                'Number variables': d,
                'Number edges': np.sum(g),
                'Number observations': n_obs,
                'Number interventions': n_interv,
                'Number intervened': n_intervened
            })
        results_df = pd.DataFrame(results_list)
        results_df.to_csv(f'results_test_{d}.csv', index=False)

if __name__ == "__main__":
    main()
