from typing import List

import numpy as np
import pandas as pd
import networkx as nx
import random
import os

from baselines.FedDAG import datasets
import torch
from sklearn.preprocessing import MinMaxScaler


from causallearn.search.ConstraintBased.PC import pc
from causallearn.search.FCMBased import lingam
from causallearn.graph.ArrowConfusion import ArrowConfusion
from causallearn.graph.AdjacencyConfusion import AdjacencyConfusion
from causallearn.graph.SHD import SHD
from notears.linear import notears_linear
from notears.nonlinear import NotearsMLP, notears_nonlinear
import ges
import tigramite.pcmci
from tigramite.independence_tests.parcorr import ParCorr
from tigramite import data_processing as pp

from sklearn.metrics import f1_score


from src.graphs import *
from ges.scores.gauss_obs_l0_pen import GaussObsL0Pen


CAUSALCHAMBER_VARIABLES = ['red', 'green', 'blue',
                           'current', 'ir_1', 'ir_2',
                           'ir_3', 'vis_1', 'vis_2',
                           'vis_3', 'pol_1', 'pol_2',
                           'angle_1', 'angle_2', 'l_11',
                           'l_12', 'l_21', 'l_22',
                           'l_31', 'l_32']

CAUSALCHAMBER_EXPERIMENTS = [
    "uniform_reference",
    "uniform_red_strong",
    "uniform_green_strong",
    "uniform_blue_strong",
    "uniform_v_c_strong",
    "uniform_t_ir_1_strong",
    "uniform_t_ir_2_strong",
    "uniform_t_ir_3_strong",
    "uniform_t_vis_1_strong",
    "uniform_t_vis_2_strong",
    "uniform_t_vis_3_strong",
    "uniform_pol_1_strong",
    "uniform_pol_2_strong",
    "uniform_v_angle_1_strong",
    "uniform_v_angle_2_strong",
    "uniform_l_11_mid",
    "uniform_l_12_mid",
    "uniform_l_21_mid",
    "uniform_l_22_mid",
    "uniform_l_31_mid",
    "uniform_l_32_mid",
]

""" function retriving scoring function """
def get_scoring_class(scoring_function: str, lmbda: float = None):
    if scoring_function == 'aic':
        return lambda data: GaussObsL0Pen(data, lmbda=1, method='raw')
    elif scoring_function == 'bic':
        return lambda data: GaussObsL0Pen(data, lmbda=0.5 * np.log(data.shape[0]), method='raw') 
    elif scoring_function == 'bic_pen':
        if lmbda is None:
            raise ValueError("labda must be provided for bic_pen scoring function")
        return lambda data: GaussObsL0Pen(data, lmbda=lmbda, method='raw')
    else:
        raise ValueError(f"Unknown scoring function: {scoring_function}")
    

""" function retriving cd function """
def get_cd_function(cd_function: str, linear: bool = True):
    if cd_function == 'pc':
       return lambda data: pc_wrapper(data, linear)
    elif cd_function == 'lingam':
       return lambda data: lingam_wrapper(data)
    elif cd_function == 'ges':
       return lambda data : ges.fit_bic(data, phases=['forward', 'backward'])[0]
    elif cd_function == 'notears':
       return lambda data: notears_wrapper(data, linear)
    elif cd_function == 'pcmci':
       return lambda data: pcmci_wrapper(data, linear)
    else:
       raise ValueError(f"Unknown cd function: {cd_function}")
   

""" PCMCI wrapper """
def pcmci_wrapper(data: np.ndarray, linear: bool = True):
    tau_max = 10
    pc_alpha = 1e-2
    data -= data.min()
    data /= data.max()
    df = pp.DataFrame(data,
                     datatime = {0:np.arange(len(data))}, 
                     var_names=range(data.shape[1]))
    pcmci = tigramite.pcmci.PCMCI(dataframe=df,
                                   cond_ind_test=ParCorr(significance='analytic'))
    results = pcmci.run_pcmci(tau_max=tau_max, pc_alpha=pc_alpha)
    p_matrix = results['p_matrix']
    val_matrix = results['val_matrix']
    print(results['val_matrix'].shape)
    
    # Get the matrix at the specified lag
    # PCMCI format: [var, var, lag+1]
    p_at_lag = p_matrix[:, :, 10]
    val_at_lag = val_matrix[:, :, 10]
    import sys 
    sys.exit(0)
    # Create binary adjacency matrix (1 if significant, 0 otherwise)
    adj_matrix = (p_at_lag < pc_alpha).astype(int)
    
    # Optional: set diagonal to 0 (no self-loops)
    np.fill_diagonal(adj_matrix, 0)
    print("PCMCI adjacency matrix:\n", adj_matrix)
    import sys 
    sys.exit(0)
    return results['graph']


""" NOTEARS wrapper """
def notears_wrapper(data: np.ndarray, linear: bool = True):
   if linear:
       graph = notears_linear(data, lambda1=0.0, loss_type='l2', max_iter=1000)
   else: 
       n = data.shape[1]
       model = NotearsMLP(dims=[n, n*2, 1])
       scaler = MinMaxScaler()
       data = scaler.fit_transform(data)
       graph = notears_nonlinear(model, data.astype(np.float32), lambda1=0.0, max_iter=200)
   graph = np.where(graph != 0, 1, 0)  # Binarize the adjacency matrix
   return graph
    

""" LiNGAM wrapper """
def lingam_wrapper(data: np.ndarray):
   model = lingam.DirectLiNGAM()
   model.fit(data)
   graph = model.adjacency_matrix_.T
   graph = np.where(graph != 0, 1, 0)  # Binarize the adjacency matrix
   return graph

    
""" PC wrapper """
def pc_wrapper(data: np.ndarray, linear: bool = True):
   indep_test = 'fisherz' if linear else 'kci'
   graph = pc(data, alpha=0.05, indep_test=indep_test, verbose=False)
   graph.to_nx_graph()
   graph = nx.to_numpy_array(graph.nx_graph)
   return graph


""" load graph """
def load_graph(graph_name: str):
    if graph_name == 'cancer':
        return Cancer().G
    elif graph_name == 'asia':
        return Asia().G
    elif graph_name == 'sachs':
        return Sachs().G
    else:
        raise ValueError(f"Unknown graph name: {graph_name}")
    
def shd(true_adj: np.ndarray, pred_adj: np.ndarray) -> int:
    """
    Compute Structural Hamming Distance (SHD) between two adjacency matrices.
    Both true_adj and pred_adj are binary adjacency matrices (0/1), shape (n, n).
    """
    # Ignore self-loops
    np.fill_diagonal(true_adj, 0)
    np.fill_diagonal(pred_adj, 0)
    
    # Convert to undirected adjacency for comparison of structure (ignoring orientation)
    true_undirected = ((true_adj + true_adj.T) > 0).astype(int)
    pred_undirected = ((pred_adj + pred_adj.T) > 0).astype(int)
    
    # Structural difference (ignoring orientation)
    undirected_diff = np.sum(np.abs(true_undirected - pred_undirected)) // 2
    
    # Orientation errors (edges present in both undirected, but wrong direction)
    common_edges = np.logical_and(true_undirected, pred_undirected)
    orientation_errors = np.sum(np.abs(true_adj[common_edges==1] - pred_adj[common_edges==1]))
    
    return int(undirected_diff + orientation_errors)


def shd_skeleton(true_adj: np.ndarray, pred_adj: np.ndarray) -> int:
    """
    Compute Structural Hamming Distance (SHD) between the skeletons of two adjacency matrices.
    Both true_adj and pred_adj are binary adjacency matrices (0/1), shape (n, n).
    """
    # Ignore self-loops
    np.fill_diagonal(true_adj, 0)
    np.fill_diagonal(pred_adj, 0)
    
    # Convert to undirected adjacency for comparison of structure (ignoring orientation)
    true_undirected = ((true_adj + true_adj.T) > 0).astype(int)
    pred_undirected = ((pred_adj + pred_adj.T) > 0).astype(int)
    
    # Structural difference (ignoring orientation)
    undirected_diff = np.sum(np.abs(true_undirected - pred_undirected)) // 2
    
    return int(undirected_diff)


def f1_skeleton(true_adj: np.ndarray, pred_adj: np.ndarray) -> float:
    """
    Compute skeleton F1 score between two DAG adjacency matrices.
    Only considers edges that exist in either true or predicted graph.
    """
    # Convert to undirected adjacency (ignore direction)
    true_undirected = ((true_adj + true_adj.T) > 0).astype(int)
    pred_undirected = ((pred_adj + pred_adj.T) > 0).astype(int)
    
    # Flatten (ignore diagonal)
    mask = ~np.eye(true_undirected.shape[0], dtype=bool)
    y_true = true_undirected[mask].astype(int)
    y_pred = pred_undirected[mask].astype(int)
    
    return f1_score(y_true, y_pred)


def f1_orientation(true_adj: np.ndarray, pred_adj: np.ndarray) -> float:
    """
    Compute orientation F1 score between two DAG adjacency matrices.
    Only considers edges that exist in either true or predicted graph.
    """
    # Flatten (ignore diagonal)
    mask = ~np.eye(true_adj.shape[0], dtype=bool)
    y_true = true_adj[mask].astype(int)
    y_pred = pred_adj[mask].astype(int)
    
    return f1_score(y_true, y_pred)


def set_determine(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)


def identity(A: np.ndarray) -> np.ndarray:
    return A


def cpdag_to_ucpdag(dag: np.ndarray, interventions: List[int]) -> np.ndarray:
    """ Convert a CPDAG to a UCPDAG by orienting edges that create unshielded colliders. """
    ucpdag = ges.utils.dag_to_cpdag(dag)
    for intervention in interventions:
        A = dag.copy()
        # remove all parents of intervention node
        parents = np.where(A[:, intervention] == 1)[0]
        for parent in parents:                
            A[parent, intervention] = 0
        cpdag_intervened = ges.utils.dag_to_cpdag(A)
        for i in range(ucpdag.shape[0]):
            for j in range(ucpdag.shape[1]):
                if (ucpdag[i, j] == 1 and ucpdag[j, i] == 1 and 
                    cpdag_intervened[i, j] != cpdag_intervened[j, i]):
                    ucpdag[i, j] = cpdag_intervened[i, j]
                    ucpdag[j, i] = cpdag_intervened[j, i]
    return ucpdag

def union_graph(dag: np.ndarray, graphs: List[np.ndarray]) -> np.ndarray:
    """ Compute the union of multiple graphs represented as adjacency matrices. """
    union_graph = np.zeros_like(dag)
    for graph in graphs:
        for i in range(graph.shape[0]):
            for j in range(graph.shape[1]):
                if (union_graph[i, j] == 1 and union_graph[j, i] == 1 and 
                    graph[i, j] != graph[j, i]):
                    union_graph[i, j] = graph[i, j]
                    union_graph[j, i] = graph[j, i]
                else:
                    union_graph[i, j] = max(union_graph[i, j], graph[i, j])
                    # union_graph[j, i] = max(union_graph[j, i], graph[j, i])
    return union_graph

def interveined_graph(dag: np.ndarray, interventions: List[int]) -> np.ndarray:
    """ Remove all incoming edges to intervention nodes in the DAG. """
    intervened_graph = dag.copy()
    for intervention in interventions:
        parents = np.where(intervened_graph[:, intervention] == 1)[0]
        for parent in parents:                
            intervened_graph[parent, intervention] = 0
    return intervened_graph