import torch
import pandas as pd
import numpy as np
import torch.nn.parallel
import torch.utils.data

from Datasets import dataset
from castle.metrics import MetricsDAG
from cdt.metrics import SID
from utils import *
from models.diffan.diffan import DiffAN
from dodiscover import make_context
from dodiscover.toporder.score import SCORE
from dodiscover.toporder.cam import CAM
from dodiscover.toporder.das import DAS
from dodiscover.toporder.nogam import NoGAM
from models.caps.caps import train_caps
from models.base import cam_pruning, TabPFN_pruning, xgb_pruning, rf_pruning, mlp_pruning
import networkx as nx


torch.set_printoptions(linewidth=1000)
np.set_printoptions(linewidth=1000)


def blue(x):
    return "\033[94m" + x + "\033[0m"


def run_single_run(args, run_idx, simulation_seeds):
    if args.dataset in ["sachs", "syntren"]:
        train_set, GT_DAG, data_ls = dataset.load_data(
            args.dataset,
            n=853,
            norm=args.norm,
            simulation_seed=42,
            num_nodes=args.num_nodes,
            num_samples=args.num_samples,
            method=args.method,
            linear_sem_type=args.linear_sem_type,
            nonlinear_sem_type=args.nonlinear_sem_type,
            linear_rate=args.linear_rate,
            runs=run_idx + 1,
        )
    elif args.dataset.startswith("Syn"):
        train_set, GT_DAG, data_ls = dataset.load_data(
            args.dataset,
            n=853,
            norm=args.norm,
            simulation_seed=simulation_seeds[run_idx],
            num_nodes=args.num_nodes,
            num_samples=args.num_samples,
            method=args.method,
            linear_sem_type=args.linear_sem_type,
            nonlinear_sem_type=args.nonlinear_sem_type,
            linear_rate=args.linear_rate,
            # Misspecified scenario parameters (six cases studied in the paper)
            scenario=args.scenario,
            rho=args.rho,
            gamma=args.gamma,
            p_unfaithful=args.p_unfaithful,
            exponent=args.exponent,
            runs=run_idx + 1,
        )
    else:
        raise Exception("Dataset not recognized.")

    return train_test(args, train_set, GT_DAG, data_ls, runs=run_idx)


def order_divergence(order, adj):
    err = 0
    for i in range(len(order)):
        err += adj[order[i + 1 :], order[i]].sum()
    return err


def evaluate(args, dag, GT_DAG):
    # print("pred_dag:\n", dag)
    # print("gt_dag:\n", GT_DAG.astype(int))
    # print(blue("edge_num: " + str(np.sum(dag))))
    mt = MetricsDAG(dag, GT_DAG)
    sid = SID(GT_DAG, dag)
    mt.metrics["sid"] = float(sid.item())
    
    # If precision and recall are both zero, treat NaN F1 as zero
    if np.isnan(mt.metrics["F1"]):
        if mt.metrics["precision"] == 0.0 and mt.metrics["recall"] == 0.0:
            print("Warning: F1 is NaN (precision=0, recall=0), setting F1=0")
            mt.metrics["F1"] = 0.0
        else:
            print(f"Warning: F1 is NaN (precision={mt.metrics['precision']}, recall={mt.metrics['recall']}), keeping as NaN")
    
    # FNR = 1 - Recall(TPR)
    mt.metrics["fnr"] = 1.0 - mt.metrics["tpr"]
    for key in mt.metrics:
        mt.metrics[key] = round(float(mt.metrics[key]), 4)
    # print(blue(str(mt.metrics)))
    return mt.metrics


def _get_order_and_graph(model_name, train_set_numpy, args):
    """Gets topological order and initial graph from various models."""
    train_df = pd.DataFrame(train_set_numpy)
    
    if model_name == "random":
        # Random ordering: shuffle node indices randomly
        n, d = train_set_numpy.shape
        order = list(range(d))  # [0, 1, 2, ..., d-1]
        np.random.shuffle(order)  # Randomly shuffle the order
        print(f"Random ordering generated: {order}")
        return order, None
    
    if model_name == "DiffAN":
        n, d = train_set_numpy.shape
        diffan = DiffAN(n_nodes=d, residue=True)
        # Assuming the user's intent was to call the ordering method after training.
        # This part might need adjustment if the original code called `train_score` separately.
        diffan.train_score(torch.FloatTensor(train_set_numpy).to(diffan.device))
        order = diffan.topological_ordering(torch.FloatTensor(train_set_numpy).to(diffan.device))
        return order, None
    
    model_map = {
        "CAM": CAM,
        "DAS": DAS,
        "SCORE": SCORE,
        "NoGAM": NoGAM,
    }

    if model_name in model_map:
        model_class = model_map[model_name]
        model_instance = model_class()
        context = make_context().variables(observed=set(train_df.columns)).build()
        model_instance.learn_graph(train_df, context)
        order = model_instance.order_
        # For DAS, SCORE, NoGAM, the graph is used directly in one of the pruning paths
        if args.dodiscover_cam: 
            graph = model_instance.graph_ if hasattr(model_instance, "graph_") else None
        else:
            graph = None
        return order, graph

    raise ValueError(f"Unknown model for order generation: {model_name}")


def _format_tau_summary(tau_stats, mode_hint=None):
    """Create a compact summary for tau diagnostics."""
    if tau_stats is None:
        tau_stats = {}
    mode = tau_stats.get('mode', mode_hint or 'unknown')
    override_value = tau_stats.get('override_value')
    values = tau_stats.get('values', [])
    finite_values = [float(v) for v in values if np.isfinite(v)]
    if finite_values:
        arr = np.asarray(finite_values, dtype=float)
        tau_min = float(arr.min())
        tau_max = float(arr.max())
        tau_mean = float(arr.mean())
    else:
        tau_min = tau_max = tau_mean = float('nan')

    mode_str = (mode or 'unknown').lower()
    if mode_str == 'manual':
        if override_value is not None and np.isfinite(override_value):
            descriptor = f"manual:{override_value:.6f}"
        else:
            descriptor = "manual"
    elif mode_str == 'mdl':
        descriptor = "MDL"
    else:
        descriptor = mode_str

    return {
        'tau': descriptor,
        'tau_mode': mode_str,
        'tau_min': tau_min,
        'tau_max': tau_max,
        'tau_mean': tau_mean,
    }


def _prune_dag(order, graph, train_set_numpy, args):
    """Applies the selected pruning method to a DAG."""
    
    # If the user wants to test the alternative TabPFN pipeline,
    # we build a full DAG from the order and prune it.
    if args.pruning_method == 'tabpfn':
        init_dag = full_DAG(order)
        dag, tau_stats = TabPFN_pruning(init_dag, train_set_numpy, args)
        return dag, _format_tau_summary(tau_stats)

    elif args.pruning_method == 'xgb':
        init_dag = full_DAG(order)
        dag, _, tau_stats = xgb_pruning(init_dag, train_set_numpy, args)
        return dag, _format_tau_summary(tau_stats)
    
    elif args.pruning_method == 'rf':
        init_dag = full_DAG(order)
        dag, _, tau_stats = rf_pruning(init_dag, train_set_numpy, args)
        return dag, _format_tau_summary(tau_stats)

    elif args.pruning_method == 'mlp':
        init_dag = full_DAG(order)
        dag, _, tau_stats = mlp_pruning(init_dag, train_set_numpy, args)
        return dag, _format_tau_summary(tau_stats)

    # In the default case ('cam'), we handle models differently based on what they return.
    elif args.pruning_method == 'cam':
        # If the model provides its own graph structure (CAM, DAS, SCORE, NoGAM), use it.
        if graph is not None:
            if hasattr(graph, 'nodes'): # Check if it's a networkx graph
                dag = nx.to_numpy_array(
                    graph, nodelist=pd.DataFrame(train_set_numpy).columns
                ).astype(int)
                return dag, _format_tau_summary({'mode': 'cam'})
            else:
                raise TypeError(f"Unsupported graph type from model {args.model}: {type(graph)}")
        
        # If the model only provides an order (like DiffAN), build a full DAG and prune it with CAM.
        else:
            init_dag = full_DAG(order)
            dag, _ = cam_pruning(init_dag, train_set_numpy, 0.001)
            return dag, _format_tau_summary({'mode': 'cam'})

    raise ValueError(f"Unknown pruning method: {args.pruning_method}")

def train_test(args, train_set, GT_DAG, data_ls, runs):
    if args.dataset in ["sachs", "syntren"]:
        x_in = train_set[0].shape[-1] - 1
    elif args.dataset.startswith("Syn"):
        x_in = args.num_nodes
    else:
        raise Exception("Dataset not recognized.")

    # CaPS and OURS use the same pipeline but are handled by model name inside
    if args.model in ["CaPS", "OURS"]:
        dag, order, tau_stats = train_caps(train_set, args)
        tau_summary = _format_tau_summary(tau_stats)
    
    # Other models follow the unified "get order -> finalize graph" pipeline
    elif args.model in ["DiffAN", "CAM", "DAS", "SCORE", "NoGAM", "random"]:
        train_set_numpy = train_set[:, 1:]
        
        # 1. Get topological order and an initial graph structure (if available)
        order, initial_graph = _get_order_and_graph(args.model, train_set_numpy, args)
        print(blue(f"Model: {args.model}, Order: {order}"))

        # 2. Finalize the DAG based on the order, initial graph, and pruning method
        dag, tau_summary = _prune_dag(order, initial_graph, train_set_numpy, args)

    else:
        raise ValueError(f"Model {args.model} not supported in train_test.")

    return evaluate(args, dag, GT_DAG), order, tau_summary
