from data_generators import generate_synthetic_data, data_distributions
from scaling_functions import *
from causallearn.search.ConstraintBased.PC import pc
import scipy.stats
import numpy as np
import pandas as pd
import networkx as nx
from dodiscover.toporder import *
from dodiscover.metrics import toporder_divergence
from scipy.sparse import csr_matrix
import anndata as ad
import scanpy as sc
import scipy.stats
import torch
from concurrent.futures import ProcessPoolExecutor
from diffintersort import diffintersort, fill_triangular
from intersort import _sort_ranking, local_search_extended, score_ordering
import argparse
from diffintersort_constraint import causal_discovery
from avici.metrics import shd, classification_metrics
from cdt.metrics import SID

def prepare_data(d, n_obs, n_interv, distribution, seed, scale_free):
    g, data, I = generate_synthetic_data(d, n_obs, n_interv, distribution=distribution, seed=seed, sf=scale_free)
    inter = np.array([np.where(row)[0][0] if np.any(row) else -1 for row in I])
    return g, data, inter, I

def create_graph(g):
    graph = nx.DiGraph()
    num_nodes = g.shape[0]
    graph.add_nodes_from(range(num_nodes))
    for i, j in zip(*np.where(g == 1)):
        graph.add_edge(i, j)
    return graph

def fit_scalers(data, inter):
    scalers = {
        method: scaler_class()
        for method, scaler_class in scaling_classes.items() if method in ["Standard Scaling"]
    }
    for scaler in scalers.values():
        if hasattr(scaler, "fit"):
            scaler.fit(data[inter == -1, :])
    return scalers


def var_task(data, inter, var, d, ):
    score_column = np.zeros(d)
    data_obs = data[inter == -1, var]
    for node in range(d):
        data_inter = data[inter == node, var]
        w_dist = scipy.stats.wasserstein_distance(data_obs, data_inter)
        if node != var:
            score_column[node] = w_dist
    return var, score_column

def batch_var_task(data, inter, vars, d):
    scores = []
    for var in vars:
        score_column = np.zeros(d)
        
        for node in range(d):
            if node != var:
                data_inter = data[inter == node, var]
                data_obs = data[inter == -1, var]
                if len(data_inter) > 0:
                    w_dist = scipy.stats.wasserstein_distance(data_obs, data_inter)
                else:
                    w_dist = 0.0
                score_column[node] = w_dist
        scores.append((var, score_column))
    return scores

def calculate_wasserstein_distances(data, inter, d, batch_size=10, scaling_method=None, scaler=None):
    score_matrix = np.zeros((d, d))

    # Create batches of variables
    batches = [range(i, min(i + batch_size, d)) for i in range(0, d, batch_size)]

    with ProcessPoolExecutor(max_workers=14) as executor:
        futures = [executor.submit(batch_var_task, data, inter, batch, d) for batch in batches]

        for future in futures:
            batch_scores = future.result()
            for var, score_column in batch_scores:
                score_matrix[:, var] = score_column

    return score_matrix

def createFullyConnectedGraph(topological_order):
    n = len(topological_order)
    adj_matrix = np.zeros((n, n))

    for i in range(n):
        for j in range(i + 1, n):
            adj_matrix[topological_order[i], topological_order[j]] = 1

    return adj_matrix


config = {
    3: {
        "lr": 0.05,
        "scaling": 0.1,
    },
    10: {
        "lr": 0.05,
        "scaling": 0.5,
    },
    30: {
        "lr": 0.01,
        "scaling": 1.0,
    },
    100: {
        "lr": 0.001,
        "scaling": 1.0,
    },
    1000: {
        "lr": 0.0005,
        "scaling": 1.0,
    },
    2000: {
        "lr": 0.0001,
        "scaling": 1.0,
    }
}


def main(d, seeds, domain, scale_free):
    n_obs = 5000
    n_interv = 100

    results_list = []
    for p_cover in [0.25, 0.5, 0.75, 1.0]:
        for seed in range(0, seeds):
            test_config(domain, n_obs, n_interv, p_cover, d, results_list, seed, scale_free)

    results_df = pd.DataFrame(results_list)
    results_df.to_csv('results_test.csv', index=False)

distributions = {
    "lin":
        ['lin-gauss',
        'lin-gauss-heterosked',
        'lin-laplace',]
    ,
    "rff": [
        'rff-gauss',
        'rff-gauss-heterosked',
        'rff-laplace',
    ],
    "gene": [
        'gene-ecoli',
    ]
}


def test_config(domain, n_obs, n_interv, p_cover, d, results_list, seed, scale_free=False):
    for distribution in distributions[domain]:
        g, data, inter, inter_a = prepare_data(d, n_obs, n_interv, distribution, seed, scale_free)
        n_intervened = int(p_cover * d)
        intervened_vars = list(range(n_intervened))
        intervened_vars.append(-1)
        subset = np.isin(inter, intervened_vars)
        data = data[subset]
        inter = inter[subset]
        if distribution == 'gene-ecoli':
            counts = csr_matrix(data)
            adata = ad.AnnData(counts)
            adata.obs_names = [str(i) for i in inter]
            sc.pp.normalize_per_cell(adata, )
            sc.pp.log1p(adata)
            adata_df = adata.to_df()

        graph = create_graph(g)
        scalers = fit_scalers(data, inter)
        for method, scaler in scalers.items():
            scaled_data = scaler.transform(data)
            score_matrix = calculate_wasserstein_distances(scaled_data, inter, d,)
            if distribution == 'gene-ecoli':
                eps = 0.5
            else:
                eps = 0.3
            def score_ordering(topological_order):
                tot = 0
                before = list()
                after = list(range(d))
                for i in topological_order:
                    after.remove(i)
                    if np.any(score_matrix[i, :] > 0.0):
                        positive = np.sum(score_matrix[i, after] - eps)
                        tot += positive
                    before.append(i)
                return tot
            

            pred_sort_ranking, to_remove = _sort_ranking(score_matrix, eps)
            topological_order_sortranking = list(nx.topological_sort(pred_sort_ranking)) 
            sid_sort_ranking = SID(g, createFullyConnectedGraph(topological_order_sortranking))
            score_matrix[score_matrix > eps] +=  config[d]["scaling"] * d 
            if distribution == 'gene-ecoli':
                score_matrix[score_matrix > eps] += eps
            score_matrix[(score_matrix < eps) & (score_matrix > 0.0)] = 0.1
            for i, j in to_remove:
                score_matrix[i, j] = 0.1
            score_ordering_sortranking = toporder_divergence(graph, topological_order_sortranking)
            if d <= 30:
                candidate = local_search_extended(topological_order_sortranking, score_matrix, d)
                score_ordering_pagerank = toporder_divergence(graph, candidate)
                score_intersort = score_ordering(candidate)
                sid_intersort = SID(g, createFullyConnectedGraph(candidate))
            else:
                score_ordering_pagerank = 0
                score_intersort = 0
                sid_intersort = 0

            obs_data = scaled_data[inter == -1, :]
            
            pred_sort_ranking, _ = _sort_ranking(score_matrix, eps)
            topological_order_sortranking_no_scale = list(nx.topological_sort(pred_sort_ranking)) 
            score_ordering_sortranking_no_scale = toporder_divergence(graph, topological_order_sortranking_no_scale)
            sid_sort_ranking_no_scale = SID(g, createFullyConnectedGraph(topological_order_sortranking_no_scale))
            pred_diffintersort, perm_matrix = diffintersort(score_matrix.copy(), d, init_ordering=topological_order_sortranking, scaling=config[d]["scaling"], n_iter=10000, lr=config[d]["lr"], t_sinkhorn = 0.05, n_iter_sinkhorn=500, eps=eps)
            score_ordering_diffintersort = toporder_divergence(graph, pred_diffintersort)
            score_diffintersort = score_ordering(pred_diffintersort)
            sid_diffintersort = SID(g, createFullyConnectedGraph(pred_diffintersort))

            full_lower = torch.ones(1, int((d - 1) * d / 2))
            full_lower = fill_triangular(full_lower, d, upper=True)

            causal_disco_graph, graph_pred, p_est = causal_discovery(scaled_data, inter, score_matrix.copy(), eps, config, init_ordering=topological_order_sortranking)
            causal_disco_graph_nx = create_graph(graph_pred)
            topological_order_causal_disco = list(nx.topological_sort(causal_disco_graph_nx))
            score_ordering_causal_disco = toporder_divergence(graph, p_est)
            score_causaldisco = score_ordering(p_est)
            shd_causal_disco = shd(g, graph_pred)
            sid_causal_disco = SID(g, graph_pred)
            metrics = classification_metrics(g, graph_pred)
            precision_causal_disco = metrics["precision"]
            recall_causal_disco = metrics["recall"]

            causal_disco_graph, graph_pred_no_cons, p_est  = causal_discovery(scaled_data, inter, score_matrix.copy(), eps, config, init_ordering=None, lambda_int=0.0)
            causal_disco_graph_nx = create_graph(graph_pred_no_cons)
            topological_order_causal_disco = list(nx.topological_sort(causal_disco_graph_nx))
            score_ordering_causal_disco_no_cons = toporder_divergence(graph, p_est)
            shd_causal_disco_no_cons = shd(g, graph_pred_no_cons)
            sid_causal_disco_no_cons = SID(g, graph_pred_no_cons)
            metrics = classification_metrics(g, graph_pred_no_cons)
            precision_causal_disco_no_cons = metrics["precision"]
            recall_causal_disco_no_cons = metrics["recall"]
            
            
        result = {
                'Distribution': distribution,
                'Seed': seed,
                'Score sortranking': score_ordering_sortranking,
                'Score sortranking scale free': score_ordering_sortranking_no_scale,
                'Score pagerank ordering': score_ordering_pagerank,
                'Score diffintersort ordering': score_ordering_diffintersort,
                'Score causal disco ordering': score_ordering_causal_disco,
                'Score causal disco ordering no cons': score_ordering_causal_disco_no_cons,
                'SID sortranking': sid_sort_ranking,
                'SID sortranking scale free': sid_sort_ranking_no_scale,
                'SID pagerank ordering': sid_intersort,
                'SID diffintersort': sid_diffintersort,
                'SID causal disco': sid_causal_disco,
                'SID causal disco no cons': sid_causal_disco_no_cons,
                'SHD causal disco': shd_causal_disco,
                'SHD causal disco no cons': shd_causal_disco_no_cons,
                'Precision causal disco': precision_causal_disco,
                'Precision causal disco no cons': precision_causal_disco_no_cons,
                'Recall causal disco': recall_causal_disco,
                'Recall causal disco no cons': recall_causal_disco_no_cons,
                'Score intersort': score_intersort,
                'Score diffintersort': score_diffintersort,
                'Score causal disco': score_causaldisco,
                'Number variables': d,
                'Number edges': np.sum(g),
                'Number observations': n_obs,
                'Number interventions': n_interv,
                'Number intervened': n_intervened
            }

        results_list.append(result)
        results_df = pd.DataFrame(results_list)

        if scale_free:
            filename = 'results_test_const_sf_' + domain + f'_{d}.csv'
        else:
            filename = 'results_test_const_er_' + domain + f'_{d}.csv'
        results_df.to_csv(filename, index=False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process some integers.")
    
    # Define command line options
    parser.add_argument('--d', type=int, default=10, help='An integer for the dimension size (default: 10)')
    parser.add_argument('--seeds', type=int, default=3, help='An integer for the number of seeds (default: 3)')
    parser.add_argument('--domain', type=str, default="lin", help='A string for the domain (default: "lin")')
    parser.add_argument('--scalefree', type=bool, default=False, help='Whether to consider scale free graphs')
    
    # Parse the arguments
    args = parser.parse_args()
    
    # Call the main function with the parsed arguments
    main(d=args.d, seeds=args.seeds, domain=args.domain, scale_free=args.scalefree)
